Model.py

"""
================================================================================
model.py - 多模态大语言模型定义
================================================================================

整体架构:
  图像 (B,3,64,64)
    └─► PatchEmbedding  ──► 可学习位置编码 ──► VisionEncoder ──► 线性投影层
                                                              │
  文本 token ids (B, T)                                      │
    └─► TokenEmbedding ──► PositionalEncoding ────────────────┤
                                                              │
                                                    Transformer 解码器(带 Causal Mask)
                                                              │
                                                            lm_head
                                                              │
                                                        下一个 token 预测

================================================================================
"""

# -----------------------------------------------------------------------------
# 导入标准库和 PyTorch
# -----------------------------------------------------------------------------

import torch                      # PyTorch 核心库,提供张量、自动微分、神经网络模块
import torch.nn as nn              # nn 模块,包含所有神经网络层(Linear, Conv2d 等)
import math                        # 数学库,提供 log、exp 等函数,用于位置编码计算


# ================================================================================
# 第 1 部分:视觉编码器(简化版 ViT - Vision Transformer)
# ================================================================================
# ViT 的核心思想:
#   1. 把图像切分成固定大小的小块(patch),例如 64x64 的图像用 8x8 的 patch
#   2. 每个 patch 展平后通过一个线性层映射为一个向量(patch embedding)
#   3. 加入可学习的位置编码,让模型知道每个 patch 出现在图像的哪个位置
#   4. 通过 Transformer Encoder 让各 patch 之间互相"看"到彼此,提取空间关系
#   5. 输出所有 patch 的特征向量序列

class PatchEmbedding(nn.Module):
    """
    将图像切分为 patch,并将其线性投影为向量序列。

    实现方式:用卷积核大小 = patch 大小、步长 = patch 大小的卷积层,
    一次卷积操作等价于"切分 patch + 线性映射"两步操作。

    例如:输入 (B,3,64,64),patch_size=8,embed_dim=128
         输出 (B, 64, 128),因为 64/8=8,8*8=64 个 patch
    """

    def __init__(self, image_size=64, patch_size=8, in_channels=3, embed_dim=128):
        """
        初始化 patch 嵌入层。

        Args:
            image_size:   输入图像的宽高(假设为正方形,单位:像素)
            patch_size:   每个 patch 的宽高(图像被切成的块大小)
            in_channels: 输入图像通道数,彩色图像为 3(RGB)
            embed_dim:    每个 patch 映射到的向量维度(也是 Transformer 的隐层维度)
        """
        super().__init__()  # 调用父类 nn.Module 的初始化方法,确保模块正确注册

        # 计算图像能被切成多少个 patch
        # 以 64x64 图像、8x8 patch 为例:(64//8)**2 = 64 个 patch
        self.num_patches = (image_size // patch_size) ** 2

        # -----------------------------------------------------------
        # 用卷积实现"切分 patch + 线性映射"
        # - kernel_size=patch_size:每个卷积核正好覆盖一个 patch
        # - stride=patch_size:卷积核不重叠,每次移动一个 patch 的距离
        # - 输出通道数 = embed_dim:每个 patch 被映射为 embed_dim 维向量
        # -----------------------------------------------------------
        self.proj = nn.Conv2d(
            in_channels,              # 输入通道数(3 = RGB)
            embed_dim,                # 输出通道数 = embed_dim(每个 patch 的向量维度)
            kernel_size=patch_size,   # 卷积核大小 = patch 大小
            stride=patch_size         # 步长 = patch 大小,保证不重叠切分
        )

    def forward(self, x):
        """
        前向传播:将图像转换为 patch 向量序列。

        Args:
            x: 输入图像张量,shape 为 (B, C, H, W)
               - B: batch size(一个 batch 中的图像数量)
               - C: 通道数(3 = RGB)
               - H, W: 图像的高和宽

        Returns:
            输出张量,shape 为 (B, num_patches, embed_dim)
               - B: batch size(不变)
               - num_patches: patch 总数 = (H/patch_size) * (W/patch_size)
               - embed_dim: 每个 patch 的向量维度
        """
        # -----------------------------------------------------------------
        # 第 1 步:卷积提取 patch
        # 输入 (B, C, H, W) → 输出 (B, embed_dim, H/p, W/p)
        # 以 64x64 图像、8x8 patch 为例:(B, 3, 64, 64) → (B, 128, 8, 8)
        # -----------------------------------------------------------------
        x = self.proj(x)

        # -----------------------------------------------------------------
        # 第 2 步:展平空间维度
        # (B, embed_dim, H/p, W/p) → (B, embed_dim, H/p * W/p)
        # flatten(2) 表示从第 2 维开始展平(即把 H/p 和 W/p 合并)
        # -----------------------------------------------------------------
        x = x.flatten(2)               # → (B, embed_dim, num_patches)

        # -----------------------------------------------------------------
        # 第 3 步:调换维度顺序
        # (B, embed_dim, num_patches) → (B, num_patches, embed_dim)
        # transpose(1, 2) 交换第 1 维和第 2 维
        # 最终形状:(B, num_patches, embed_dim),每行代表一个 patch 的向量
        # -----------------------------------------------------------------
        x = x.transpose(1, 2)          # → (B, num_patches, embed_dim)

        return x                       # 返回 patch 向量序列


class VisionEncoder(nn.Module):
    """
    简化版 ViT 视觉编码器。

    由三部分组成:
      1. PatchEmbedding:将图像切分为 patch 并映射为向量
      2. 位置编码(Positional Embedding):为每个 patch 添加位置信息
      3. Transformer Encoder:让各 patch 之间通过自注意力机制交换信息

    Transformer Encoder 不使用 causal mask,各 patch 可以看到图像中的所有区域,
    类似于标准 ViT 的做法。
    """

    def __init__(self, image_size=64, patch_size=8, embed_dim=128, num_heads=4, num_layers=2):
        """
        初始化视觉编码器。

        Args:
            image_size:  输入图像的宽高(像素)
            patch_size: 每个 patch 的大小
            embed_dim:  特征维度(Transformer 的 d_model)
            num_heads:  注意力头数量(多头注意力的头数)
            num_layers: Transformer Encoder 的层数
        """
        super().__init__()  # 调用父类初始化

        # -----------------------------------------------------------------
        # 第 1 部分:Patch 嵌入层
        # 将 (B, 3, 64, 64) 的图像转换为 (B, num_patches, embed_dim) 的向量序列
        # -----------------------------------------------------------------
        self.patch_embed = PatchEmbedding(
            image_size=image_size,  # 输入图像尺寸(如 64)
            patch_size=patch_size,  # patch 大小(如 8)
            in_channels=3,          # RGB 三通道
            embed_dim=embed_dim     # 输出向量维度
        )

        # 计算 patch 总数,用于初始化位置编码
        # 例如:64/8=8,每个维度 8 个 patch,总共 8*8=64 个 patch
        num_patches = (image_size // patch_size) ** 2

        # -----------------------------------------------------------------
        # 第 2 部分:可学习的位置编码(Learnable Positional Embedding)
        # 形状:(1, num_patches, embed_dim)
        #   - 第 1 维为 1:所有 batch 共享同一套位置编码
        #   - 第 2 维为 num_patches:每个 patch 对应一个位置向量
        #   - 第 3 维为 embed_dim:与 patch 向量维度对齐,可以直接相加
        # nn.Parameter 表示这是一个可学习的参数,PyTorch 会自动加入梯度计算
        # -----------------------------------------------------------------
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)  # 初始化为全 0
        )

        # -----------------------------------------------------------------
        # 第 3 部分:Transformer Encoder 层
        # 自注意力机制允许每个 patch "看到"图像中的所有其他 patch,
        # 从而学习到全局的空间关系(如物体各部分之间的相对位置)。
        # -----------------------------------------------------------------
        # 单层 Transformer Encoder 的配置
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,              # 输入/输出向量维度,必须与 embed_dim 一致
            nhead=num_heads,                 # 注意力头数量(如 4)
            # 前馈神经网络层(FFN)的隐藏层维度,通常是 d_model 的 4 倍
            dim_feedforward=embed_dim * 4,
            batch_first=True,                # 输入/输出张量的第 0 维是 batch((B, T, D))
            norm_first=True                  # 先做 LayerNorm 再做注意力(Pre-LN 风格,更稳定)
        )

        # 将单层 Encoder 重复 num_layers 次,构成深层网络
        self.transformer = nn.TransformerEncoder(
            encoder_layer,    # 传入上面定义的单层配置
            num_layers=num_layers  # 重复的层数
        )

        # -----------------------------------------------------------------
        # 第 4 部分:输出 LayerNorm
        # 在 Transformer 之后加一层 LayerNorm,稳定训练
        # -----------------------------------------------------------------
        self.norm = nn.LayerNorm(embed_dim)

        # -----------------------------------------------------------------
        # 权重初始化
        # 对位置编码用截断正态分布初始化(均值 0,标准差 0.02)
        # 这是一种常见的 Transformer 初始化策略
        # -----------------------------------------------------------------
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, images):
        """
        前向传播:将图像编码为 patch 特征序列。

        Args:
            images: 输入图像张量,shape 为 (B, 3, H, W)

        Returns:
            输出张量,shape 为 (B, num_patches, embed_dim)
            每个 patch 对应一个 embed_dim 维的特征向量
        """
        # -----------------------------------------------------------------
        # 第 1 步:Patch 嵌入
        # (B, 3, H, W) → (B, num_patches, embed_dim)
        # 将图像切分为 patch 并线性投影为向量
        # -----------------------------------------------------------------
        x = self.patch_embed(images)

        # -----------------------------------------------------------------
        # 第 2 步:加上位置编码
        # 位置编码与 patch 向量维度相同,直接相加即可
        # self.pos_embed 形状为 (1, num_patches, embed_dim)
        # 广播机制会自动扩展到 (B, num_patches, embed_dim)
        # -----------------------------------------------------------------
        x = x + self.pos_embed

        # -----------------------------------------------------------------
        # 第 3 步:过 Transformer Encoder
        # 自注意力让各 patch 之间互相交流信息,学习图像的全局特征
        # -----------------------------------------------------------------
        x = self.transformer(x)

        # -----------------------------------------------------------------
        # 第 4 步:LayerNorm
        # 标准化输出,稳定训练
        # -----------------------------------------------------------------
        x = self.norm(x)

        return x  # (B, num_patches, embed_dim)


# ================================================================================
# 第 2 部分:文本位置编码(正弦/余弦位置编码)
# ================================================================================
# 这是原始 Transformer 论文(Vaswani et al., 2017)提出的位置编码方法。
# 原理:用不同频率的正弦和余弦函数为每个位置生成唯一的编码。
# 优势:
#   - 不需要学习,推理速度快
#   - 可以处理任意长度的序列(外推能力)
#   - 能让模型学到相对位置关系(因为 sin(a+b) 和 cos(a+b) 与位置差有关)

class PositionalEncoding(nn.Module):
    """
    正弦余弦位置编码。

    为文本序列中的每个位置生成一个固定的位置向量,
    与 Token Embedding 相加后输入 Transformer,使模型感知词的顺序。

    公式(来自《Attention is All You Need》):
        PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
        PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

    其中 pos 是位置,i 是维度索引,偶数维度用 sin,奇数维度用 cos。
    """

    def __init__(self, embed_dim, max_len=512):
        """
        初始化位置编码。

        Args:
            embed_dim: 位置向量的维度,必须与 Token Embedding 维度一致
            max_len:   最大支持的序列长度(如 512)
        """
        super().__init__()  # 调用父类初始化

        # -----------------------------------------------------------------
        # 创建位置编码矩阵
        # 形状:(max_len, embed_dim)
        # 第 0 维是位置(0 到 max_len-1),第 1 维是维度(0 到 embed_dim-1)
        # -----------------------------------------------------------------
        pe = torch.zeros(max_len, embed_dim)  # 初始化为全 0 矩阵

        # -----------------------------------------------------------------
        # 创建位置索引向量
        # arange(max_len):生成 [0, 1, 2, ..., max_len-1]
        # unsqueeze(1):在第 1 维增加一个维度,变成 (max_len, 1)
        # 这样可以与 div_term 广播相乘,得到每个位置对应的编码值
        # -----------------------------------------------------------------
        position = torch.arange(max_len).unsqueeze(1).float()  # (max_len, 1)

        # -----------------------------------------------------------------
        # 计算除数项 div_term
        # 公式中的 10000^(2i/d_model)
        # 用 exp(log(10000) * (-2i/d_model)) 实现,避免数值溢出
        # torch.arange(0, embed_dim, 2) 生成 [0, 2, 4, ...](偶数维度索引)
        # -----------------------------------------------------------------
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2).float()  # [0, 2, 4, ..., embed_dim-2]
            * (-math.log(10000.0) / embed_dim)     # -log(10000)/d * [0, 2, 4, ...]
        )   # 结果形状:(embed_dim/2,)

        # -----------------------------------------------------------------
        # 填充位置编码矩阵

        # 对于偶数维度(2i):用 sin(position * div_term[i])
        # position 形状 (max_len, 1),div_term 形状 (embed_dim/2,)
        # 广播后形状 (max_len, embed_dim/2)
        # -----------------------------------------------------------------
        pe[:, 0::2] = torch.sin(position * div_term)

        # -----------------------------------------------------------------
        # 对于奇数维度(2i+1):用 cos(position * div_term[i])
        # 同理,用余弦函数生成奇数维度的编码
        # -----------------------------------------------------------------
        pe[:, 1::2] = torch.cos(position * div_term)

        # -----------------------------------------------------------------
        # 注册为 buffer(不参与梯度计算,但会随模型保存/加载)
        # unsqueeze(0):在第 0 维增加 batch 维度 → (1, max_len, embed_dim)
        # register_buffer 确保它在 device 迁移时自动同步
        # -----------------------------------------------------------------
        self.register_buffer('pe', pe.unsqueeze(0))  # 形状:(1, max_len, embed_dim)

    def forward(self, x):
        """
        将位置编码加到输入序列上。

        Args:
            x: 输入张量,shape 为 (B, T, embed_dim)

        Returns:
            输出张量,shape 为 (B, T, embed_dim)
        """
        # self.pe[:, :x.size(1)] 根据实际序列长度截取对应长度的位置编码
        # x.size(1) 获取序列长度 T
        # self.pe 形状 (1, max_len, embed_dim),自动广播到 (1, T, embed_dim)
        # 然后直接与输入 x 相加(因为维度相同)
        return x + self.pe[:, :x.size(1)]


# ================================================================================
# 第 3 部分:多模态大语言模型( Multimodal LLM)
# ================================================================================
# 这是整个模型的核心,把视觉编码器和语言模型融合在一起。
# 融合策略:将图像特征作为"前缀"拼接到文本 token 序列的前面,
#           类似于 GPT 类的自回归语言模型,但输入多了图像信息。

class MultimodalLLM(nn.Module):
    """
    多模态大语言模型。

    输入:
        - 图像:shape (B, 3, H, W)
        - 文本:shape (B, T),T 为 token 序列长度

    输出:
        - logits:shape (B, T, vocab_size)
          每个文本位置对应词表中每个 token 的预测分数(未归一化的概率)
          取 argmax 即为预测的 token

    数据流:
        图像 → VisionEncoder → 线性投影
              ─────────────────────────────────────────────────►
                                                                    │
        文本 token ids → TokenEmbedding → PositionalEncoding ───────►│
                                                                    │
                                                              拼接 (concat)
                                                                    │
                                                                    ▼
                                                           Transformer 解码器
                                                          (带 Causal Mask)
                                                                    │
                                                                    ▼
                                                           只取文本部分输出
                                                                    │
                                                                    ▼
                                                               lm_head
                                                                    │
                                                                    ▼
                                                          下一个 token logits
    """

    def __init__(
        self,
        vocab_size=1000,        # 词表大小(模型能识别的不同 token 总数)
        embed_dim=128,          # 特征维度(所有 embedding 和 Transformer 的向量维度)
        num_heads=4,            # 注意力头数量
        num_layers=4,           # Transformer 解码器层数
        image_size=64,          # 输入图像尺寸
        patch_size=8,           # 图像切分的 patch 大小
        max_seq_len=128,       # 最大文本序列长度(用于初始化位置编码表)
    ):
        """
        初始化多模态大语言模型的所有子模块。
        """
        super().__init__()  # 调用父类初始化

        self.embed_dim = embed_dim  # 保存维度配置,供其他方法使用

        # -------------------------------------------------------------------------
        # 第 1 部分:视觉编码器(Vision Encoder)
        # 输入 64x64 RGB 图像,输出每个 patch 的特征向量
        # -------------------------------------------------------------------------
        self.vision_encoder = VisionEncoder(
            image_size=image_size,   # 图像尺寸(64)
            patch_size=patch_size,  # patch 大小(8)
            embed_dim=embed_dim,    # 特征维度(128)
            num_heads=num_heads,    # 注意力头数(4)
            num_layers=2,           # 视觉编码器用 2 层(比语言模型少,语言模型更深)
        )

        # -------------------------------------------------------------------------
        # 第 2 部分:视觉特征投影层
        # 将视觉编码器输出的特征向量映射到语言模型的维度
        # 如果视觉编码器和语言模型的 embed_dim 相同,这层就是恒等映射(等价于没有)
        # 但保留这一层可以增加灵活性,方便后续扩展
        # -------------------------------------------------------------------------
        self.vision_proj = nn.Linear(embed_dim, embed_dim)

        # -------------------------------------------------------------------------
        # 第 3 部分:文本 Token Embedding
        # 将 token id(如 0, 1, 2, ..., vocab_size-1)映射为 embed_dim 维向量
        # 例如:token_id=5 → lookup → 一个 128 维的向量
        # -------------------------------------------------------------------------
        self.token_embed = nn.Embedding(vocab_size, embed_dim)

        # -------------------------------------------------------------------------
        # 第 4 部分:文本位置编码
        # 与图像位置编码不同,文本需要区分 token 的先后顺序
        # -------------------------------------------------------------------------
        self.pos_encoding = PositionalEncoding(embed_dim, max_len=max_seq_len + 64)

        # -------------------------------------------------------------------------
        # 第 5 部分:Transformer 解码器
        # 使用与 Transformer Encoder 相同的结构,但配合 Causal Mask 实现解码
        # Causal Mask:每个位置只能看到自己和之前的 token,不能"偷看"未来
        # -------------------------------------------------------------------------
        decoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,              # 输入/输出向量维度
            nhead=num_heads,                 # 注意力头数量
            dim_feedforward=embed_dim * 4,  # FFN 隐藏层维度
            batch_first=True,                # batch 在第 0 维
            norm_first=True                  # Pre-LN 风格(先 LayerNorm 再注意力)
        )

        self.transformer = nn.TransformerEncoder(
            decoder_layer,   # 单层配置
            num_layers=num_layers  # 堆叠的层数
        )

        # -------------------------------------------------------------------------
        # 第 6 部分:输出 LayerNorm
        # 在 Transformer 之后做层归一化
        # -------------------------------------------------------------------------
        self.norm = nn.LayerNorm(embed_dim)

        # -------------------------------------------------------------------------
        # 第 7 部分:语言模型头(LM Head)
        # 将 Transformer 输出的 embed_dim 维向量映射回词表维度
        # 输出 shape:(B, T, vocab_size),每个位置对应词表中每个 token 的分数
        # -------------------------------------------------------------------------
        self.lm_head = nn.Linear(embed_dim, vocab_size)

        # -------------------------------------------------------------------------
        # 权重初始化
        # -------------------------------------------------------------------------
        self._init_weights()

    def _init_weights(self):
        """
        权重初始化。

        - Token Embedding 用截断正态分布(标准差 0.02)
        - LM Head 的偏置初始化为 0
        """
        # 用正态分布初始化 token embedding
        nn.init.normal_(self.token_embed.weight, std=0.02)

        # 将 lm_head 的偏置初始化为 0(如果不初始化,默认也是 0,这里显式写出来)
        nn.init.zeros_(self.lm_head.bias)

    def _causal_mask(self, seq_len, device):
        """
        生成因果掩码(Causal Mask / Attention Mask)。

        目的:防止解码器在预测第 t 个 token 时看到第 t+1、t+2 ... 个 token,
              确保模型只能利用当前及之前的信息来预测下一个 token。

        实现:生成一个上三角矩阵,上三角(含对角线)为 -inf,下三角为 0。
              在计算注意力时,-inf 的位置会被 softmax 变成 0(不参与注意力加权)。

        示例(seq_len=4):
              [[0, -inf, -inf, -inf],   # 位置 0 只能看自己
               [0,     0, -inf, -inf],   # 位置 1 只能看 0,1
               [0,     0,     0, -inf],   # 位置 2 只能看 0,1,2
               [0,     0,     0,     0]]   # 位置 3 可以看所有

        Args:
            seq_len: 序列长度
            device:  张量所在设备(cpu/cuda/mps)

        Returns:
            mask: shape (seq_len, seq_len) 的掩码张量
        """
        # -----------------------------------------------------------------
        # torch.triu(input, diagonal=1):
        # 返回输入矩阵的上三角部分(不含对角线),其余位置填 0
        # 例如 triu(ones(4,4), diagonal=1):
        # [[0, 1, 1, 1],
        #  [0, 0, 1, 1],
        #  [0, 0, 0, 1],
        #  [0, 0, 0, 0]]
        # -----------------------------------------------------------------
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device),  # 全 1 矩阵
            diagonal=1                                   # 从对角线右上方开始为 1
        )

        # -----------------------------------------------------------------
        # 将值为 1 的位置替换为 -inf
        # masked_fill(condition, value):在 condition 为 True 的位置填充 value
        # -----------------------------------------------------------------
        mask = mask.masked_fill(mask == 1, float('-inf'))

        return mask  # (seq_len, seq_len)

    def forward(self, images, input_ids):
        """
        前向传播。

        Args:
            images:    输入图像张量,shape (B, 3, H, W)
            input_ids: 输入文本 token id 序列,shape (B, T)

        Returns:
            logits: 每个文本位置的 token 预测分数,shape (B, T, vocab_size)
                    对第 t 个位置,logits[t] 是词表中每个 token 的分数
        """
        B, T = input_ids.shape  # B: batch size, T: 文本序列长度

        # =========================================================================
        # 第 1 步:编码图像
        # images: (B, 3, H, W) → vision_encoder → (B, num_patches, embed_dim)
        # =========================================================================
        img_features = self.vision_encoder(images)
        img_features = self.vision_proj(img_features)  # 线性投影到语言模型维度
        num_img_tokens = img_features.size(1)         # 记录图像 token 数量

        # =========================================================================
        # 第 2 步:编码文本
        # input_ids: (B, T) → token_embed → (B, T, embed_dim)
        # → pos_encoding → (B, T, embed_dim)
        # =========================================================================
        text_features = self.token_embed(input_ids)  # Token Embedding
        text_features = self.pos_encoding(text_features)  # 加上位置编码

        # =========================================================================
        # 第 3 步:拼接图像特征和文本特征
        # 图像特征作为"前缀"(prefix),文本跟在后面
        # torch.cat(dim=1):在第 1 维(序列维)拼接
        # 结果:(B, num_img_tokens + T, embed_dim)
        # =========================================================================
        x = torch.cat([img_features, text_features], dim=1)
        total_len = num_img_tokens + T  # 拼接后的总序列长度

        # =========================================================================
        # 第 4 步:生成因果掩码
        # 确保文本部分不能看到未来的 token
        # 注意:图像部分(序列的前 num_img_tokens 个位置)也有 causal mask,
        #       但由于图像本身是对称的 patch 集合,这个限制影响不大
        # =========================================================================
        mask = self._causal_mask(total_len, images.device)

        # =========================================================================
        # 第 5 步:过 Transformer 解码器
        # 自注意力 + 因果掩码,实现标准的 Transformer 解码器行为
        # =========================================================================
        x = self.transformer(x, mask=mask)  # (B, total_len, embed_dim)
        x = self.norm(x)                     # 层归一化

        # =========================================================================
        # 第 6 步:取文本部分的输出
        # x 的前 num_img_tokens 个位置是图像 token,后 T 个是文本 token
        # 只取文本部分用于预测下一个 token
        # =========================================================================
        text_out = x[:, num_img_tokens:, :]  # (B, T, embed_dim)

        # =========================================================================
        # 第 7 步:映射到词表
        # lm_head 将每个位置的 embed_dim 维向量映射为 vocab_size 维向量
        # logits[t][v] = 第 t 个位置是词表第 v 个 token 的分数
        # =========================================================================
        logits = self.lm_head(text_out)  # (B, T, vocab_size)

        return logits


# ================================================================================
# 测试代码:验证模型结构是否正确
# ================================================================================
if __name__ == "__main__":
    """
    当直接运行 `python model.py` 时,执行此测试代码。
    验证模型可以正常前向传播,并打印各层张量形状。
    """
    # 创建模型实例,vocab_size=1000 表示词表有 1000 个不同的 token
    model = MultimodalLLM(vocab_size=1000, embed_dim=128)

    # 创建随机输入数据(模拟一个 batch 的数据)
    # 图像:(batch=2, channels=3, height=64, width=64)
    images = torch.randn(2, 3, 64, 64)

    # 文本 token ids:(batch=2, seq_len=16),每个元素是 [0, 999] 之间的整数
    input_ids = torch.randint(0, 1000, (2, 16))

    # 前向传播
    logits = model(images, input_ids)

    # 打印结果
    print(f"输入图像 shape:   {images.shape}")      # torch.Size([2, 3, 64, 64])
    print(f"输入文本 shape:   {input_ids.shape}")    # torch.Size([2, 16])
    print(f"输出 logits:     {logits.shape}")       # torch.Size([2, 16, 1000])
    # 期望:logits 的第 2 维=16(文本序列长度),第 3 维=1000(词表大小)

    # 统计模型参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型总参数量:     {total_params:,}")

训练主程序Train.py

"""
================================================================================
train.py - 训练主程序
================================================================================

标准 PyTorch 训练流程:
  1. 选择计算设备(CPU / GPU / Apple Silicon MPS)
  2. 加载数据集和数据加载器
  3. 初始化模型
  4. 配置优化器(AdamW)和学习率调度器(Cosine Annealing)
  5. 定义损失函数(交叉熵)
  6. 循环训练 N 个 epoch,每个 epoch 遍历所有 batch
  7. 每个 epoch 结束后打印 loss 和困惑度(Perplexity)
  8. 保存最优模型权重到磁盘

关键概念:
  - Epoch:遍历一次完整训练集称为一个 epoch
  - Batch:每次更新参数时使用的一小批样本
  - Loss:模型预测与真实标签之间的差距(越小越好)
  - Perplexity(困惑度):语言模型的常用评估指标,= exp(loss),越小越好
  - 梯度裁剪:防止梯度爆炸,将梯度的 L2 范数限制在某个阈值以内
================================================================================
"""

# -----------------------------------------------------------------------------
# 导入依赖
# -----------------------------------------------------------------------------
import torch                                    # PyTorch 核心库
import torch.nn as nn                           # 神经网络模块(包含损失函数等)
from torch.optim import AdamW                   # AdamW 优化器(Adam + 权重衰减)
from torch.optim.lr_scheduler import CosineAnnealingLR  # 余弦退火学习率调度器

from model import MultimodalLLM                 # 导入我们定义的多模态模型
from dataset import get_dataloader              # 导入数据加载器工厂函数


# ================================================================================
# 超参数配置
# ================================================================================
# 将所有超参数集中在一个字典中,方便统一管理和修改

CONFIG = {
    # ── 数据相关 ──────────────────────────────────────────────────────────────
    "num_samples": 500,       # 训练样本总数(模拟数据集大小)
    "batch_size": 16,         # 每个 batch 的样本数(越大训练越稳定,但显存占用越多)
    "seq_len": 16,            # 文本序列长度(包含 BOS 和 EOS)
    "vocab_size": 1000,       # 词表大小(模型能识别的不同 token 总数)
    "image_size": 64,         # 输入图像的宽高(像素)

    # ── 模型结构 ──────────────────────────────────────────────────────────────
    "embed_dim": 128,         # 特征向量维度(Transformer 的 d_model)
    "num_heads": 4,           # 多头注意力的头数(embed_dim 必须能被 num_heads 整除)
    "num_layers": 4,          # Transformer 解码器的层数(越深表达能力越强)
    "patch_size": 8,          # ViT 的 patch 大小(图像被切成 8x8 的小块)

    # ── 训练超参数 ────────────────────────────────────────────────────────────
    "epochs": 10,             # 训练轮数(遍历完整数据集的次数)
    "lr": 3e-4,               # 初始学习率(3e-4 = 0.0003,AdamW 的常用默认值)
    "weight_decay": 0.01,     # 权重衰减系数(L2 正则化,防止过拟合)
    "grad_clip": 1.0,         # 梯度裁剪阈值(梯度 L2 范数超过此值时进行裁剪)

    # ── 保存路径 ──────────────────────────────────────────────────────────────
    "save_path": "checkpoint.pt",  # 模型权重保存路径
}


# ================================================================================
# 训练单个 epoch 的函数
# ================================================================================

def train_epoch(model, loader, optimizer, criterion, device, grad_clip):
    """
    训练一个完整的 epoch(遍历一次所有训练数据)。

    Args:
        model:     MultimodalLLM 模型实例
        loader:    DataLoader,提供批量训练数据
        optimizer: 优化器(AdamW),负责更新模型参数
        criterion: 损失函数(CrossEntropyLoss)
        device:    计算设备("cpu" / "cuda" / "mps")
        grad_clip: 梯度裁剪阈值(0 表示不裁剪)

    Returns:
        avg_loss:   本 epoch 的平均 loss(按 token 数量加权平均)
        perplexity: 困惑度 = exp(avg_loss)
    """
    # -----------------------------------------------------------------
    # 切换到训练模式
    # model.train() 会启用 Dropout 和 BatchNorm 的训练行为
    # (本模型没有 Dropout,但养成好习惯)
    # -----------------------------------------------------------------
    model.train()

    total_loss = 0.0    # 累计所有 batch 的总 loss(按 token 数量加权)
    total_tokens = 0    # 累计处理的 token 总数(用于计算平均 loss)

    # -----------------------------------------------------------------
    # 遍历所有 batch
    # enumerate(loader) 返回 (batch_idx, batch_data) 的迭代器
    # batch_data 是 Dataset.__getitem__ 返回值的批量版本
    # -----------------------------------------------------------------
    for batch_idx, (images, input_ids, labels) in enumerate(loader):

        # -----------------------------------------------------------------
        # 将数据移动到指定设备(CPU/GPU/MPS)
        # .to(device) 会返回一个新张量,存储在目标设备上
        # 如果数据已经在目标设备上,则不做任何操作
        # -----------------------------------------------------------------
        images = images.to(device)        # 图像张量移到设备
        input_ids = input_ids.to(device)  # 输入 token ids 移到设备
        labels = labels.to(device)        # 标签 token ids 移到设备

        # -----------------------------------------------------------------
        # 前向传播(Forward Pass)
        # 将图像和文本输入模型,得到每个位置的 token 预测分数
        # logits 形状:(B, T, vocab_size)
        # -----------------------------------------------------------------
        logits = model(images, input_ids)

        # -----------------------------------------------------------------
        # 计算损失(Loss)
        # CrossEntropyLoss 期望输入形状为 (N, C),其中 N 是样本数,C 是类别数
        # 但我们的 logits 是 (B, T, vocab_size),需要先展平
        # -----------------------------------------------------------------
        B, T, V = logits.shape  # B: batch size, T: 序列长度, V: 词表大小

        # reshape(B*T, V):将 batch 和序列维度合并,每个 token 位置作为一个独立样本
        # labels.reshape(B*T):同样展平标签
        loss = criterion(
            logits.reshape(B * T, V),  # 预测分数,shape (B*T, vocab_size)
            labels.reshape(B * T)      # 真实标签,shape (B*T,)
        )

        # -----------------------------------------------------------------
        # 反向传播(Backward Pass)
        # -----------------------------------------------------------------

        # 清零梯度
        # 必须在每次反向传播前清零,否则梯度会累积(PyTorch 默认行为)
        optimizer.zero_grad()

        # 计算梯度
        # loss.backward() 会自动计算 loss 对所有可训练参数的梯度
        # 梯度存储在每个参数的 .grad 属性中
        loss.backward()

        # -----------------------------------------------------------------
        # 梯度裁剪(Gradient Clipping)
        # 防止梯度爆炸:当梯度的 L2 范数超过 grad_clip 时,等比例缩小所有梯度
        # 这在训练 Transformer 时非常重要
        # -----------------------------------------------------------------
        if grad_clip > 0:
            nn.utils.clip_grad_norm_(
                model.parameters(),  # 所有可训练参数
                grad_clip            # 最大梯度范数(如 1.0)
            )

        # -----------------------------------------------------------------
        # 参数更新
        # optimizer.step() 根据梯度和学习率更新模型参数
        # AdamW 的更新公式:θ = θ - lr * (m / (√v + ε) + weight_decay * θ)
        # -----------------------------------------------------------------
        optimizer.step()

        # -----------------------------------------------------------------
        # 累计统计量(用于计算 epoch 平均 loss)
        # loss.item() 将标量张量转为 Python float
        # 乘以 B*T 是因为 CrossEntropyLoss 默认对 token 取平均,
        # 这里还原为总 loss,最后再除以总 token 数得到真正的平均
        # -----------------------------------------------------------------
        total_loss += loss.item() * B * T  # 累计总 loss
        total_tokens += B * T              # 累计总 token 数

    # -----------------------------------------------------------------
    # 计算 epoch 平均 loss
    # -----------------------------------------------------------------
    avg_loss = total_loss / total_tokens

    # -----------------------------------------------------------------
    # 计算困惑度(Perplexity)
    # Perplexity = exp(loss),是语言模型的标准评估指标
    # 直觉理解:困惑度越低,模型对下一个 token 的预测越确定
    # 例如:困惑度=10 表示模型平均在 10 个候选 token 中选择
    # -----------------------------------------------------------------
    perplexity = torch.exp(torch.tensor(avg_loss)).item()

    return avg_loss, perplexity  # 返回平均 loss 和困惑度


# ================================================================================
# 主训练流程
# ================================================================================

def main():
    """
    完整的训练流程:初始化 → 训练 → 保存。
    """

    # =========================================================================
    # 第 1 步:选择计算设备
    # =========================================================================
    device = (
        "cuda" if torch.cuda.is_available()          # 优先使用 NVIDIA GPU
        else "mps" if torch.backends.mps.is_available()  # 其次使用 Apple Silicon GPU
        else "cpu"                                    # 最后使用 CPU
    )
    print(f"使用设备: {device}")

    # =========================================================================
    # 第 2 步:加载数据
    # =========================================================================
    print("\n── 加载数据 ──")
    train_loader = get_dataloader(
        num_samples=CONFIG["num_samples"],  # 样本总数
        batch_size=CONFIG["batch_size"],    # 每个 batch 的大小
        vocab_size=CONFIG["vocab_size"],    # 词表大小
        seq_len=CONFIG["seq_len"],          # 序列长度
        image_size=CONFIG["image_size"],    # 图像尺寸
    )
    # 打印数据集信息
    print(f"训练集: {CONFIG['num_samples']} 样本,{len(train_loader)} 个 batch")

    # =========================================================================
    # 第 3 步:初始化模型
    # =========================================================================
    print("\n── 初始化模型 ──")
    model = MultimodalLLM(
        vocab_size=CONFIG["vocab_size"],    # 词表大小
        embed_dim=CONFIG["embed_dim"],      # 特征维度
        num_heads=CONFIG["num_heads"],      # 注意力头数
        num_layers=CONFIG["num_layers"],    # Transformer 层数
        image_size=CONFIG["image_size"],    # 图像尺寸
        patch_size=CONFIG["patch_size"],    # patch 大小
        max_seq_len=CONFIG["seq_len"] + 10, # 最大序列长度(留一些余量)
    ).to(device)  # 将模型移动到指定设备

    # 统计参数量
    total_params = sum(p.numel() for p in model.parameters())
    # p.requires_grad 为 True 表示该参数参与梯度计算(可训练)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"总参数量:     {total_params:,}")
    print(f"可训练参数量: {trainable_params:,}")

    # =========================================================================
    # 第 4 步:配置优化器
    # =========================================================================
    # AdamW = Adam + 权重衰减(Weight Decay)
    # Adam 是目前最常用的深度学习优化器,自适应调整每个参数的学习率
    # 权重衰减是一种正则化手段,防止参数过大(过拟合)
    optimizer = AdamW(
        model.parameters(),           # 需要优化的参数(所有可训练参数)
        lr=CONFIG["lr"],              # 初始学习率
        weight_decay=CONFIG["weight_decay"]  # 权重衰减系数
    )

    # =========================================================================
    # 第 5 步:配置学习率调度器
    # =========================================================================
    # CosineAnnealingLR:余弦退火调度器
    # 学习率从初始值按余弦曲线逐渐降低到接近 0
    # 好处:训练初期学习率较大(快速收敛),后期学习率较小(精细调整)
    scheduler = CosineAnnealingLR(
        optimizer,              # 需要调度的优化器
        T_max=CONFIG["epochs"]  # 完成一个余弦周期所需的 epoch 数
    )

    # =========================================================================
    # 第 6 步:定义损失函数
    # =========================================================================
    # CrossEntropyLoss:交叉熵损失,用于多分类问题(预测下一个 token 是哪个)
    # ignore_index=0:忽略 PAD token(id=0)的 loss,不让填充 token 影响训练
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    # =========================================================================
    # 第 7 步:训练循环
    # =========================================================================
    print(f"\n── 开始训练(共 {CONFIG['epochs']} epochs)──")
    best_loss = float("inf")  # 记录历史最优 loss,初始化为正无穷

    # 遍历每个 epoch(range(1, N+1) 让 epoch 从 1 开始,更直观)
    for epoch in range(1, CONFIG["epochs"] + 1):

        # 训练一个 epoch,返回平均 loss 和困惑度
        loss, ppl = train_epoch(
            model,              # 模型
            train_loader,       # 数据加载器
            optimizer,          # 优化器
            criterion,          # 损失函数
            device,             # 计算设备
            CONFIG["grad_clip"] # 梯度裁剪阈值
        )

        # 更新学习率(每个 epoch 结束后调用一次)
        scheduler.step()

        # 获取当前学习率(用于打印)
        lr_now = optimizer.param_groups[0]["lr"]

        # 打印本 epoch 的训练信息
        print(
            f"Epoch {epoch:2d}/{CONFIG['epochs']} | "  # 当前 epoch / 总 epoch
            f"Loss: {loss:.4f} | "                     # 平均 loss(4 位小数)
            f"Perplexity: {ppl:.2f} | "                # 困惑度(2 位小数)
            f"LR: {lr_now:.2e}"                        # 当前学习率(科学计数法)
        )

        # -----------------------------------------------------------------
        # 保存最优模型
        # 如果本 epoch 的 loss 比历史最优更低,则保存模型权重
        # -----------------------------------------------------------------
        if loss < best_loss:
            best_loss = loss  # 更新历史最优 loss

            # torch.save 保存一个字典,包含:
            # - epoch:当前训练轮数(方便恢复训练)
            # - model_state_dict:模型所有参数的权重
            # - optimizer_state_dict:优化器状态(包含动量等,方便恢复训练)
            # - loss:当前 loss(方便查看保存时的性能)
            # - config:超参数配置(方便推理时重建模型)
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                "config": CONFIG,
            }, CONFIG["save_path"])

            print(f"  ✓ 保存最优模型 (loss={loss:.4f})")

    # 训练结束,打印最终结果
    print(f"\n训练完成!最优 loss: {best_loss:.4f}")
    print(f"模型已保存至: {CONFIG['save_path']}")


# ================================================================================
# 程序入口
# ================================================================================
if __name__ == "__main__":
    """
    当直接运行 `python train.py` 时,执行 main() 函数。
    如果作为模块被其他文件 import,则不执行(避免意外触发训练)。
    """
    main()

推理文本生成示例

"""
================================================================================
inference.py - 推理(文本生成)示例
================================================================================

演示如何用训练好的模型做自回归文本生成。

自回归生成(Autoregressive Generation)的原理:
  1. 给定图像和起始 token(BOS),让模型预测下一个 token
  2. 将预测的 token 追加到序列末尾
  3. 用新的序列再次预测下一个 token
  4. 重复上述过程,直到生成 EOS token 或达到最大长度

示意图:
  步骤 1:输入 [BOS]         → 预测 w1
  步骤 2:输入 [BOS, w1]     → 预测 w2
  步骤 3:输入 [BOS, w1, w2] → 预测 w3
  ...

采样策略:
  - 贪心解码(Greedy):每次选概率最高的 token(确定性,但可能重复)
  - Top-K 采样:从概率最高的 K 个 token 中随机采样(多样性更好)
  - 温度(Temperature):控制概率分布的"尖锐程度"
    - 温度 < 1:分布更尖锐,倾向于选高概率 token(更保守)
    - 温度 > 1:分布更平坦,各 token 概率趋于均等(更随机)
    - 温度 = 1:不改变原始分布
================================================================================
"""

# -----------------------------------------------------------------------------
# 导入依赖
# -----------------------------------------------------------------------------
import torch                  # PyTorch 核心库
from model import MultimodalLLM  # 导入多模态模型


# ================================================================================
# 自回归文本生成函数
# ================================================================================

@torch.no_grad()  # 装饰器:推理时不需要计算梯度,节省内存和计算量
def generate(model, image, max_new_tokens=20, temperature=1.0, top_k=50,
             bos_token=1, eos_token=2, device="cpu"):
    """
    自回归文本生成。

    给定图像,逐步生成文本 token 序列。

    Args:
        model:          训练好的 MultimodalLLM 模型
        image:          输入图像张量,shape (1, 3, H, W),batch_size 必须为 1
        max_new_tokens: 最多生成多少个新 token(不含 BOS)
        temperature:    温度参数,控制生成的随机性
                        - 0.01 ≈ 贪心解码(几乎总是选最高概率的 token)
                        - 1.0  = 原始概率分布
                        - 2.0  = 更随机(概率分布更平坦)
        top_k:          Top-K 采样的 K 值
                        - 只从概率最高的 K 个 token 中采样
                        - 0 表示不做 Top-K 过滤(从全词表采样)
        bos_token:      句子开始 token 的 id(Begin Of Sentence)
        eos_token:      句子结束 token 的 id(End Of Sentence)
        device:         计算设备

    Returns:
        generated: 生成的 token id 列表(包含 BOS,可能包含 EOS)
    """
    # -----------------------------------------------------------------
    # 切换到推理模式
    # model.eval() 会关闭 Dropout 等训练专用层
    # -----------------------------------------------------------------
    model.eval()

    # 将图像移动到指定设备
    image = image.to(device)

    # -----------------------------------------------------------------
    # 初始化生成序列
    # 从 BOS token 开始,这是语言模型生成的标准起点
    # -----------------------------------------------------------------
    generated = [bos_token]  # Python 列表,存储已生成的 token id

    # 将生成序列转为张量,shape (1, 1)
    # 第 0 维是 batch(=1),第 1 维是序列长度(初始=1)
    input_ids = torch.tensor([generated], dtype=torch.long, device=device)

    # -----------------------------------------------------------------
    # 自回归生成循环
    # 每次迭代生成一个新 token
    # -----------------------------------------------------------------
    for step in range(max_new_tokens):

        # -----------------------------------------------------------------
        # 第 1 步:前向传播,获取当前序列的 logits
        # logits 形状:(1, T, vocab_size)
        # T 是当前序列长度(随着生成逐渐增加)
        # -----------------------------------------------------------------
        logits = model(image, input_ids)

        # -----------------------------------------------------------------
        # 第 2 步:取最后一个位置的 logits
        # logits[0]:取 batch 中第 0 个样本,shape (T, vocab_size)
        # logits[0, -1, :]:取序列最后一个位置的 logits,shape (vocab_size,)
        # 最后一个位置的 logits 代表"给定当前序列,下一个 token 的预测分数"
        # -----------------------------------------------------------------
        next_logits = logits[0, -1, :]  # shape: (vocab_size,)

        # -----------------------------------------------------------------
        # 第 3 步:温度缩放
        # 将 logits 除以温度,改变概率分布的"尖锐程度"
        # 温度越低 → logits 差异越大 → softmax 后概率越集中 → 更确定
        # 温度越高 → logits 差异越小 → softmax 后概率越均匀 → 更随机
        # -----------------------------------------------------------------
        next_logits = next_logits / temperature

        # -----------------------------------------------------------------
        # 第 4 步:Top-K 过滤
        # 只保留概率最高的 K 个 token,其余设为 -inf(softmax 后概率为 0)
        # 这样可以避免采样到概率极低的奇怪 token
        # -----------------------------------------------------------------
        if top_k > 0:
            # torch.topk:返回最大的 top_k 个值及其索引
            # top_values 形状:(top_k,),按降序排列
            top_values, _ = torch.topk(next_logits, top_k)

            # 取第 K 个最大值作为阈值(即第 K 大的 logit 值)
            threshold = top_values[-1]

            # 将所有小于阈值的 logit 设为 -inf
            # 这样 softmax 后这些 token 的概率为 0,不会被采样到
            next_logits[next_logits < threshold] = float('-inf')

        # -----------------------------------------------------------------
        # 第 5 步:计算概率分布并采样
        # softmax 将 logits 转为概率(所有值在 [0,1] 之间,且和为 1)
        # multinomial 根据概率分布随机采样一个 token
        # -----------------------------------------------------------------
        probs = torch.softmax(next_logits, dim=-1)  # shape: (vocab_size,)

        # torch.multinomial:按概率分布随机采样
        # num_samples=1:采样 1 个 token
        # .item():将单元素张量转为 Python int
        next_token = torch.multinomial(probs, num_samples=1).item()

        # -----------------------------------------------------------------
        # 第 6 步:将新 token 追加到生成序列
        # -----------------------------------------------------------------
        generated.append(next_token)

        # -----------------------------------------------------------------
        # 第 7 步:检查是否生成了 EOS token
        # 如果生成了句子结束标记,停止生成
        # -----------------------------------------------------------------
        if next_token == eos_token:
            break  # 退出循环,生成结束

        # -----------------------------------------------------------------
        # 第 8 步:更新输入序列
        # 将新生成的 token 追加到输入序列,用于下一步预测
        # 注意:每次都重新创建张量(简单但低效,生产环境可用 KV Cache 优化)
        # -----------------------------------------------------------------
        input_ids = torch.tensor([generated], dtype=torch.long, device=device)

    return generated  # 返回完整的生成序列(token id 列表)


# ================================================================================
# 主程序
# ================================================================================

def main():
    """
    推理主程序:加载模型,生成文本,展示不同采样策略的效果。
    """

    # =========================================================================
    # 第 1 步:选择计算设备
    # =========================================================================
    device = (
        "cuda" if torch.cuda.is_available()               # 优先 NVIDIA GPU
        else "mps" if torch.backends.mps.is_available()   # 其次 Apple Silicon GPU
        else "cpu"                                         # 最后 CPU
    )
    print(f"使用设备: {device}")

    # =========================================================================
    # 第 2 步:加载模型
    # =========================================================================
    checkpoint_path = "checkpoint.pt"  # 模型权重文件路径

    try:
        # -----------------------------------------------------------------
        # 方式一:加载训练好的模型权重
        # torch.load 加载 checkpoint 字典
        # map_location=device:将权重加载到指定设备(避免 GPU 权重加载到 CPU 报错)
        # -----------------------------------------------------------------
        checkpoint = torch.load(checkpoint_path, map_location=device)

        # 从 checkpoint 中取出超参数配置(训练时保存的 CONFIG 字典)
        cfg = checkpoint["config"]

        print(f"加载模型: {checkpoint_path}")
        print(f"  训练 epoch: {checkpoint['epoch']}")
        print(f"  训练 loss:  {checkpoint['loss']:.4f}")

        # 根据保存的配置重建模型结构
        model = MultimodalLLM(
            vocab_size=cfg["vocab_size"],   # 词表大小
            embed_dim=cfg["embed_dim"],     # 特征维度
            num_heads=cfg["num_heads"],     # 注意力头数
            num_layers=cfg["num_layers"],   # Transformer 层数
            image_size=cfg["image_size"],   # 图像尺寸
            patch_size=cfg["patch_size"],   # patch 大小
        ).to(device)  # 将模型移到设备

        # 加载模型权重
        # load_state_dict 将保存的参数值填充到模型中
        model.load_state_dict(checkpoint["model_state_dict"])

    except FileNotFoundError:
        # -----------------------------------------------------------------
        # 方式二:如果没有训练好的模型,用随机初始化的模型演示推理流程
        # 这种情况下生成的 token 是随机的,没有实际意义
        # -----------------------------------------------------------------
        print(f"未找到 {checkpoint_path},使用随机初始化模型演示推理流程")
        print("提示:先运行 `python train.py` 训练模型,再运行推理\n")

        # 创建一个随机初始化的模型(参数随机,输出无意义)
        model = MultimodalLLM(vocab_size=1000, embed_dim=128).to(device)

    # =========================================================================
    # 第 3 步:准备输入图像
    # =========================================================================
    # 这里用随机张量模拟一张图像
    # 真实使用时,替换为:
    #   from PIL import Image
    #   import torchvision.transforms as T
    #   img = Image.open("your_image.jpg")
    #   transform = T.Compose([T.Resize((64, 64)), T.ToTensor()])
    #   image = transform(img).unsqueeze(0)  # 增加 batch 维度 → (1, 3, 64, 64)
    image = torch.randn(1, 3, 64, 64)  # (batch=1, channels=3, height=64, width=64)

    print("\n── 开始生成 ──")
    print(f"输入图像 shape: {image.shape}")

    # =========================================================================
    # 第 4 步:Top-K 采样生成(有一定随机性)
    # =========================================================================
    tokens = generate(
        model,                  # 模型
        image,                  # 输入图像
        max_new_tokens=15,      # 最多生成 15 个新 token
        temperature=0.8,        # 温度 0.8:略微降低随机性,生成更合理
        top_k=50,               # 从概率最高的 50 个 token 中采样
        device=device,
    )

    print(f"\n[Top-K 采样] 生成的 token ids: {tokens}")
    print(f"生成长度: {len(tokens)} tokens")
    print("注意:这里输出的是 token id,真实场景中需要用 tokenizer.decode() 转回文字")

    # =========================================================================
    # 第 5 步:贪心解码(确定性,每次选最高概率的 token)
    # =========================================================================
    tokens_greedy = generate(
        model,
        image,
        max_new_tokens=15,
        temperature=0.01,  # 极低温度 ≈ 贪心解码(几乎总是选最高概率的 token)
        top_k=1,           # 只保留概率最高的 1 个 token(等价于 argmax)
        device=device,
    )

    print(f"\n[贪心解码] 生成的 token ids: {tokens_greedy}")
    print("贪心解码是确定性的:相同输入每次生成结果相同")

    # =========================================================================
    # 第 6 步:高温采样(更随机,多样性更高)
    # =========================================================================
    tokens_creative = generate(
        model,
        image,
        max_new_tokens=15,
        temperature=1.5,   # 高温:概率分布更平坦,生成更多样
        top_k=100,         # 从更多候选中采样
        device=device,
    )

    print(f"\n[高温采样] 生成的 token ids: {tokens_creative}")
    print("高温采样多样性更高,但可能生成不太合理的内容")


# ================================================================================
# 程序入口
# ================================================================================
if __name__ == "__main__":
    """
    当直接运行 `python inference.py` 时,执行 main() 函数。
    """
    main()