神经网络架构详解

涵盖 ResNet18 · UNet · SwinIR · CLIP · LLaVA · LLaMA · Qwen3
包含结构图、直觉理解、核心模块解析、完整代码实现


目录

  1. 基础知识:Attention & 位置编码
  2. ResNet18
  3. UNet
  4. SwinIR
  5. CLIP
  6. LLaVA
  7. LLaMA
  8. Qwen3
  9. 架构对比速查表

基础知识

back to 目录

0.1 Self-Attention

直觉理解

想象你在读一句话:“追着在客厅里”。
当你理解”跑”这个词时,你的大脑会自动回头看”猫”和”球”——这正是 Self-Attention 在做的事情:让序列中的每个位置都能”看”到其他所有位置,并按相关程度加权融合信息

传统 RNN 只能一步步传递信息(像电话传话游戏,传着传着就失真了)。Self-Attention 则像在一个圆桌会议上,每个人可以直接和所有人交流,距离不再是障碍。

Q / K / V 三件套的比喻

把注意力机制想象成一个图书馆检索系统:

  • Q(Query,查询):你手里的借书申请单——“我想找关于猫的书”
  • K(Key,键):书架上每本书的标签——“这本书讲猫、那本书讲狗”
  • V(Value,值):书本身的内容

你拿着申请单(Q)去匹配每本书的标签(K),相似度高的书你就多看一些,最终你读到的内容(输出)是按相似度加权的书的内容(V)的混合。

数学形式

除以 是为了防止点积值过大导致 softmax 进入梯度接近零的饱和区——这就像把麦克风音量调到合适大小,避免声音失真。

输入 X: (B, N, D)
         │
    ┌────┴─────┐
   W_Q        W_K        W_V     ← 三个独立的线性投影(学习"怎么提问/怎么打标签/什么是内容")
    │          │          │
    Q          K          V
 (B,N,d_k) (B,N,d_k) (B,N,d_v)
    │          │
    └──── QK^T / √d_k ────┘
              │
           softmax    → (B, N, N) 注意力图,第 i 行表示第 i 个 token 关注其他 token 的权重
              │
           × V        → 按权重加权求和每个位置的 Value
              │
         output (B, N, d_v)

Multi-Head 的意义

单个注意力头只能学习一种”关注模式”(比如语法关系)。多头注意力就像多个摄像头从不同角度拍摄同一个场景——有的头关注语义,有的头关注位置距离,有的头关注共指关系,最终把所有视角合并。

X → [Q,K,V 线性投影] → 拆成 h 个 head
每个 head 独立做 Attention(Q_i, K_i, V_i)  ← h 个不同的"视角"
所有 head concat → 线性投影 W_O → 输出
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
 
        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o   = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x, mask=None):
        B, N, D = x.shape
        # 一次性计算 Q, K, V(效率更高)
        qkv = self.W_qkv(x)                          # (B, N, 3D)
        q, k, v = qkv.chunk(3, dim=-1)               # 各 (B, N, D)
 
        # 拆分多头:每个 head 独立处理一个 d_k 维子空间
        def split_heads(t):
            return t.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
            # -> (B, h, N, d_k)
 
        q, k, v = split_heads(q), split_heads(k), split_heads(v)
 
        # Scaled Dot-Product Attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        # scores: (B, h, N, N) —— 每对 token 之间的相关性
 
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
 
        attn = F.softmax(scores, dim=-1)   # 归一化为概率分布
        attn = self.dropout(attn)
 
        out = torch.matmul(attn, v)                   # (B, h, N, d_k)
        out = out.transpose(1, 2).contiguous()        # (B, N, h, d_k)
        out = out.view(B, N, D)                       # 合并所有 head
        return self.W_o(out)                          # 最终投影
 
 
# ---- 带 KV Cache 的因果 Self-Attention(用于 LLM 推理加速)----
# 推理时每次只处理一个新 token,把历史 K/V 缓存起来,避免重复计算
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o   = nn.Linear(d_model, d_model, bias=False)
 
    def forward(self, x, kv_cache=None):
        B, T, D = x.shape
        q, k, v = self.W_qkv(x).chunk(3, dim=-1)
 
        def split(t):
            return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        q, k, v = split(q), split(k), split(v)
 
        # KV Cache:把新的 K/V 拼接到历史缓存上
        if kv_cache is not None:
            k = torch.cat([kv_cache[0], k], dim=2)
            v = torch.cat([kv_cache[1], v], dim=2)
        new_cache = (k, v)
 
        S = k.shape[2]
        scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_k)
 
        # 下三角 mask:每个 token 只能看到自己及之前的 token(自回归性质)
        causal_mask = torch.tril(torch.ones(T, S, device=x.device))
        scores = scores.masked_fill(causal_mask[-T:, :S] == 0, float('-inf'))
 
        out = torch.matmul(F.softmax(scores, dim=-1), v)
        out = out.transpose(1,2).contiguous().view(B, T, D)
        return self.W_o(out), new_cache

0.2 Cross-Attention

直觉理解

如果说 Self-Attention 是”自我反思”(序列内部的信息整合),那 Cross-Attention 就是”跨界对话”——一个序列拿着自己的问题(Q),去另一个序列里查答案(K/V)

最经典的例子是机器翻译:解码器(生成中文)用当前状态作为 Q,去关注编码器(理解英文)产出的 K/V,从而知道”此刻应该关注原文的哪个部分”。

在多模态模型(如 LLaVA)中,Cross-Attention 让文本 token 能”看”图像特征,从而理解图像内容。

Query 来源 X_q: (B, N_q, D)      ← 解码器当前状态 / 文本 token
Key/Value 来源 X_kv: (B, N_kv, D) ← 编码器输出 / 图像特征

Q = X_q  · W_Q   ← "我现在想知道什么"
K = X_kv · W_K   ← "我(上下文)里有什么主题"
V = X_kv · W_V   ← "上下文的实际内容"

注意力图 softmax(QK^T/√d): (B, h, N_q, N_kv)
  ↑ 第 i 个 query token 对第 j 个 context token 的关注程度

输出: Attention(Q, K, V) → (B, N_q, D)
  ↑ 每个 query token 现在"看过"了所有 context 信息
class CrossAttention(nn.Module):
    def __init__(self, d_model: int, d_context: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        # Q 来自目标序列,K/V 来自上下文序列(维度可以不同)
        self.W_q = nn.Linear(d_model,   d_model, bias=False)
        self.W_k = nn.Linear(d_context, d_model, bias=False)
        self.W_v = nn.Linear(d_context, d_model, bias=False)
        self.W_o = nn.Linear(d_model,   d_model, bias=False)
 
    def forward(self, x, context):
        """
        x:       (B, N_q,  D)       ← query source(提问者)
        context: (B, N_kv, D_ctx)   ← key/value source(知识库)
        """
        B, N_q,  _ = x.shape
        B, N_kv, _ = context.shape
 
        q = self.W_q(x)        # (B, N_q,  D)
        k = self.W_k(context)  # (B, N_kv, D)
        v = self.W_v(context)  # (B, N_kv, D)
 
        def split(t, N):
            return t.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
 
        q, k, v = split(q, N_q), split(k, N_kv), split(v, N_kv)
 
        scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_k)
        attn   = F.softmax(scores, dim=-1)   # (B, h, N_q, N_kv)
        out    = torch.matmul(attn, v)       # (B, h, N_q, d_k)
        out    = out.transpose(1,2).contiguous().view(B, N_q, -1)
        return self.W_o(out)

0.3 位置编码

为什么需要位置编码?

Attention 本质上是一种无序操作——把输入打乱顺序,输出不变(只是行列换一下)。但语言是有顺序的:“狗咬人” ≠ “人咬狗”。因此必须用位置编码把序号信息注入进去。

各种位置编码的核心问题只有一个:怎么让模型知道 token 在哪个位置,以及两个 token 之间有多远?


0.3.1 Sinusoidal PE(原始 Transformer 论文)

思路: 用不同频率的正弦/余弦波来表示位置,就像钟表的时针、分针、秒针——频率不同,组合起来可以唯一表示任意时刻。
低维度用高频(变化快,区分近邻),高维度用低频(变化慢,区分远端)。

优点: 无需学习,天然能外推到训练时没见过的更长序列。
缺点: 相对位置信息是隐式的,模型需要自己去学习如何利用它。

class SinusoidalPE(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()          # (L, 1)
        div = torch.exp(torch.arange(0, d_model, 2).float()
                        * -(math.log(10000.0) / d_model))            # (D/2,)
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))                  # (1, L, D)
 
    def forward(self, x):
        # 位置编码直接加到输入 embedding 上
        return x + self.pe[:, :x.size(1)]

0.3.2 Learnable PE(BERT / ViT)

思路: 既然不知道最好的编码方式,就让模型自己学。每个位置分配一个可训练的向量,和 word embedding 一样从数据中学习。

优点: 灵活,针对任务自适应。
缺点: 无法外推(训练时最长 512,推理时遇到 600 长度就不知道怎么处理了)。

class LearnablePE(nn.Module):
    def __init__(self, d_model: int, max_len: int):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)  # 每个位置一个可学习的向量
 
    def forward(self, x):
        pos = torch.arange(x.size(1), device=x.device)
        return x + self.pe(pos)

0.3.3 RoPE(Rotary Position Embedding)⭐ LLaMA / Qwen 使用

思路: 这是最精妙的设计。不直接”加”位置信息,而是对 Q/K 向量旋转一个与位置成正比的角度。
数学上可以证明:旋转后的 只与相对位置差 有关,与绝对位置无关。

比喻: 想象两个人站在圆形跑道上,他们之间的”相对方向”只取决于各自的位置差,与他们在跑道哪个绝对位置无关。RoPE 就是让注意力分数只感知这个”相对方向”。

为什么 LLM 青睐 RoPE:

  • 外推性比 Learnable PE 强
  • 相对位置天然编码
  • 可以通过调整 theta 来扩展上下文长度(LLaMA 用 10000,Qwen3 用 1,000,000)

旋转示意(2D 情况):

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    """
    预计算每个位置、每个维度对的旋转角度(以复数形式存储)
    dim/2 对维度,每对用一个复数 e^{i*pos*θ_j} 表示
    """
    # 每对维度的"旋转速度":低维快(高频),高维慢(低频)
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t     = torch.arange(seq_len)
    freqs = torch.outer(t, freqs)                     # (seq_len, dim/2)
    # 转为复数:e^{i*θ} = cos(θ) + i*sin(θ)
    return torch.polar(torch.ones_like(freqs), freqs)
 
def apply_rotary_emb(xq, xk, freqs_cis):
    """
    将旋转位置编码应用到 Q 和 K
    原理:复数乘法 = 旋转。把每对维度当作一个二维向量,乘以旋转复数
    """
    def rotate(x, f):
        # x: (B, T, h, d_k)
        # 把每对相邻维度合成一个复数
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        # f: (T, d_k/2) → 广播到 (1, T, 1, d_k/2)
        f    = f[:x.shape[1]].unsqueeze(0).unsqueeze(2)
        # 复数乘法:相当于旋转
        x_rot = torch.view_as_real(x_c * f).flatten(3)
        return x_rot.type_as(x)
 
    xq = rotate(xq, freqs_cis)
    xk = rotate(xk, freqs_cis)
    return xq, xk

0.3.4 ALiBi(Attention with Linear Biases)

思路: 不修改输入,直接在注意力分数矩阵上加一个线性惩罚——距离越远,分数扣得越多,天然让模型更关注近邻。

不同 head 用不同的惩罚斜率,有的 head 视野窄(斜率大,惩罚重),有的 head 视野宽(斜率小,惩罚轻)。

优点: 外推能力极强(训练 1024,推理 4096 也不差)。
缺点: 先验假设”近的更重要”不一定对所有任务成立。

def get_alibi_slopes(n_heads):
    """计算等比数列斜率,斜率越大 = 对距离越敏感"""
    def _slopes(n):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        return [start * (start ** i) for i in range(n)]
    return torch.tensor(_slopes(n_heads))
 
def alibi_bias(seq_len, n_heads, slopes):
    pos  = torch.arange(seq_len)
    # |i - j|:位置距离矩阵
    bias = -torch.abs(pos.unsqueeze(0) - pos.unsqueeze(1))  # (T, T)
    # 每个 head 乘以自己的斜率
    bias = bias.unsqueeze(0) * slopes.view(-1,1,1)          # (h, T, T)
    return bias

0.3.5 Swin Transformer 相对位置偏置(2D 图像专用)

思路: 不编码绝对坐标,而是学习每对 patch 之间相对位移对应的偏置值。
大小为 的窗口里,行列各有 种相对位移,所以偏置表大小是
这比绝对位置编码参数更少泛化更好

class RelativePositionBias(nn.Module):
    """
    Swin Transformer 中的相对位置偏置
    核心思想:不记录"我在哪",只记录"我们相差多少"
    """
    def __init__(self, window_size: int, n_heads: int):
        super().__init__()
        self.window_size = window_size
        M = 2 * window_size - 1
        # 可学习偏置表:(2M-1)² 种相对位置 × n_heads
        self.bias_table = nn.Parameter(torch.zeros(M*M, n_heads))
        nn.init.trunc_normal_(self.bias_table, std=0.02)
 
        # 预计算所有 patch 对的相对位置索引(只需计算一次)
        coords = torch.stack(torch.meshgrid(
            torch.arange(window_size), torch.arange(window_size), indexing='ij'
        ))                                              # (2, W, W)
        coords_flat = coords.flatten(1)                # (2, W²)
        rel = coords_flat[:, :, None] - coords_flat[:, None, :]  # (2, W², W²)
        rel = rel.permute(1, 2, 0).contiguous()
        rel[:, :, 0] += window_size - 1                # 偏移到非负
        rel[:, :, 1] += window_size - 1
        rel[:, :, 0] *= 2 * window_size - 1
        self.register_buffer('rel_index', rel.sum(-1)) # (W², W²)
 
    def forward(self):
        bias = self.bias_table[self.rel_index.view(-1)]
        bias = bias.view(self.window_size**2, self.window_size**2, -1)
        return bias.permute(2, 0, 1).contiguous()      # (n_heads, W², W²)

ResNet18

back to 目录

背景与动机

2014 年前后,更深的网络反而训练效果更差——不是因为过拟合,而是梯度消失导致深层网络根本没学到东西。研究者的直觉是:如果新加的层什么都不学(恒等映射),至少不该比浅层网络差。

残差学习的核心思想: 与其让网络直接学目标映射 ,不如让它学残差 ,目标变成
时退化为恒等映射,这比学一个精确的恒等映射容易得多(把所有参数推向零即可)。

形象地说:残差连接就像给每一层装了一条”高速公路”,梯度可以直接从深层流回浅层,不用经过层层”收费站”(非线性变换)。

整体架构

输入图像 (3, 224, 224)
        │
   ┌────────────────────────────┐
   │ Stem(快速下采样)          │
   │  Conv7×7, 64, stride=2     │  → (64, 112, 112)  空间减半
   │  BN + ReLU                 │
   │  MaxPool 3×3, stride=2     │  → (64, 56, 56)    再减半
   └────────────────────────────┘
        │
   ┌────────────────────────────────────────────────────────────┐
   │ 4个 Stage,通道数翻倍,空间尺寸减半(除第一个Stage)       │
   │                                                            │
   │  Stage 1: 2× BasicBlock(64→64,   stride=1) → (64, 56, 56) │
   │  Stage 2: 2× BasicBlock(64→128,  stride=2) → (128,28,28)  │
   │  Stage 3: 2× BasicBlock(128→256, stride=2) → (256,14,14)  │
   │  Stage 4: 2× BasicBlock(256→512, stride=2) → (512, 7, 7)  │
   └────────────────────────────────────────────────────────────┘
        │
   AdaptiveAvgPool2d(1,1)   → (512, 1, 1)  全局平均池化,去掉空间维度
   Flatten                  → (512,)
   Linear(512, num_classes) → (num_classes,)
        │
      输出 logits

BasicBlock:核心模块

BasicBlock 是整个 ResNet18 的”原子单元”,理解了它就理解了 ResNet 的精髓。

输入 x: (B, C_in, H, W)
   │
   ├─── 主路径(学习残差 F(x))──────────────────────
   │    Conv3×3(C_in→C_out, stride)  ← 可能改变通道数和空间尺寸
   │    BN → ReLU                    ← 归一化+激活
   │    Conv3×3(C_out→C_out)         ← 精细调整特征
   │    BN                           ← 注意:这里不加 ReLU!
   │
   └─── 捷径(shortcut)───────────────────────────────
        ● 若尺寸不变:直接 Identity(x 原样通过)
        ● 若尺寸改变:Conv1×1(stride=2) + BN 投影对齐
   │
   ① 主路径输出 + shortcut 输出  ← 残差相加
   ② ReLU                        ← 相加后才激活
        │
      输出 (B, C_out, H', W')

关键细节:
  - 两个 Conv 之间有 ReLU,最后一个 BN 后先加再激活
  - 这个顺序(pre-activation vs post-activation)有细微差别
  - ResNet18/34 用 BasicBlock;ResNet50+ 用 Bottleneck(1×1-3×3-1×1)
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class BasicBlock(nn.Module):
    expansion = 1  # ResNet50+ 的 Bottleneck 此值为4,用于计算输出通道数
 
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()
        # 第一个卷积:可能下采样(stride=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3,
                               stride=stride, padding=1, bias=False)
        # BN + Conv 组合,bias=False 是因为 BN 的 beta 参数已经起到偏置的作用
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.relu  = nn.ReLU(inplace=True)
        # 第二个卷积:不改变尺寸
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3,
                               padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
 
        # 快捷路径:维度不匹配时需要投影
        # 两种情况:① stride=2(空间缩小)② 通道数变化
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                # 1×1 卷积专门用于改变通道数和空间尺寸
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
 
    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))  # Conv → BN → ReLU
        out = self.bn2(self.conv2(out))            # Conv → BN(还没 ReLU!)
        out += self.shortcut(x)                    # 残差相加
        return self.relu(out)                      # 相加后再 ReLU
 
 
class ResNet18(nn.Module):
    def __init__(self, num_classes: int = 1000):
        super().__init__()
        # Stem:大感受野快速下采样
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        # 4个 Stage,通道数翻倍,空间缩小
        self.layer1 = self._make_layer(64,  64,  n=2, stride=1)
        self.layer2 = self._make_layer(64,  128, n=2, stride=2)
        self.layer3 = self._make_layer(128, 256, n=2, stride=2)
        self.layer4 = self._make_layer(256, 512, n=2, stride=2)
        # 分类头
        self.pool   = nn.AdaptiveAvgPool2d(1)
        self.fc     = nn.Linear(512, num_classes)
        self._init_weights()
 
    def _make_layer(self, in_ch, out_ch, n, stride):
        # 第一个 block 可能有 stride(下采样),之后的 block 维持尺寸
        layers = [BasicBlock(in_ch, out_ch, stride)]
        for _ in range(1, n):
            layers.append(BasicBlock(out_ch, out_ch))
        return nn.Sequential(*layers)
 
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
 
    def forward(self, x):
        x = self.stem(x)     # (B, 64, 56, 56)
        x = self.layer1(x)   # (B, 64, 56, 56)
        x = self.layer2(x)   # (B, 128, 28, 28)
        x = self.layer3(x)   # (B, 256, 14, 14)
        x = self.layer4(x)   # (B, 512, 7, 7)
        x = self.pool(x).flatten(1)   # (B, 512)
        return self.fc(x)    # (B, num_classes)
 
 
# 验证:参数量约 11.7M
model = ResNet18(num_classes=1000)
x = torch.randn(2, 3, 224, 224)
print(model(x).shape)
print(f"{sum(p.numel() for p in model.parameters())/1e6:.1f}M params")

UNet

back to 目录

背景与动机

UNet 诞生于 2015 年的医学图像分割领域,面对的核心挑战是:用极少的标注数据,对每个像素进行分类

普通分类网络通过下采样丢失了大量空间细节(“这里是什么”),而语义分割不仅要知道”是什么”,还要精确到”在哪里”。UNet 的核心设计哲学是:

下采样(Encoder) 负责理解语义(“这是一个细胞核”)
上采样(Decoder) 负责恢复位置(“它在图像的左上角”)
跳跃连接(Skip Connection) 把 Encoder 中丢失的细节直接”传送”给 Decoder

形象地说:Encoder 是近视眼(看不清细节但能看清整体),Decoder 是老花眼(需要细节帮助)。Skip Connection 就是给 Decoder 配了一副眼镜,把 Encoder 看清楚的细节直接传过去。

整体架构

输入 (B, C_in, H, W)   例如 (B, 1, 256, 256)
        │
Encoder(编码器 / 下采样路径):
  ┌─────────────────────────────────────────────────────────┐
  │  Block1: ConvBlock(C_in→64)   → feat1 (B, 64,  H,   W) │──── skip1 ────────────────────────────────────┐
  │  MaxPool ↓2                   → (B, 64,  H/2, W/2)     │                                               │
  │  Block2: ConvBlock(64→128)    → feat2 (B, 128, H/2, W/2)│──── skip2 ──────────────────────────────┐    │
  │  MaxPool ↓2                   → (B, 128, H/4, W/4)     │                                          │    │
  │  Block3: ConvBlock(128→256)   → feat3 (B, 256, H/4, W/4)│──── skip3 ─────────────────────────┐   │    │
  │  MaxPool ↓2                   → (B, 256, H/8, W/8)     │                                     │   │    │
  │  Block4: ConvBlock(256→512)   → feat4 (B, 512, H/8, W/8)│──── skip4 ────────────────────┐   │   │    │
  │  MaxPool ↓2                   → (B, 512, H/16,W/16)    │                               │   │   │    │
  └─────────────────────────────────────────────────────────┘                               │   │   │    │
        │                                                                                    │   │   │    │
Bottleneck:                                                                                  │   │   │    │
  Block5: ConvBlock(512→1024)     → (B, 1024, H/16, W/16)                                  │   │   │    │
        │                                                                                    │   │   │    │
Decoder(解码器 / 上采样路径):                                                               │   │   │    │
  UpConv2×2                       → (B, 512,  H/8,  W/8)                                    │   │   │    │
  Concat(skip4)                   → (B, 1024, H/8,  W/8)  ← ─────────────────────────────┘   │   │    │
  ConvBlock(1024→512)             → (B, 512,  H/8,  W/8)                                        │   │    │
                                                                                                  │   │    │
  UpConv2×2                       → (B, 256,  H/4,  W/4)                                         │   │    │
  Concat(skip3)                   → (B, 512,  H/4,  W/4)  ← ────────────────────────────────────┘   │    │
  ConvBlock(512→256)              → (B, 256,  H/4,  W/4)                                              │    │
  ... (以此类推)                                                                                        │    │
  UpConv2×2                       → (B, 64,   H,    W)                                                │    │
  Concat(skip1)                   → (B, 128,  H,    W)    ← ─────────────────────────────────────────┘    │
  ConvBlock(128→64)               → (B, 64,   H,    W)                                                     │
                                                                                                             │
  Conv1×1                         → (B, num_classes, H, W)  ← 每个像素的分类结果                            │

Skip Connection 的关键作用

没有跳跃连接,解码器就像在黑暗中摸索——它知道”大概在哪里”但不知道”确切边界在哪”。跳跃连接把 Encoder 里精确的位置、边缘、纹理信息直接送给 Decoder,让分割边界变得精确。

class ConvBlock(nn.Module):
    """
    UNet 的基本单元:两次 Conv3×3 + BN + ReLU
    选择 3×3 而非 1×1:保持感受野,充分利用空间上下文
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch,  out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)
 
 
class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2, features=(64,128,256,512)):
        super().__init__()
        # Encoder
        self.encoders = nn.ModuleList()
        self.pools    = nn.ModuleList()
        ch = in_channels
        for f in features:
            self.encoders.append(ConvBlock(ch, f))
            self.pools.append(nn.MaxPool2d(2))
            ch = f
 
        # Bottleneck:特征维度最大,空间最小,语义最强
        self.bottleneck = ConvBlock(features[-1], features[-1] * 2)
 
        # Decoder:逐步上采样 + 融合跳跃连接
        self.upconvs  = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for f in reversed(features):
            # ConvTranspose2d:可学习的上采样(比双线性插值效果好)
            self.upconvs.append(
                nn.ConvTranspose2d(f * 2, f, kernel_size=2, stride=2)
            )
            # 上采样后 Concat 跳跃特征,通道数翻倍
            self.decoders.append(ConvBlock(f * 2, f))
 
        self.final = nn.Conv2d(features[0], num_classes, 1)
 
    def forward(self, x):
        skips = []
        # Encoder:保存每个 stage 的特征图用于跳跃连接
        for enc, pool in zip(self.encoders, self.pools):
            x = enc(x)
            skips.append(x)   # 保存池化前的特征(细节更丰富)
            x = pool(x)
 
        x = self.bottleneck(x)
 
        # Decoder:上采样 + 拼接 skip 特征
        for up, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
            x = up(x)
            # 处理尺寸不整除的边角情况(如输入不是 2 的幂次方)
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])
            # Concat 而非 Add:保留两路信息,不做融合(让后续卷积自己学怎么用)
            x = torch.cat([skip, x], dim=1)
            x = dec(x)
 
        return self.final(x)
 
 
# 验证
model = UNet(in_channels=1, num_classes=2)
x = torch.randn(2, 1, 256, 256)
print(model(x).shape)   # (2, 2, 256, 256)  每个像素 2 分类

SwinIR

back to 目录

背景与动机

SwinIR 是基于 Swin Transformer 的图像超分辨率模型(2021)。它解决了 ViT 用于图像的两个核心问题:

  1. 计算量问题: 标准 Self-Attention 的复杂度是 ,对 224×224 图像有 50176 个 patch,算不起。Swin 把注意力限制在小窗口内,变成 是窗口大小。

  2. 跨窗口信息交流: 但窗口内部关注,窗口之间信息孤立。Swin 用移位窗口(Shifted Window) 解决——交替使用正常窗口和偏移半个窗口大小的窗口,让相邻窗口能间接交流。

比喻: 标准注意力像全国拨打任意号码的电话网络(贵但自由),窗口注意力像只能打本地电话(便宜但受限),移位窗口像相邻区域的号码段有重叠(以较低成本建立了跨区连接)。

SwinIR 整体架构

低分辨率输入 (B, 3, H, W)
        │
  ┌──────────────────────────────┐
  │ 浅层特征提取                  │
  │ Conv3×3 → (B, C, H, W)      │  简单卷积,C=180
  │ 作用:把图像映射到特征空间     │
  └──────────────────────────────┘
        │ F_shallow(保留用于最终残差)
        │
  ┌──────────────────────────────────────────────────┐
  │ 深层特征提取(6个 RSTB)                          │
  │                                                  │
  │  RSTB (Residual Swin Transformer Block)  ×6      │
  │  ┌────────────────────────────────────────────┐  │
  │  │                                            │  │
  │  │  STL (Swin Transformer Layer)  ×6          │  │
  │  │  ┌──────────────────────────────────────┐  │  │
  │  │  │  LN                                  │  │  │
  │  │  │  W-MSA(偶数层,正常窗口)            │  │  │
  │  │  │  或 SW-MSA(奇数层,移位窗口)        │  │  │
  │  │  │  + 残差                               │  │  │
  │  │  │  LN                                  │  │  │
  │  │  │  FFN (MLP)                           │  │  │
  │  │  │  + 残差                               │  │  │
  │  │  └──────────────────────────────────────┘  │  │
  │  │                                            │  │
  │  │  Conv3×3(局部信息增强 + 残差连接)        │  │
  │  └────────────────────────────────────────────┘  │
  │                                                  │
  └──────────────────────────────────────────────────┘
        │ F_deep
        │
  Conv3×3 + (F_shallow + F_deep)  ← 浅深特征融合
        │
  ┌──────────────────────────────┐
  │ 上采样重建                    │
  │ Conv3×3                      │
  │ PixelShuffle ×scale          │  ← 亚像素卷积,1/r² 的计算量生成 r² 个子像素
  │ Conv3×3                      │
  └──────────────────────────────┘
        │
  高分辨率输出 (B, 3, s·H, s·W)

窗口注意力机制

输入特征图 (B, H, W, C)
    │
    ▼
分成 nW = (H/M)×(W/M) 个不重叠 M×M 窗口
    → (B×nW, M×M, C)
    │
    ▼
在每个窗口内独立做 Multi-Head Attention
(M=8 时,每次只算 64 个 token,而不是 H×W 个)
加上相对位置偏置(Relative Position Bias)
    │
    ▼
窗口合并回 (B, H, W, C)

SW-MSA(Shifted Window)的做法:
    原窗口边界:  ┌───┬───┐   移位后边界:  ┌─┬─────┬─┐
                │ A │ B │              │D│  A  │B│
                ├───┼───┤              ├─┼─────┼─┤
                │ C │ D │              │ │     │ │
                └───┴───┘              │C│     │A│
                                       └─┴─────┴─┘
    原来 A/B/C/D 四个窗口互不相交,移位后出现了跨越原窗口边界的新窗口
    → A 窗口的 token 现在和 B/C/D 的边界 token 同处一个注意力计算中
    → 用 Cyclic Shift + Attention Mask 高效实现(无需 padding)
def window_partition(x, window_size):
    """(B, H, W, C) → (B*nW, M, M, C)"""
    B, H, W, C = x.shape
    x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
    windows = x.permute(0,1,3,2,4,5).contiguous()
    return windows.view(-1, window_size, window_size, C)
 
def window_reverse(windows, window_size, H, W):
    """window_partition 的逆操作"""
    B_nW = windows.shape[0]
    B = int(B_nW / (H * W / window_size / window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    return x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)
 
 
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, n_heads):
        super().__init__()
        self.n_heads  = n_heads
        self.d_k      = dim // n_heads
        self.scale    = self.d_k ** -0.5
        self.W_qkv    = nn.Linear(dim, 3*dim)
        self.W_o      = nn.Linear(dim, dim)
        self.rel_bias = RelativePositionBias(window_size, n_heads)
 
    def forward(self, x, mask=None):
        # x: (B*nW, M*M, C)
        B_, N, C = x.shape
        qkv = self.W_qkv(x).reshape(B_, N, 3, self.n_heads, self.d_k)
        q, k, v = qkv.permute(2,0,3,1,4).unbind(0)
 
        q  = q * self.scale
        attn = q @ k.transpose(-2,-1)
        attn = attn + self.rel_bias()   # 加相对位置偏置:让模型感知窗口内的相对位置
 
        if mask is not None:
            # SW-MSA 的 mask:屏蔽不该相互关注的跨区域 token 对
            nW = mask.shape[0]
            attn = attn.view(B_//nW, nW, self.n_heads, N, N)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.n_heads, N, N)
 
        attn = F.softmax(attn, dim=-1)
        out  = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.W_o(out)
 
 
class SwinTransformerLayer(nn.Module):
    """
    一个 Swin Transformer 层
    交替使用 W-MSA 和 SW-MSA:
    - 偶数层(shift=False):正常窗口,学习局部内部关系
    - 奇数层(shift=True):移位窗口,学习跨窗口关系
    """
    def __init__(self, dim, n_heads, window_size=8, shift=False, mlp_ratio=4.0):
        super().__init__()
        self.shift       = shift
        self.window_size = window_size
        self.shift_size  = window_size // 2 if shift else 0
 
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = WindowAttention(dim, window_size, n_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim)
        )
 
    def forward(self, x, H, W):
        B, L, C = x.shape
        x_2d = x.view(B, H, W, C)
 
        # Cyclic Shift:把特征图整体滚动,让边界处的 token 聚集到中央的同一个窗口
        if self.shift_size > 0:
            x_2d = torch.roll(x_2d, (-self.shift_size, -self.shift_size), dims=(1,2))
 
        windows  = window_partition(x_2d, self.window_size)
        windows  = windows.view(-1, self.window_size**2, C)
        attn_out = self.attn(self.norm1(windows))
        attn_out = attn_out.view(-1, self.window_size, self.window_size, C)
        x_2d     = window_reverse(attn_out, self.window_size, H, W)
 
        # 反向滚动还原位置
        if self.shift_size > 0:
            x_2d = torch.roll(x_2d, (self.shift_size, self.shift_size), dims=(1,2))
 
        x = x + x_2d.view(B, L, C)   # 残差
        x = x + self.mlp(self.norm2(x))
        return x
 
 
class RSTB(nn.Module):
    """
    Residual Swin Transformer Block
    = N 个 SwinTransformerLayer + Conv3×3 + 残差
    最后的 Conv3×3 增强局部建模能力(Transformer 偏全局,Conv 偏局部,互补)
    """
    def __init__(self, dim, n_heads, n_layers=6, window_size=8):
        super().__init__()
        self.layers = nn.ModuleList([
            SwinTransformerLayer(dim, n_heads, window_size, shift=(i%2==1))
            for i in range(n_layers)
        ])
        self.norm = nn.LayerNorm(dim)
        self.conv = nn.Conv2d(dim, dim, 3, padding=1)
 
    def forward(self, x, H, W):
        res = x
        for layer in self.layers:
            x = layer(x, H, W)
        x = self.norm(x)
        x = x.transpose(1,2).view(-1, x.shape[-1], H, W)
        x = self.conv(x).flatten(2).transpose(1,2)
        return x + res   # 残差:跨越整个 RSTB 的捷径
 
 
class SwinIR(nn.Module):
    def __init__(self, in_ch=3, dim=60, n_heads=6, n_rstb=6,
                 window_size=8, scale=4):
        super().__init__()
        self.shallow = nn.Conv2d(in_ch, dim, 3, padding=1)
        self.deep = nn.ModuleList([
            RSTB(dim, n_heads, window_size=window_size) for _ in range(n_rstb)
        ])
        self.deep_norm = nn.LayerNorm(dim)
        self.deep_conv = nn.Conv2d(dim, dim, 3, padding=1)
        self.upsample  = nn.Sequential(
            nn.Conv2d(dim, dim * scale**2, 3, padding=1),
            nn.PixelShuffle(scale),   # 亚像素卷积:把通道维的信息"折叠"到空间维
            nn.Conv2d(dim, in_ch, 3, padding=1)
        )
 
    def forward(self, x):
        B, C, H, W = x.shape
        feat = self.shallow(x)
 
        deep = feat.flatten(2).transpose(1,2)   # (B, H*W, dim)
        for rstb in self.deep:
            deep = rstb(deep, H, W)
        deep = self.deep_norm(deep)
        deep = deep.transpose(1,2).view(B, -1, H, W)
        deep = self.deep_conv(deep) + feat       # 深浅特征相加
 
        return self.upsample(deep)
 
 
model = SwinIR(scale=4)
x = torch.randn(1, 3, 64, 64)
print(model(x).shape)   # (1, 3, 256, 256) 4×超分辨率

CLIP

back to 目录

背景与动机

传统图像识别需要大量手工标注,而且每次换任务就要重新训练。CLIP(Contrastive Language-Image Pre-Training,2021)的核心洞察是:互联网上有海量的图文对(图片+描述),这是天然的监督信号

CLIP 的训练思路极为简单而优雅:

  • 给一批图像和对应的文本描述
  • 训练两个编码器,使配对的图文特征尽量相近,不配对的尽量远离
  • 这就是对比学习(Contrastive Learning)

零样本推理能力是 CLIP 最神奇的特性:训练时没见过”猫”这个类别,推理时只需构建文本 “a photo of a cat”,让图像特征与各类别文本特征比较相似度即可分类。模型学到的是通用的图文对齐空间,而不是特定类别的分类器。

整体架构

训练时:一批 N 对 (图像_i, 文本_i)

图像端                              文本端
(B, 3, 224, 224)                  (B, 77)  ← 最多 77 个 token
      │                                 │
 ViT-B/16 图像编码器               Transformer 文本编码器
  ┌─────────────────┐              ┌──────────────────┐
  │ 分成 14×14 个    │              │  Token Embedding  │
  │ 16×16 的 patch  │              │  + Positional Emb │
  │ 线性投影+CLS     │              │                  │
  │ Transformer ×12 │              │  Transformer ×12  │
  │ 取 CLS token    │              │  取 EOS token     │
  └─────────────────┘              └──────────────────┘
      │                                 │
  Linear 投影                       Linear 投影
      │                                 │
  image_feat (B, 512)               text_feat (B, 512)
      │                                 │
      └──── L2 归一化 ─────┐  ┌──── L2 归一化 ────┘
                           │  │
            相似度矩阵 = image_feat @ text_feat.T * exp(τ)
            → (B, B) 矩阵,第 i 行第 j 列 = 第 i 张图和第 j 段文本的相似度
            → 对角线是正样本,其余是负样本

对比损失(InfoNCE):
  行方向 CE:每张图 → 正确文本(共 B 个负样本)
  列方向 CE:每段文本 → 正确图像
  最终 loss = (行CE + 列CE) / 2

温度参数 τ 的作用:控制分布的”锐利程度”。τ 小→分布尖锐,区分度高;τ 大→分布平缓,训练更稳定。τ 本身也是可学习的。

class ViTImageEncoder(nn.Module):
    """
    CLIP 图像编码器:Vision Transformer (ViT)
    核心操作:把图像切成 patch → 当作 token 序列 → Transformer 处理
    """
    def __init__(self, img_size=224, patch_size=16, d_model=768,
                 n_heads=12, n_layers=12, embed_dim=512):
        super().__init__()
        n_patches = (img_size // patch_size) ** 2   # 14×14 = 196 个 patch
        # 用卷积实现 patch embedding(等价于把每个 patch 展平再做 Linear)
        self.patch_embed = nn.Conv2d(3, d_model, patch_size, stride=patch_size)
        # [CLS] token:汇聚全局信息,最终用它来表示整张图
        self.cls_token   = nn.Parameter(torch.zeros(1, 1, d_model))
        # 可学习位置编码(ViT 使用)
        self.pos_embed   = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
 
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
            dropout=0.0, activation='gelu', batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.ln   = nn.LayerNorm(d_model)
        self.proj = nn.Linear(d_model, embed_dim, bias=False)
 
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).transpose(1,2)   # (B, N, D)
        cls = self.cls_token.expand(B, -1, -1)
        x   = torch.cat([cls, x], dim=1) + self.pos_embed   # (B, N+1, D)
        x   = self.transformer(x)
        x   = self.ln(x[:, 0])                              # 取 CLS token
        return self.proj(x)                                  # (B, embed_dim)
 
 
class TextEncoder(nn.Module):
    """
    CLIP 文本编码器:因果 Transformer(使用 causal mask)
    取 EOS token 的输出作为整个文本的表示(而不是 CLS)
    这是 CLIP 的独特选择,EOS 在生成过程中看到了所有 token,信息最完整
    """
    def __init__(self, vocab_size=49408, ctx_len=77, d_model=512,
                 n_heads=8, n_layers=12, embed_dim=512):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed   = nn.Parameter(torch.zeros(ctx_len, d_model))
 
        decoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
            dropout=0.0, activation='gelu', batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(decoder_layer, n_layers)
        self.ln   = nn.LayerNorm(d_model)
        self.proj = nn.Linear(d_model, embed_dim, bias=False)
 
    def forward(self, tokens):
        x = self.token_embed(tokens) + self.pos_embed[:tokens.shape[1]]
        x = self.transformer(x)
        x = self.ln(x)
        eot_idx = tokens.argmax(dim=-1)   # EOS token 的位置
        x = x[torch.arange(x.shape[0]), eot_idx]
        return self.proj(x)
 
 
class CLIP(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.image_encoder = ViTImageEncoder(embed_dim=embed_dim)
        self.text_encoder  = TextEncoder(embed_dim=embed_dim)
        # 可学习温度:初始化为 1/0.07 ≈ 14.3,在 log 空间学习(保证正值)
        self.logit_scale   = nn.Parameter(torch.ones([]) * math.log(1/0.07))
 
    def encode_image(self, image):
        # L2 归一化:把特征映射到单位超球面上,相似度 = 余弦相似度
        return F.normalize(self.image_encoder(image), dim=-1)
 
    def encode_text(self, tokens):
        return F.normalize(self.text_encoder(tokens), dim=-1)
 
    def forward(self, image, tokens):
        img_feat  = self.encode_image(image)     # (B, D)
        text_feat = self.encode_text(tokens)     # (B, D)
 
        scale  = self.logit_scale.exp()
        # (B, B) 矩阵:每对图文的余弦相似度 × 温度
        logits = scale * img_feat @ text_feat.T
 
        # 标签就是对角线索引(第 i 张图对应第 i 段文本)
        labels = torch.arange(logits.shape[0], device=logits.device)
        loss_i = F.cross_entropy(logits,   labels)   # 图匹配文
        loss_t = F.cross_entropy(logits.T, labels)   # 文匹配图
        return (loss_i + loss_t) / 2
 
    @torch.no_grad()
    def zero_shot_classify(self, image, class_prompts_tokens):
        """
        零样本分类:不需要额外训练,用文本描述代替类别标签
        例如:["a photo of a cat", "a photo of a dog", ...]
        """
        img_feat   = self.encode_image(image)
        text_feats = self.encode_text(class_prompts_tokens)   # (C, D)
        probs = (img_feat @ text_feats.T * self.logit_scale.exp()).softmax(-1)
        return probs   # (B, C)

LLaVA

back to 目录

背景与动机

LLaVA(Large Language and Vision Assistant,2023)的核心问题是:如何让已经很强的 LLM 拥有”看图说话”的能力?

最暴力的想法是从头联合训练图像理解和语言生成,但代价极高。LLaVA 的优雅方案是:CLIP 已经学会了理解图像,LLM 已经学会了理解和生成语言,我只需要一个”翻译器”把图像特征转换到 LLM 能理解的空间

这个”翻译器”就是 MLP Projector——两层线性变换而已,轻量但有效。

整体架构

图像 (B, 3, 336, 336)
        │
  CLIP ViT-L/14@336px
  (训练时通常冻结,已具备强大图像理解能力)
        │ 取所有 patch token(去掉 CLS)
  (B, 576, 1024)   ← 576 = 24×24 个 patch,每个 1024 维
        │
  MLP Projector(可学习!这是两个模态的"桥梁")
  Linear(1024 → 4096) → GELU → Linear(4096 → 4096)
        │ (B, 576, 4096)  ← 对齐到 LLaMA 的 hidden size
        │
  ┌──────────────────────────────────────────────────────┐
  │ LLaMA / Vicuna 7B(指令微调的 LLaMA)                │
  │                                                      │
  │ 输入序列构建:                                        │
  │  [系统提示] [USER:] [576个视觉token] [文字问题] [ASSISTANT:] │
  │                                                      │
  │  word tokens → Embedding(vocab, 4096)                │
  │  + 视觉 token(从 Projector 来)直接拼接              │
  │        ↓                                             │
  │  合并序列 (B, 576 + T_text, 4096)                    │
  │        ↓                                             │
  │  LLaMA Transformer ×32 层                           │
  │        ↓                                             │
  │  输出 logits,预测下一个 token                        │
  └──────────────────────────────────────────────────────┘

两阶段训练策略:
  Stage 1 - 特征对齐(约 60 万图文对,1 epoch):
    ✅ 训练 MLP Projector
    ❄️ 冻结 CLIP ViT
    ❄️ 冻结 LLM
    目标:让 Projector 学会将视觉特征翻译成 LLM 的"语言"

  Stage 2 - 视觉对话微调(约 15 万多轮对话,1 epoch):
    ✅ 训练 MLP Projector
    ❄️ 冻结 CLIP ViT
    ✅ 训练 LLM(LoRA 或全参数)
    目标:让 LLM 学会用视觉信息回答问题
class MLPProjector(nn.Module):
    """
    视觉-语言桥梁模块
    设计简单但关键:把 CLIP 的视觉空间映射到 LLM 的 token 空间
    LLaVA-1.5 发现两层 MLP 比单层线性效果好很多
    """
    def __init__(self, vision_dim=1024, llm_dim=4096):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(vision_dim, llm_dim),
            nn.GELU(),
            nn.Linear(llm_dim, llm_dim),
        )
    def forward(self, x):
        return self.proj(x)   # (B, num_patches, llm_dim)
 
 
class LLaVA(nn.Module):
    """
    LLaVA 整体结构
    核心设计:视觉 token 和文本 token 在 LLM 输入层拼接,
    之后完全用统一的 Transformer 处理,不需要特殊的 Cross-Attention
    """
    def __init__(self, vision_encoder, llm, projector,
                 image_token_id: int = -200):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.projector      = projector
        self.llm            = llm
        self.image_token_id = image_token_id   # 特殊占位符 token 的 id
 
    def get_visual_tokens(self, images):
        with torch.no_grad():
            # 取 patch tokens(去掉 CLS),保留空间信息
            vision_feats = self.vision_encoder(images)
            patch_feats  = vision_feats[:, 1:]     # (B, N_patch, D_vision)
        return self.projector(patch_feats)         # (B, N_patch, D_llm)
 
    def forward(self, input_ids, images, attention_mask=None, labels=None):
        """
        input_ids: (B, T) 文本 token,其中 <IMAGE> 位置为特殊 token id
        images:    (B, 3, H, W)
 
        典型的 input_ids 结构(展开后):
        [SYSTEM] [USER:] [<IMAGE>×576] [问题文字 tokens] [ASSISTANT:]
        """
        # 1. 文本 token 的 embedding
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)   # (B, T, D)
 
        # 2. 视觉 token 的 embedding
        vis_tokens = self.get_visual_tokens(images)   # (B, N, D)
 
        # 3. 将视觉 embedding 替换掉 <IMAGE> 占位符的位置
        image_mask = (input_ids == self.image_token_id)
        for b in range(input_ids.shape[0]):
            img_pos = image_mask[b].nonzero(as_tuple=True)[0]
            n_img   = vis_tokens.shape[1]
            inputs_embeds[b, img_pos[:n_img]] = vis_tokens[b]
 
        # 4. 整体送入 LLM,像正常文本一样处理
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels   # 只在 ASSISTANT 的回答部分计算 loss
        )
        return outputs

LLaMA

back to 目录

背景与动机

LLaMA(2023,Meta)是开源 LLM 领域的里程碑,其设计哲学是:在同等参数量下,用更多的数据和更精心的架构设计,超越更大的模型

LLaMA 相对原始 Transformer 的主要改进可以概括为四点:

改进点原始 TransformerLLaMA
归一化位置Post-LN(层后归一化)Pre-RMSNorm(层前,去掉均值)
激活函数ReLU / GELUSwiGLU(门控激活)
位置编码绝对位置编码RoPE(旋转位置编码)
注意力MHA(多头注意力)GQA(分组查询注意力)

整体架构

输入 token ids: (B, T)
        │
  Embedding(vocab_size=32000, dim=4096)
        │ (B, T, 4096)
        │
  ┌───────────────────────────────────────────────────────┐
  │ Transformer Block ×32                                 │
  │                                                       │
  │  x ──► RMSNorm ──► Self-Attention(RoPE + GQA) ──►──┐ │
  │  │                                                  + │ │
  │  └──────────────────────────────────────────────────┘ │
  │  ↓                                                     │
  │  x ──► RMSNorm ──► SwiGLU FFN ──────────────────────┐ │
  │  │                                                   + │ │
  │  └───────────────────────────────────────────────────┘ │
  └───────────────────────────────────────────────────────┘
        │
  RMSNorm(最后一层)
        │
  Linear(4096, 32000) = lm_head
  (与 Embedding 共享权重,节省参数)
        │
  logits: (B, T, 32000)

Pre-RMSNorm 的意义

原始 Transformer 在 Attention/FFN 之后归一化(Post-LN),训练不稳定,需要 warmup 很长时间。
Pre-LN(LN 在 Attention/FFN 之前)训练更稳定,梯度流更好。
RMSNorm 去掉了均值项,计算更快,效果相当。

SwiGLU 的意义

普通 FFN:
SwiGLU:

多了一个门控机制: 决定”开多大门”, 是信号。这让模型能更灵活地选择性地传递信息,实践中比 GELU/ReLU 效果更好。

GQA 的意义(Grouped Query Attention)

标准多头注意力:32 个 Q head,32 个 K/V head → 显存占用大
GQA(LLaMA-2 采用):32 个 Q head,只有 8 个 K/V head,每 4 个 Q 共享 1 对 K/V
→ K/V Cache 的显存缩减 4 倍,推理速度大幅提升,精度损失很小

class RMSNorm(nn.Module):
    """
    Root Mean Square Normalization
    相比 LayerNorm 去掉了减均值的步骤,计算量更小,效果相当
    """
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps    = eps
    def forward(self, x):
        # 在 float32 精度下计算,避免 fp16 溢出
        norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return (x.float() / norm).type_as(x) * self.weight
 
 
class SwiGLU(nn.Module):
    """
    SwiGLU FFN:比 FFN 多了一个门控路径
    hidden_dim 通常为 int(8/3 * dim),然后对齐到 256 的倍数
    """
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)   # gate
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)   # down
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)   # up
    def forward(self, x):
        # SiLU(x) = x * sigmoid(x)(Swish 激活)
        # 门控:SiLU(gate) ⊙ up,决定信息流量
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
 
 
class LlamaAttention(nn.Module):
    """
    LLaMA Attention:GQA + RoPE + 因果 mask
    n_kv_heads < n_heads 时启用 GQA
    """
    def __init__(self, dim: int, n_heads: int, n_kv_heads: int = None,
                 max_seq_len: int = 4096, base: float = 10000.0):
        super().__init__()
        self.n_heads    = n_heads
        self.n_kv_heads = n_kv_heads or n_heads
        self.n_rep      = n_heads // self.n_kv_heads   # 每个 KV head 被几个 Q head 共享
        self.d_k        = dim // n_heads
 
        self.Wq = nn.Linear(dim, n_heads    * self.d_k, bias=False)
        self.Wk = nn.Linear(dim, self.n_kv_heads * self.d_k, bias=False)
        self.Wv = nn.Linear(dim, self.n_kv_heads * self.d_k, bias=False)
        self.Wo = nn.Linear(n_heads * self.d_k, dim, bias=False)
 
        freqs = precompute_freqs_cis(self.d_k, max_seq_len, base)
        self.register_buffer('freqs_cis', freqs)
 
    def forward(self, x, start_pos=0):
        B, T, _ = x.shape
        q = self.Wq(x).view(B, T, self.n_heads,    self.d_k)
        k = self.Wk(x).view(B, T, self.n_kv_heads, self.d_k)
        v = self.Wv(x).view(B, T, self.n_kv_heads, self.d_k)
 
        # 旋转位置编码:只旋转 Q 和 K,不动 V
        freqs = self.freqs_cis[start_pos: start_pos + T]
        q, k  = apply_rotary_emb(q, k, freqs)
 
        # GQA:把 K/V 的 head 数扩展到和 Q 一样(repeat_interleave)
        if self.n_rep > 1:
            k = k.repeat_interleave(self.n_rep, dim=2)
            v = v.repeat_interleave(self.n_rep, dim=2)
 
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
 
        scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_k)
 
        # 上三角 mask(不含对角线):未来 token 不可见
        mask   = torch.triu(torch.full((T, T), float('-inf'), device=x.device), 1)
        scores = scores + mask
 
        out = F.softmax(scores, dim=-1) @ v
        out = out.transpose(1,2).contiguous().view(B, T, -1)
        return self.Wo(out)
 
 
class LlamaBlock(nn.Module):
    """Pre-Norm Transformer Block"""
    def __init__(self, dim=4096, n_heads=32, n_kv_heads=8, ffn_hidden=11008):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn  = LlamaAttention(dim, n_heads, n_kv_heads)
        self.norm2 = RMSNorm(dim)
        self.ffn   = SwiGLU(dim, ffn_hidden)
 
    def forward(self, x, start_pos=0):
        # Pre-Norm:先归一化再运算,残差连接跨越归一化
        x = x + self.attn(self.norm1(x), start_pos)
        x = x + self.ffn(self.norm2(x))
        return x
 
 
class LLaMA(nn.Module):
    def __init__(self, vocab_size=32000, dim=4096, n_layers=32,
                 n_heads=32, n_kv_heads=8, ffn_hidden=11008, max_seq_len=4096):
        super().__init__()
        self.embed   = nn.Embedding(vocab_size, dim)
        self.layers  = nn.ModuleList([
            LlamaBlock(dim, n_heads, n_kv_heads, ffn_hidden)
            for _ in range(n_layers)
        ])
        self.norm    = RMSNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
 
        # 权重绑定:embedding 矩阵和 lm_head 共享同一份参数
        # 直觉:输入和输出都是词表空间,用同一套表示更一致,且节省约 500M 参数
        self.lm_head.weight = self.embed.weight
 
    def forward(self, tokens, start_pos=0):
        x = self.embed(tokens)
        for layer in self.layers:
            x = layer(x, start_pos)
        return self.lm_head(self.norm(x))   # (B, T, vocab_size)
 
    @torch.no_grad()
    def generate(self, prompt_ids, max_new_tokens=100, temperature=1.0):
        """自回归生成:每次预测下一个 token,拼到序列尾部,循环"""
        self.eval()
        x = prompt_ids
        for _ in range(max_new_tokens):
            logits  = self.forward(x)[:, -1, :]   # 只取最后一步的预测
            logits  = logits / temperature
            next_t  = torch.multinomial(F.softmax(logits, -1), 1)
            x = torch.cat([x, next_t], dim=1)
        return x
 
 
# 小型验证
model = LLaMA(vocab_size=1000, dim=256, n_layers=4,
              n_heads=8, n_kv_heads=4, ffn_hidden=512)
x = torch.randint(0, 1000, (2, 16))
print(model(x).shape)   # (2, 16, 1000)

Qwen3

back to 目录

背景与动机

Qwen3(2025,阿里通义)是当前(2025年初)最强的开源多语言 LLM 之一,在架构上相对 LLaMA 做了若干精细化改进。

Qwen3 的整体思路是:站在 LLaMA 的肩膀上,把每个细节都再优化一遍。核心改进:

  1. Q/K 加 RMSNorm:在 RoPE 之前对 Q 和 K 各做一次 RMSNorm,防止注意力分数数值不稳定(特别是在长上下文场景)
  2. 更大的 RoPE base(θ=1,000,000):LLaMA-2 用 10000,Qwen3 用 1,000,000。更大的 base 意味着旋转”更慢”,保留更多低频位置信息,支持更长的上下文(128K+)
  3. MoE 版本(Mixture of Experts):Qwen3-MoE 系列用 64 个专家 FFN,每次只激活 8 个,在保持参数量大(表达力强)的同时,计算量只是密集版的一小部分

Qwen3 vs LLaMA 架构差异

                    LLaMA                    Qwen3
归一化         RMSNorm(Pre-Norm)      RMSNorm(Pre-Norm)✓ 相同
RoPE base      10,000                  1,000,000           ✦ 更大
Q/K 偏置       无 bias                 Q/K 有 bias         ✦ 新增
Q/K Norm       无                      RMSNorm(Q), RMSNorm(K)  ✦ 新增
GQA            部分版本                全系列标配           ✓ 相同
FFN            SwiGLU                  SwiGLU 或 MoE       ✦ MoE 版新增

MoE FFN 架构

输入 x: (B, T, D)
        │
  Router: Linear(D, n_experts=64)
  → softmax → 选出每个 token 的 Top-8 专家 + 分配权重
        │
  ┌─────┬─────┬─────┬─────────────────────────┐
  │FFN_0│FFN_1│FFN_2│  ...  FFN_63            │  ← 64 个 SwiGLU 专家
  └─────┴─────┴─────┴─────────────────────────┘
     每个 token 只经过被选中的 8 个专家(稀疏激活!)
        │
  加权求和(router 分配的权重)
        │
  + 共享专家(Shared Expert,每个 token 都过)
        │
  输出 (B, T, D)

优点:
  ✦ 总参数量 = 64 × FFN 大小(容量大,知识丰富)
  ✦ 每步计算量 = 8 × FFN 大小(推理快,显存占用低)
  ✦ 不同专家自然学习到不同领域的知识(语言、代码、数学...)
class Qwen3Attention(nn.Module):
    """
    Qwen3 的核心改进:
    1. Q/K 带 bias(微小但有效的改进)
    2. Q/K 在 RoPE 前各过一次 RMSNorm(稳定长上下文的 attention score)
    3. RoPE base = 1,000,000(支持 128K+ 上下文)
    """
    def __init__(self, dim=4096, n_heads=32, n_kv_heads=8,
                 max_seq_len=32768, rope_base=1_000_000):
        super().__init__()
        self.n_heads    = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep      = n_heads // n_kv_heads
        self.d_k        = dim // n_heads
 
        # Q/K 有 bias,V 没有 bias(Qwen3 的选择)
        self.Wq = nn.Linear(dim, n_heads    * self.d_k, bias=True)
        self.Wk = nn.Linear(dim, n_kv_heads * self.d_k, bias=True)
        self.Wv = nn.Linear(dim, n_kv_heads * self.d_k, bias=False)
        self.Wo = nn.Linear(n_heads * self.d_k, dim,    bias=False)
 
        # 对每个 head 的 Q/K 独立做 RMSNorm
        # 动机:防止极长序列中 QK^T 的值过大,导致注意力坍塌(只关注极少数位置)
        self.q_norm = RMSNorm(self.d_k)
        self.k_norm = RMSNorm(self.d_k)
 
        freqs = precompute_freqs_cis(self.d_k, max_seq_len, theta=rope_base)
        self.register_buffer('freqs_cis', freqs)
 
    def forward(self, x, start_pos=0):
        B, T, _ = x.shape
        q = self.Wq(x).view(B, T, self.n_heads,    self.d_k)
        k = self.Wk(x).view(B, T, self.n_kv_heads, self.d_k)
        v = self.Wv(x).view(B, T, self.n_kv_heads, self.d_k)
 
        # 先归一化,再旋转——让 Q/K 的模长稳定在合理范围
        q = self.q_norm(q)
        k = self.k_norm(k)
 
        freqs = self.freqs_cis[start_pos: start_pos + T]
        q, k  = apply_rotary_emb(q, k, freqs)
 
        if self.n_rep > 1:
            k = k.repeat_interleave(self.n_rep, dim=2)
            v = v.repeat_interleave(self.n_rep, dim=2)
 
        q = q.transpose(1,2); k = k.transpose(1,2); v = v.transpose(1,2)
        scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_k)
        mask   = torch.triu(torch.full((T,T), float('-inf'), device=x.device), 1)
        out    = F.softmax(scores + mask, dim=-1) @ v
        out    = out.transpose(1,2).contiguous().view(B, T, -1)
        return self.Wo(out)
 
 
class MoEFFN(nn.Module):
    """
    Mixture of Experts FFN(Qwen3-MoE 使用)
    稀疏激活:每个 token 只走 n_active 个专家,大幅降低计算量
    """
    def __init__(self, dim: int, n_experts: int = 64, n_active: int = 8,
                 ffn_hidden: int = 2048, n_shared: int = 1):
        super().__init__()
        self.router   = nn.Linear(dim, n_experts, bias=False)
        self.n_active = n_active
        # 稀疏专家:每次只激活部分
        self.experts  = nn.ModuleList([SwiGLU(dim, ffn_hidden) for _ in range(n_experts)])
        # 共享专家:每个 token 都过,提供稳定的基础表示
        # 设计灵感:不是所有知识都需要专家化,通用能力用共享专家维持
        self.shared_experts = nn.ModuleList([SwiGLU(dim, ffn_hidden) for _ in range(n_shared)])
 
    def forward(self, x):
        B, T, D = x.shape
        x_flat = x.view(-1, D)
 
        # 路由:为每个 token 选出 top-k 个专家
        logits  = self.router(x_flat)
        weights, indices = torch.topk(logits, self.n_active, dim=-1)
        weights = F.softmax(weights, dim=-1)   # 归一化权重
 
        # 稀疏计算(生产实现用 scatter/expert-parallel,这里用循环示意)
        out = torch.zeros_like(x_flat)
        for k in range(self.n_active):
            expert_id = indices[:, k]
            w         = weights[:, k].unsqueeze(-1)
            for i, expert in enumerate(self.experts):
                mask = (expert_id == i)
                if mask.any():
                    out[mask] += w[mask] * expert(x_flat[mask])
 
        # 共享专家(每个 token 都走,补充通用能力)
        for shared in self.shared_experts:
            out = out + shared(x_flat)
 
        return out.view(B, T, D)
 
 
class Qwen3Block(nn.Module):
    def __init__(self, dim=4096, n_heads=32, n_kv_heads=8,
                 ffn_hidden=11008, use_moe=False):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn  = Qwen3Attention(dim, n_heads, n_kv_heads)
        self.norm2 = RMSNorm(dim)
        self.ffn   = MoEFFN(dim, ffn_hidden=ffn_hidden) if use_moe \
                     else SwiGLU(dim, ffn_hidden)
 
    def forward(self, x, start_pos=0):
        x = x + self.attn(self.norm1(x), start_pos)
        x = x + self.ffn(self.norm2(x))
        return x
 
 
class Qwen3(nn.Module):
    def __init__(self, vocab_size=151936, dim=4096, n_layers=32,
                 n_heads=32, n_kv_heads=8, ffn_hidden=11008, use_moe=False):
        super().__init__()
        self.embed   = nn.Embedding(vocab_size, dim)
        self.layers  = nn.ModuleList([
            Qwen3Block(dim, n_heads, n_kv_heads, ffn_hidden, use_moe)
            for _ in range(n_layers)
        ])
        self.norm    = RMSNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
 
    def forward(self, tokens, start_pos=0):
        x = self.embed(tokens)
        for layer in self.layers:
            x = layer(x, start_pos)
        return self.lm_head(self.norm(x))
 
 
# 小型验证
model = Qwen3(vocab_size=1000, dim=256, n_layers=2,
              n_heads=8, n_kv_heads=4, ffn_hidden=512)
x = torch.randint(0, 1000, (2, 16))
print(model(x).shape)   # (2, 16, 1000)

架构对比速查表

back to 目录

整体架构对比

模型类型核心创新注意力机制位置编码归一化激活函数主要用途
ResNet18CNN残差连接BatchNormReLU图像分类
UNetCNN+跳连编解码+跳连BatchNormReLU图像分割/生成
SwinIR层级 ViT窗口注意力W-MSA / SW-MSA相对位置偏置LayerNormGELU图像复原/超分
CLIP双塔对比学习MHSALearnableLayerNormGELU图文对齐
LLaVA多模态模态对齐MHSARoPERMSNormSwiGLU视觉问答
LLaMA纯解码器高效 LLMGQA + 因果maskRoPE(1e4)RMSNormSwiGLU文本生成
Qwen3纯解码器Q/K-Norm + MoEGQA + 因果maskRoPE(1e6)RMSNormSwiGLU多语言/多模态

位置编码选择指南

场景推荐编码原因
短序列分类(BERT 风格)Learnable PE简单有效,不需外推
图像 patch(ViT/CLIP)Learnable PE2D 位置,学习更直接
图像局部窗口(Swin)相对位置偏置只需感知窗口内相对距离
长文本生成(LLM)RoPE天然相对位置,外推好
超长上下文(128K+)RoPE(大 base)更大 theta,低频信息更丰富
极长外推(MPT 等)ALiBi线性衰减,外推最稳

注意力机制复杂度对比

设序列长度 N,窗口大小 M,专家数 E,激活专家数 K

MHA(标准多头注意力):   O(N²·d)          ← N 大时极慢
GQA(分组查询注意力):   O(N²·d)          ← 同 MHA,但 KV Cache 显存少
Window Attention:       O(N·M²·d)        ← M 固定,随 N 线性增长
MoE FFN:               O(N·K/E·d_ffn)   ← K/E 通常很小(8/64=12.5%)

参考资料:timm · HuggingFace Transformers · 各模型原论文