神经网络架构详解
涵盖 ResNet18 · UNet · SwinIR · CLIP · LLaVA · LLaMA · Qwen3
包含结构图、直觉理解、核心模块解析、完整代码实现
目录
基础知识
back to 目录
0.1 Self-Attention
直觉理解
想象你在读一句话:“猫追着球在客厅里跑”。
当你理解”跑”这个词时,你的大脑会自动回头看”猫”和”球”——这正是 Self-Attention 在做的事情:让序列中的每个位置都能”看”到其他所有位置,并按相关程度加权融合信息。
传统 RNN 只能一步步传递信息(像电话传话游戏,传着传着就失真了)。Self-Attention 则像在一个圆桌会议上,每个人可以直接和所有人交流,距离不再是障碍。
Q / K / V 三件套的比喻
把注意力机制想象成一个图书馆检索系统:
- Q(Query,查询):你手里的借书申请单——“我想找关于猫的书”
- K(Key,键):书架上每本书的标签——“这本书讲猫、那本书讲狗”
- V(Value,值):书本身的内容
你拿着申请单(Q)去匹配每本书的标签(K),相似度高的书你就多看一些,最终你读到的内容(输出)是按相似度加权的书的内容(V)的混合。
数学形式
除以 是为了防止点积值过大导致 softmax 进入梯度接近零的饱和区——这就像把麦克风音量调到合适大小,避免声音失真。
输入 X: (B, N, D)
│
┌────┴─────┐
W_Q W_K W_V ← 三个独立的线性投影(学习"怎么提问/怎么打标签/什么是内容")
│ │ │
Q K V
(B,N,d_k) (B,N,d_k) (B,N,d_v)
│ │
└──── QK^T / √d_k ────┘
│
softmax → (B, N, N) 注意力图,第 i 行表示第 i 个 token 关注其他 token 的权重
│
× V → 按权重加权求和每个位置的 Value
│
output (B, N, d_v)
Multi-Head 的意义
单个注意力头只能学习一种”关注模式”(比如语法关系)。多头注意力就像多个摄像头从不同角度拍摄同一个场景——有的头关注语义,有的头关注位置距离,有的头关注共指关系,最终把所有视角合并。
X → [Q,K,V 线性投影] → 拆成 h 个 head
每个 head 独立做 Attention(Q_i, K_i, V_i) ← h 个不同的"视角"
所有 head concat → 线性投影 W_O → 输出
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, N, D = x.shape
# 一次性计算 Q, K, V(效率更高)
qkv = self.W_qkv(x) # (B, N, 3D)
q, k, v = qkv.chunk(3, dim=-1) # 各 (B, N, D)
# 拆分多头:每个 head 独立处理一个 d_k 维子空间
def split_heads(t):
return t.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
# -> (B, h, N, d_k)
q, k, v = split_heads(q), split_heads(k), split_heads(v)
# Scaled Dot-Product Attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: (B, h, N, N) —— 每对 token 之间的相关性
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1) # 归一化为概率分布
attn = self.dropout(attn)
out = torch.matmul(attn, v) # (B, h, N, d_k)
out = out.transpose(1, 2).contiguous() # (B, N, h, d_k)
out = out.view(B, N, D) # 合并所有 head
return self.W_o(out) # 最终投影
# ---- 带 KV Cache 的因果 Self-Attention(用于 LLM 推理加速)----
# 推理时每次只处理一个新 token,把历史 K/V 缓存起来,避免重复计算
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, kv_cache=None):
B, T, D = x.shape
q, k, v = self.W_qkv(x).chunk(3, dim=-1)
def split(t):
return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
q, k, v = split(q), split(k), split(v)
# KV Cache:把新的 K/V 拼接到历史缓存上
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
new_cache = (k, v)
S = k.shape[2]
scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_k)
# 下三角 mask:每个 token 只能看到自己及之前的 token(自回归性质)
causal_mask = torch.tril(torch.ones(T, S, device=x.device))
scores = scores.masked_fill(causal_mask[-T:, :S] == 0, float('-inf'))
out = torch.matmul(F.softmax(scores, dim=-1), v)
out = out.transpose(1,2).contiguous().view(B, T, D)
return self.W_o(out), new_cache0.2 Cross-Attention
直觉理解
如果说 Self-Attention 是”自我反思”(序列内部的信息整合),那 Cross-Attention 就是”跨界对话”——一个序列拿着自己的问题(Q),去另一个序列里查答案(K/V)。
最经典的例子是机器翻译:解码器(生成中文)用当前状态作为 Q,去关注编码器(理解英文)产出的 K/V,从而知道”此刻应该关注原文的哪个部分”。
在多模态模型(如 LLaVA)中,Cross-Attention 让文本 token 能”看”图像特征,从而理解图像内容。
Query 来源 X_q: (B, N_q, D) ← 解码器当前状态 / 文本 token
Key/Value 来源 X_kv: (B, N_kv, D) ← 编码器输出 / 图像特征
Q = X_q · W_Q ← "我现在想知道什么"
K = X_kv · W_K ← "我(上下文)里有什么主题"
V = X_kv · W_V ← "上下文的实际内容"
注意力图 softmax(QK^T/√d): (B, h, N_q, N_kv)
↑ 第 i 个 query token 对第 j 个 context token 的关注程度
输出: Attention(Q, K, V) → (B, N_q, D)
↑ 每个 query token 现在"看过"了所有 context 信息
class CrossAttention(nn.Module):
def __init__(self, d_model: int, d_context: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Q 来自目标序列,K/V 来自上下文序列(维度可以不同)
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_context, d_model, bias=False)
self.W_v = nn.Linear(d_context, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, context):
"""
x: (B, N_q, D) ← query source(提问者)
context: (B, N_kv, D_ctx) ← key/value source(知识库)
"""
B, N_q, _ = x.shape
B, N_kv, _ = context.shape
q = self.W_q(x) # (B, N_q, D)
k = self.W_k(context) # (B, N_kv, D)
v = self.W_v(context) # (B, N_kv, D)
def split(t, N):
return t.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
q, k, v = split(q, N_q), split(k, N_kv), split(v, N_kv)
scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1) # (B, h, N_q, N_kv)
out = torch.matmul(attn, v) # (B, h, N_q, d_k)
out = out.transpose(1,2).contiguous().view(B, N_q, -1)
return self.W_o(out)0.3 位置编码
为什么需要位置编码?
Attention 本质上是一种无序操作——把输入打乱顺序,输出不变(只是行列换一下)。但语言是有顺序的:“狗咬人” ≠ “人咬狗”。因此必须用位置编码把序号信息注入进去。
各种位置编码的核心问题只有一个:怎么让模型知道 token 在哪个位置,以及两个 token 之间有多远?
0.3.1 Sinusoidal PE(原始 Transformer 论文)
思路: 用不同频率的正弦/余弦波来表示位置,就像钟表的时针、分针、秒针——频率不同,组合起来可以唯一表示任意时刻。
低维度用高频(变化快,区分近邻),高维度用低频(变化慢,区分远端)。
优点: 无需学习,天然能外推到训练时没见过的更长序列。
缺点: 相对位置信息是隐式的,模型需要自己去学习如何利用它。
class SinusoidalPE(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1).float() # (L, 1)
div = torch.exp(torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)) # (D/2,)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0)) # (1, L, D)
def forward(self, x):
# 位置编码直接加到输入 embedding 上
return x + self.pe[:, :x.size(1)]0.3.2 Learnable PE(BERT / ViT)
思路: 既然不知道最好的编码方式,就让模型自己学。每个位置分配一个可训练的向量,和 word embedding 一样从数据中学习。
优点: 灵活,针对任务自适应。
缺点: 无法外推(训练时最长 512,推理时遇到 600 长度就不知道怎么处理了)。
class LearnablePE(nn.Module):
def __init__(self, d_model: int, max_len: int):
super().__init__()
self.pe = nn.Embedding(max_len, d_model) # 每个位置一个可学习的向量
def forward(self, x):
pos = torch.arange(x.size(1), device=x.device)
return x + self.pe(pos)0.3.3 RoPE(Rotary Position Embedding)⭐ LLaMA / Qwen 使用
思路: 这是最精妙的设计。不直接”加”位置信息,而是对 Q/K 向量旋转一个与位置成正比的角度。
数学上可以证明:旋转后的 只与相对位置差 有关,与绝对位置无关。
比喻: 想象两个人站在圆形跑道上,他们之间的”相对方向”只取决于各自的位置差,与他们在跑道哪个绝对位置无关。RoPE 就是让注意力分数只感知这个”相对方向”。
为什么 LLM 青睐 RoPE:
- 外推性比 Learnable PE 强
- 相对位置天然编码
- 可以通过调整
theta来扩展上下文长度(LLaMA 用 10000,Qwen3 用 1,000,000)
旋转示意(2D 情况):
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
"""
预计算每个位置、每个维度对的旋转角度(以复数形式存储)
dim/2 对维度,每对用一个复数 e^{i*pos*θ_j} 表示
"""
# 每对维度的"旋转速度":低维快(高频),高维慢(低频)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(seq_len)
freqs = torch.outer(t, freqs) # (seq_len, dim/2)
# 转为复数:e^{i*θ} = cos(θ) + i*sin(θ)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb(xq, xk, freqs_cis):
"""
将旋转位置编码应用到 Q 和 K
原理:复数乘法 = 旋转。把每对维度当作一个二维向量,乘以旋转复数
"""
def rotate(x, f):
# x: (B, T, h, d_k)
# 把每对相邻维度合成一个复数
x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
# f: (T, d_k/2) → 广播到 (1, T, 1, d_k/2)
f = f[:x.shape[1]].unsqueeze(0).unsqueeze(2)
# 复数乘法:相当于旋转
x_rot = torch.view_as_real(x_c * f).flatten(3)
return x_rot.type_as(x)
xq = rotate(xq, freqs_cis)
xk = rotate(xk, freqs_cis)
return xq, xk0.3.4 ALiBi(Attention with Linear Biases)
思路: 不修改输入,直接在注意力分数矩阵上加一个线性惩罚——距离越远,分数扣得越多,天然让模型更关注近邻。
不同 head 用不同的惩罚斜率,有的 head 视野窄(斜率大,惩罚重),有的 head 视野宽(斜率小,惩罚轻)。
优点: 外推能力极强(训练 1024,推理 4096 也不差)。
缺点: 先验假设”近的更重要”不一定对所有任务成立。
def get_alibi_slopes(n_heads):
"""计算等比数列斜率,斜率越大 = 对距离越敏感"""
def _slopes(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * (start ** i) for i in range(n)]
return torch.tensor(_slopes(n_heads))
def alibi_bias(seq_len, n_heads, slopes):
pos = torch.arange(seq_len)
# |i - j|:位置距离矩阵
bias = -torch.abs(pos.unsqueeze(0) - pos.unsqueeze(1)) # (T, T)
# 每个 head 乘以自己的斜率
bias = bias.unsqueeze(0) * slopes.view(-1,1,1) # (h, T, T)
return bias0.3.5 Swin Transformer 相对位置偏置(2D 图像专用)
思路: 不编码绝对坐标,而是学习每对 patch 之间相对位移对应的偏置值。
大小为 的窗口里,行列各有 种相对位移,所以偏置表大小是 。
这比绝对位置编码参数更少、泛化更好。
class RelativePositionBias(nn.Module):
"""
Swin Transformer 中的相对位置偏置
核心思想:不记录"我在哪",只记录"我们相差多少"
"""
def __init__(self, window_size: int, n_heads: int):
super().__init__()
self.window_size = window_size
M = 2 * window_size - 1
# 可学习偏置表:(2M-1)² 种相对位置 × n_heads
self.bias_table = nn.Parameter(torch.zeros(M*M, n_heads))
nn.init.trunc_normal_(self.bias_table, std=0.02)
# 预计算所有 patch 对的相对位置索引(只需计算一次)
coords = torch.stack(torch.meshgrid(
torch.arange(window_size), torch.arange(window_size), indexing='ij'
)) # (2, W, W)
coords_flat = coords.flatten(1) # (2, W²)
rel = coords_flat[:, :, None] - coords_flat[:, None, :] # (2, W², W²)
rel = rel.permute(1, 2, 0).contiguous()
rel[:, :, 0] += window_size - 1 # 偏移到非负
rel[:, :, 1] += window_size - 1
rel[:, :, 0] *= 2 * window_size - 1
self.register_buffer('rel_index', rel.sum(-1)) # (W², W²)
def forward(self):
bias = self.bias_table[self.rel_index.view(-1)]
bias = bias.view(self.window_size**2, self.window_size**2, -1)
return bias.permute(2, 0, 1).contiguous() # (n_heads, W², W²)ResNet18
back to 目录
背景与动机
2014 年前后,更深的网络反而训练效果更差——不是因为过拟合,而是梯度消失导致深层网络根本没学到东西。研究者的直觉是:如果新加的层什么都不学(恒等映射),至少不该比浅层网络差。
残差学习的核心思想: 与其让网络直接学目标映射 ,不如让它学残差 ,目标变成 。
当 时退化为恒等映射,这比学一个精确的恒等映射容易得多(把所有参数推向零即可)。
形象地说:残差连接就像给每一层装了一条”高速公路”,梯度可以直接从深层流回浅层,不用经过层层”收费站”(非线性变换)。
整体架构
输入图像 (3, 224, 224)
│
┌────────────────────────────┐
│ Stem(快速下采样) │
│ Conv7×7, 64, stride=2 │ → (64, 112, 112) 空间减半
│ BN + ReLU │
│ MaxPool 3×3, stride=2 │ → (64, 56, 56) 再减半
└────────────────────────────┘
│
┌────────────────────────────────────────────────────────────┐
│ 4个 Stage,通道数翻倍,空间尺寸减半(除第一个Stage) │
│ │
│ Stage 1: 2× BasicBlock(64→64, stride=1) → (64, 56, 56) │
│ Stage 2: 2× BasicBlock(64→128, stride=2) → (128,28,28) │
│ Stage 3: 2× BasicBlock(128→256, stride=2) → (256,14,14) │
│ Stage 4: 2× BasicBlock(256→512, stride=2) → (512, 7, 7) │
└────────────────────────────────────────────────────────────┘
│
AdaptiveAvgPool2d(1,1) → (512, 1, 1) 全局平均池化,去掉空间维度
Flatten → (512,)
Linear(512, num_classes) → (num_classes,)
│
输出 logits
BasicBlock:核心模块
BasicBlock 是整个 ResNet18 的”原子单元”,理解了它就理解了 ResNet 的精髓。
输入 x: (B, C_in, H, W)
│
├─── 主路径(学习残差 F(x))──────────────────────
│ Conv3×3(C_in→C_out, stride) ← 可能改变通道数和空间尺寸
│ BN → ReLU ← 归一化+激活
│ Conv3×3(C_out→C_out) ← 精细调整特征
│ BN ← 注意:这里不加 ReLU!
│
└─── 捷径(shortcut)───────────────────────────────
● 若尺寸不变:直接 Identity(x 原样通过)
● 若尺寸改变:Conv1×1(stride=2) + BN 投影对齐
│
① 主路径输出 + shortcut 输出 ← 残差相加
② ReLU ← 相加后才激活
│
输出 (B, C_out, H', W')
关键细节:
- 两个 Conv 之间有 ReLU,最后一个 BN 后先加再激活
- 这个顺序(pre-activation vs post-activation)有细微差别
- ResNet18/34 用 BasicBlock;ResNet50+ 用 Bottleneck(1×1-3×3-1×1)
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion = 1 # ResNet50+ 的 Bottleneck 此值为4,用于计算输出通道数
def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
super().__init__()
# 第一个卷积:可能下采样(stride=2)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3,
stride=stride, padding=1, bias=False)
# BN + Conv 组合,bias=False 是因为 BN 的 beta 参数已经起到偏置的作用
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
# 第二个卷积:不改变尺寸
self.conv2 = nn.Conv2d(out_channels, out_channels, 3,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 快捷路径:维度不匹配时需要投影
# 两种情况:① stride=2(空间缩小)② 通道数变化
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
# 1×1 卷积专门用于改变通道数和空间尺寸
nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x))) # Conv → BN → ReLU
out = self.bn2(self.conv2(out)) # Conv → BN(还没 ReLU!)
out += self.shortcut(x) # 残差相加
return self.relu(out) # 相加后再 ReLU
class ResNet18(nn.Module):
def __init__(self, num_classes: int = 1000):
super().__init__()
# Stem:大感受野快速下采样
self.stem = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=2, padding=1)
)
# 4个 Stage,通道数翻倍,空间缩小
self.layer1 = self._make_layer(64, 64, n=2, stride=1)
self.layer2 = self._make_layer(64, 128, n=2, stride=2)
self.layer3 = self._make_layer(128, 256, n=2, stride=2)
self.layer4 = self._make_layer(256, 512, n=2, stride=2)
# 分类头
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512, num_classes)
self._init_weights()
def _make_layer(self, in_ch, out_ch, n, stride):
# 第一个 block 可能有 stride(下采样),之后的 block 维持尺寸
layers = [BasicBlock(in_ch, out_ch, stride)]
for _ in range(1, n):
layers.append(BasicBlock(out_ch, out_ch))
return nn.Sequential(*layers)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
def forward(self, x):
x = self.stem(x) # (B, 64, 56, 56)
x = self.layer1(x) # (B, 64, 56, 56)
x = self.layer2(x) # (B, 128, 28, 28)
x = self.layer3(x) # (B, 256, 14, 14)
x = self.layer4(x) # (B, 512, 7, 7)
x = self.pool(x).flatten(1) # (B, 512)
return self.fc(x) # (B, num_classes)
# 验证:参数量约 11.7M
model = ResNet18(num_classes=1000)
x = torch.randn(2, 3, 224, 224)
print(model(x).shape)
print(f"{sum(p.numel() for p in model.parameters())/1e6:.1f}M params")UNet
back to 目录
背景与动机
UNet 诞生于 2015 年的医学图像分割领域,面对的核心挑战是:用极少的标注数据,对每个像素进行分类。
普通分类网络通过下采样丢失了大量空间细节(“这里是什么”),而语义分割不仅要知道”是什么”,还要精确到”在哪里”。UNet 的核心设计哲学是:
下采样(Encoder) 负责理解语义(“这是一个细胞核”)
上采样(Decoder) 负责恢复位置(“它在图像的左上角”)
跳跃连接(Skip Connection) 把 Encoder 中丢失的细节直接”传送”给 Decoder
形象地说:Encoder 是近视眼(看不清细节但能看清整体),Decoder 是老花眼(需要细节帮助)。Skip Connection 就是给 Decoder 配了一副眼镜,把 Encoder 看清楚的细节直接传过去。
整体架构
输入 (B, C_in, H, W) 例如 (B, 1, 256, 256)
│
Encoder(编码器 / 下采样路径):
┌─────────────────────────────────────────────────────────┐
│ Block1: ConvBlock(C_in→64) → feat1 (B, 64, H, W) │──── skip1 ────────────────────────────────────┐
│ MaxPool ↓2 → (B, 64, H/2, W/2) │ │
│ Block2: ConvBlock(64→128) → feat2 (B, 128, H/2, W/2)│──── skip2 ──────────────────────────────┐ │
│ MaxPool ↓2 → (B, 128, H/4, W/4) │ │ │
│ Block3: ConvBlock(128→256) → feat3 (B, 256, H/4, W/4)│──── skip3 ─────────────────────────┐ │ │
│ MaxPool ↓2 → (B, 256, H/8, W/8) │ │ │ │
│ Block4: ConvBlock(256→512) → feat4 (B, 512, H/8, W/8)│──── skip4 ────────────────────┐ │ │ │
│ MaxPool ↓2 → (B, 512, H/16,W/16) │ │ │ │ │
└─────────────────────────────────────────────────────────┘ │ │ │ │
│ │ │ │ │
Bottleneck: │ │ │ │
Block5: ConvBlock(512→1024) → (B, 1024, H/16, W/16) │ │ │ │
│ │ │ │ │
Decoder(解码器 / 上采样路径): │ │ │ │
UpConv2×2 → (B, 512, H/8, W/8) │ │ │ │
Concat(skip4) → (B, 1024, H/8, W/8) ← ─────────────────────────────┘ │ │ │
ConvBlock(1024→512) → (B, 512, H/8, W/8) │ │ │
│ │ │
UpConv2×2 → (B, 256, H/4, W/4) │ │ │
Concat(skip3) → (B, 512, H/4, W/4) ← ────────────────────────────────────┘ │ │
ConvBlock(512→256) → (B, 256, H/4, W/4) │ │
... (以此类推) │ │
UpConv2×2 → (B, 64, H, W) │ │
Concat(skip1) → (B, 128, H, W) ← ─────────────────────────────────────────┘ │
ConvBlock(128→64) → (B, 64, H, W) │
│
Conv1×1 → (B, num_classes, H, W) ← 每个像素的分类结果 │
Skip Connection 的关键作用
没有跳跃连接,解码器就像在黑暗中摸索——它知道”大概在哪里”但不知道”确切边界在哪”。跳跃连接把 Encoder 里精确的位置、边缘、纹理信息直接送给 Decoder,让分割边界变得精确。
class ConvBlock(nn.Module):
"""
UNet 的基本单元:两次 Conv3×3 + BN + ReLU
选择 3×3 而非 1×1:保持感受野,充分利用空间上下文
"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
)
def forward(self, x): return self.block(x)
class UNet(nn.Module):
def __init__(self, in_channels=1, num_classes=2, features=(64,128,256,512)):
super().__init__()
# Encoder
self.encoders = nn.ModuleList()
self.pools = nn.ModuleList()
ch = in_channels
for f in features:
self.encoders.append(ConvBlock(ch, f))
self.pools.append(nn.MaxPool2d(2))
ch = f
# Bottleneck:特征维度最大,空间最小,语义最强
self.bottleneck = ConvBlock(features[-1], features[-1] * 2)
# Decoder:逐步上采样 + 融合跳跃连接
self.upconvs = nn.ModuleList()
self.decoders = nn.ModuleList()
for f in reversed(features):
# ConvTranspose2d:可学习的上采样(比双线性插值效果好)
self.upconvs.append(
nn.ConvTranspose2d(f * 2, f, kernel_size=2, stride=2)
)
# 上采样后 Concat 跳跃特征,通道数翻倍
self.decoders.append(ConvBlock(f * 2, f))
self.final = nn.Conv2d(features[0], num_classes, 1)
def forward(self, x):
skips = []
# Encoder:保存每个 stage 的特征图用于跳跃连接
for enc, pool in zip(self.encoders, self.pools):
x = enc(x)
skips.append(x) # 保存池化前的特征(细节更丰富)
x = pool(x)
x = self.bottleneck(x)
# Decoder:上采样 + 拼接 skip 特征
for up, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
x = up(x)
# 处理尺寸不整除的边角情况(如输入不是 2 的幂次方)
if x.shape != skip.shape:
x = F.interpolate(x, size=skip.shape[2:])
# Concat 而非 Add:保留两路信息,不做融合(让后续卷积自己学怎么用)
x = torch.cat([skip, x], dim=1)
x = dec(x)
return self.final(x)
# 验证
model = UNet(in_channels=1, num_classes=2)
x = torch.randn(2, 1, 256, 256)
print(model(x).shape) # (2, 2, 256, 256) 每个像素 2 分类SwinIR
back to 目录
背景与动机
SwinIR 是基于 Swin Transformer 的图像超分辨率模型(2021)。它解决了 ViT 用于图像的两个核心问题:
-
计算量问题: 标准 Self-Attention 的复杂度是 ,对 224×224 图像有 50176 个 patch,算不起。Swin 把注意力限制在小窗口内,变成 , 是窗口大小。
-
跨窗口信息交流: 但窗口内部关注,窗口之间信息孤立。Swin 用移位窗口(Shifted Window) 解决——交替使用正常窗口和偏移半个窗口大小的窗口,让相邻窗口能间接交流。
比喻: 标准注意力像全国拨打任意号码的电话网络(贵但自由),窗口注意力像只能打本地电话(便宜但受限),移位窗口像相邻区域的号码段有重叠(以较低成本建立了跨区连接)。
SwinIR 整体架构
低分辨率输入 (B, 3, H, W)
│
┌──────────────────────────────┐
│ 浅层特征提取 │
│ Conv3×3 → (B, C, H, W) │ 简单卷积,C=180
│ 作用:把图像映射到特征空间 │
└──────────────────────────────┘
│ F_shallow(保留用于最终残差)
│
┌──────────────────────────────────────────────────┐
│ 深层特征提取(6个 RSTB) │
│ │
│ RSTB (Residual Swin Transformer Block) ×6 │
│ ┌────────────────────────────────────────────┐ │
│ │ │ │
│ │ STL (Swin Transformer Layer) ×6 │ │
│ │ ┌──────────────────────────────────────┐ │ │
│ │ │ LN │ │ │
│ │ │ W-MSA(偶数层,正常窗口) │ │ │
│ │ │ 或 SW-MSA(奇数层,移位窗口) │ │ │
│ │ │ + 残差 │ │ │
│ │ │ LN │ │ │
│ │ │ FFN (MLP) │ │ │
│ │ │ + 残差 │ │ │
│ │ └──────────────────────────────────────┘ │ │
│ │ │ │
│ │ Conv3×3(局部信息增强 + 残差连接) │ │
│ └────────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────┘
│ F_deep
│
Conv3×3 + (F_shallow + F_deep) ← 浅深特征融合
│
┌──────────────────────────────┐
│ 上采样重建 │
│ Conv3×3 │
│ PixelShuffle ×scale │ ← 亚像素卷积,1/r² 的计算量生成 r² 个子像素
│ Conv3×3 │
└──────────────────────────────┘
│
高分辨率输出 (B, 3, s·H, s·W)
窗口注意力机制
输入特征图 (B, H, W, C)
│
▼
分成 nW = (H/M)×(W/M) 个不重叠 M×M 窗口
→ (B×nW, M×M, C)
│
▼
在每个窗口内独立做 Multi-Head Attention
(M=8 时,每次只算 64 个 token,而不是 H×W 个)
加上相对位置偏置(Relative Position Bias)
│
▼
窗口合并回 (B, H, W, C)
SW-MSA(Shifted Window)的做法:
原窗口边界: ┌───┬───┐ 移位后边界: ┌─┬─────┬─┐
│ A │ B │ │D│ A │B│
├───┼───┤ ├─┼─────┼─┤
│ C │ D │ │ │ │ │
└───┴───┘ │C│ │A│
└─┴─────┴─┘
原来 A/B/C/D 四个窗口互不相交,移位后出现了跨越原窗口边界的新窗口
→ A 窗口的 token 现在和 B/C/D 的边界 token 同处一个注意力计算中
→ 用 Cyclic Shift + Attention Mask 高效实现(无需 padding)
def window_partition(x, window_size):
"""(B, H, W, C) → (B*nW, M, M, C)"""
B, H, W, C = x.shape
x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
windows = x.permute(0,1,3,2,4,5).contiguous()
return windows.view(-1, window_size, window_size, C)
def window_reverse(windows, window_size, H, W):
"""window_partition 的逆操作"""
B_nW = windows.shape[0]
B = int(B_nW / (H * W / window_size / window_size))
x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
return x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = dim // n_heads
self.scale = self.d_k ** -0.5
self.W_qkv = nn.Linear(dim, 3*dim)
self.W_o = nn.Linear(dim, dim)
self.rel_bias = RelativePositionBias(window_size, n_heads)
def forward(self, x, mask=None):
# x: (B*nW, M*M, C)
B_, N, C = x.shape
qkv = self.W_qkv(x).reshape(B_, N, 3, self.n_heads, self.d_k)
q, k, v = qkv.permute(2,0,3,1,4).unbind(0)
q = q * self.scale
attn = q @ k.transpose(-2,-1)
attn = attn + self.rel_bias() # 加相对位置偏置:让模型感知窗口内的相对位置
if mask is not None:
# SW-MSA 的 mask:屏蔽不该相互关注的跨区域 token 对
nW = mask.shape[0]
attn = attn.view(B_//nW, nW, self.n_heads, N, N)
attn = attn + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.n_heads, N, N)
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1,2).reshape(B_, N, C)
return self.W_o(out)
class SwinTransformerLayer(nn.Module):
"""
一个 Swin Transformer 层
交替使用 W-MSA 和 SW-MSA:
- 偶数层(shift=False):正常窗口,学习局部内部关系
- 奇数层(shift=True):移位窗口,学习跨窗口关系
"""
def __init__(self, dim, n_heads, window_size=8, shift=False, mlp_ratio=4.0):
super().__init__()
self.shift = shift
self.window_size = window_size
self.shift_size = window_size // 2 if shift else 0
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size, n_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim)
)
def forward(self, x, H, W):
B, L, C = x.shape
x_2d = x.view(B, H, W, C)
# Cyclic Shift:把特征图整体滚动,让边界处的 token 聚集到中央的同一个窗口
if self.shift_size > 0:
x_2d = torch.roll(x_2d, (-self.shift_size, -self.shift_size), dims=(1,2))
windows = window_partition(x_2d, self.window_size)
windows = windows.view(-1, self.window_size**2, C)
attn_out = self.attn(self.norm1(windows))
attn_out = attn_out.view(-1, self.window_size, self.window_size, C)
x_2d = window_reverse(attn_out, self.window_size, H, W)
# 反向滚动还原位置
if self.shift_size > 0:
x_2d = torch.roll(x_2d, (self.shift_size, self.shift_size), dims=(1,2))
x = x + x_2d.view(B, L, C) # 残差
x = x + self.mlp(self.norm2(x))
return x
class RSTB(nn.Module):
"""
Residual Swin Transformer Block
= N 个 SwinTransformerLayer + Conv3×3 + 残差
最后的 Conv3×3 增强局部建模能力(Transformer 偏全局,Conv 偏局部,互补)
"""
def __init__(self, dim, n_heads, n_layers=6, window_size=8):
super().__init__()
self.layers = nn.ModuleList([
SwinTransformerLayer(dim, n_heads, window_size, shift=(i%2==1))
for i in range(n_layers)
])
self.norm = nn.LayerNorm(dim)
self.conv = nn.Conv2d(dim, dim, 3, padding=1)
def forward(self, x, H, W):
res = x
for layer in self.layers:
x = layer(x, H, W)
x = self.norm(x)
x = x.transpose(1,2).view(-1, x.shape[-1], H, W)
x = self.conv(x).flatten(2).transpose(1,2)
return x + res # 残差:跨越整个 RSTB 的捷径
class SwinIR(nn.Module):
def __init__(self, in_ch=3, dim=60, n_heads=6, n_rstb=6,
window_size=8, scale=4):
super().__init__()
self.shallow = nn.Conv2d(in_ch, dim, 3, padding=1)
self.deep = nn.ModuleList([
RSTB(dim, n_heads, window_size=window_size) for _ in range(n_rstb)
])
self.deep_norm = nn.LayerNorm(dim)
self.deep_conv = nn.Conv2d(dim, dim, 3, padding=1)
self.upsample = nn.Sequential(
nn.Conv2d(dim, dim * scale**2, 3, padding=1),
nn.PixelShuffle(scale), # 亚像素卷积:把通道维的信息"折叠"到空间维
nn.Conv2d(dim, in_ch, 3, padding=1)
)
def forward(self, x):
B, C, H, W = x.shape
feat = self.shallow(x)
deep = feat.flatten(2).transpose(1,2) # (B, H*W, dim)
for rstb in self.deep:
deep = rstb(deep, H, W)
deep = self.deep_norm(deep)
deep = deep.transpose(1,2).view(B, -1, H, W)
deep = self.deep_conv(deep) + feat # 深浅特征相加
return self.upsample(deep)
model = SwinIR(scale=4)
x = torch.randn(1, 3, 64, 64)
print(model(x).shape) # (1, 3, 256, 256) 4×超分辨率CLIP
back to 目录
背景与动机
传统图像识别需要大量手工标注,而且每次换任务就要重新训练。CLIP(Contrastive Language-Image Pre-Training,2021)的核心洞察是:互联网上有海量的图文对(图片+描述),这是天然的监督信号。
CLIP 的训练思路极为简单而优雅:
- 给一批图像和对应的文本描述
- 训练两个编码器,使配对的图文特征尽量相近,不配对的尽量远离
- 这就是对比学习(Contrastive Learning)
零样本推理能力是 CLIP 最神奇的特性:训练时没见过”猫”这个类别,推理时只需构建文本 “a photo of a cat”,让图像特征与各类别文本特征比较相似度即可分类。模型学到的是通用的图文对齐空间,而不是特定类别的分类器。
整体架构
训练时:一批 N 对 (图像_i, 文本_i)
图像端 文本端
(B, 3, 224, 224) (B, 77) ← 最多 77 个 token
│ │
ViT-B/16 图像编码器 Transformer 文本编码器
┌─────────────────┐ ┌──────────────────┐
│ 分成 14×14 个 │ │ Token Embedding │
│ 16×16 的 patch │ │ + Positional Emb │
│ 线性投影+CLS │ │ │
│ Transformer ×12 │ │ Transformer ×12 │
│ 取 CLS token │ │ 取 EOS token │
└─────────────────┘ └──────────────────┘
│ │
Linear 投影 Linear 投影
│ │
image_feat (B, 512) text_feat (B, 512)
│ │
└──── L2 归一化 ─────┐ ┌──── L2 归一化 ────┘
│ │
相似度矩阵 = image_feat @ text_feat.T * exp(τ)
→ (B, B) 矩阵,第 i 行第 j 列 = 第 i 张图和第 j 段文本的相似度
→ 对角线是正样本,其余是负样本
对比损失(InfoNCE):
行方向 CE:每张图 → 正确文本(共 B 个负样本)
列方向 CE:每段文本 → 正确图像
最终 loss = (行CE + 列CE) / 2
温度参数 τ 的作用:控制分布的”锐利程度”。τ 小→分布尖锐,区分度高;τ 大→分布平缓,训练更稳定。τ 本身也是可学习的。
class ViTImageEncoder(nn.Module):
"""
CLIP 图像编码器:Vision Transformer (ViT)
核心操作:把图像切成 patch → 当作 token 序列 → Transformer 处理
"""
def __init__(self, img_size=224, patch_size=16, d_model=768,
n_heads=12, n_layers=12, embed_dim=512):
super().__init__()
n_patches = (img_size // patch_size) ** 2 # 14×14 = 196 个 patch
# 用卷积实现 patch embedding(等价于把每个 patch 展平再做 Linear)
self.patch_embed = nn.Conv2d(3, d_model, patch_size, stride=patch_size)
# [CLS] token:汇聚全局信息,最终用它来表示整张图
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
# 可学习位置编码(ViT 使用)
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
dropout=0.0, activation='gelu', batch_first=True, norm_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
self.ln = nn.LayerNorm(d_model)
self.proj = nn.Linear(d_model, embed_dim, bias=False)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x).flatten(2).transpose(1,2) # (B, N, D)
cls = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1) + self.pos_embed # (B, N+1, D)
x = self.transformer(x)
x = self.ln(x[:, 0]) # 取 CLS token
return self.proj(x) # (B, embed_dim)
class TextEncoder(nn.Module):
"""
CLIP 文本编码器:因果 Transformer(使用 causal mask)
取 EOS token 的输出作为整个文本的表示(而不是 CLS)
这是 CLIP 的独特选择,EOS 在生成过程中看到了所有 token,信息最完整
"""
def __init__(self, vocab_size=49408, ctx_len=77, d_model=512,
n_heads=8, n_layers=12, embed_dim=512):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Parameter(torch.zeros(ctx_len, d_model))
decoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
dropout=0.0, activation='gelu', batch_first=True, norm_first=True
)
self.transformer = nn.TransformerEncoder(decoder_layer, n_layers)
self.ln = nn.LayerNorm(d_model)
self.proj = nn.Linear(d_model, embed_dim, bias=False)
def forward(self, tokens):
x = self.token_embed(tokens) + self.pos_embed[:tokens.shape[1]]
x = self.transformer(x)
x = self.ln(x)
eot_idx = tokens.argmax(dim=-1) # EOS token 的位置
x = x[torch.arange(x.shape[0]), eot_idx]
return self.proj(x)
class CLIP(nn.Module):
def __init__(self, embed_dim=512):
super().__init__()
self.image_encoder = ViTImageEncoder(embed_dim=embed_dim)
self.text_encoder = TextEncoder(embed_dim=embed_dim)
# 可学习温度:初始化为 1/0.07 ≈ 14.3,在 log 空间学习(保证正值)
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1/0.07))
def encode_image(self, image):
# L2 归一化:把特征映射到单位超球面上,相似度 = 余弦相似度
return F.normalize(self.image_encoder(image), dim=-1)
def encode_text(self, tokens):
return F.normalize(self.text_encoder(tokens), dim=-1)
def forward(self, image, tokens):
img_feat = self.encode_image(image) # (B, D)
text_feat = self.encode_text(tokens) # (B, D)
scale = self.logit_scale.exp()
# (B, B) 矩阵:每对图文的余弦相似度 × 温度
logits = scale * img_feat @ text_feat.T
# 标签就是对角线索引(第 i 张图对应第 i 段文本)
labels = torch.arange(logits.shape[0], device=logits.device)
loss_i = F.cross_entropy(logits, labels) # 图匹配文
loss_t = F.cross_entropy(logits.T, labels) # 文匹配图
return (loss_i + loss_t) / 2
@torch.no_grad()
def zero_shot_classify(self, image, class_prompts_tokens):
"""
零样本分类:不需要额外训练,用文本描述代替类别标签
例如:["a photo of a cat", "a photo of a dog", ...]
"""
img_feat = self.encode_image(image)
text_feats = self.encode_text(class_prompts_tokens) # (C, D)
probs = (img_feat @ text_feats.T * self.logit_scale.exp()).softmax(-1)
return probs # (B, C)LLaVA
back to 目录
背景与动机
LLaVA(Large Language and Vision Assistant,2023)的核心问题是:如何让已经很强的 LLM 拥有”看图说话”的能力?
最暴力的想法是从头联合训练图像理解和语言生成,但代价极高。LLaVA 的优雅方案是:CLIP 已经学会了理解图像,LLM 已经学会了理解和生成语言,我只需要一个”翻译器”把图像特征转换到 LLM 能理解的空间。
这个”翻译器”就是 MLP Projector——两层线性变换而已,轻量但有效。
整体架构
图像 (B, 3, 336, 336)
│
CLIP ViT-L/14@336px
(训练时通常冻结,已具备强大图像理解能力)
│ 取所有 patch token(去掉 CLS)
(B, 576, 1024) ← 576 = 24×24 个 patch,每个 1024 维
│
MLP Projector(可学习!这是两个模态的"桥梁")
Linear(1024 → 4096) → GELU → Linear(4096 → 4096)
│ (B, 576, 4096) ← 对齐到 LLaMA 的 hidden size
│
┌──────────────────────────────────────────────────────┐
│ LLaMA / Vicuna 7B(指令微调的 LLaMA) │
│ │
│ 输入序列构建: │
│ [系统提示] [USER:] [576个视觉token] [文字问题] [ASSISTANT:] │
│ │
│ word tokens → Embedding(vocab, 4096) │
│ + 视觉 token(从 Projector 来)直接拼接 │
│ ↓ │
│ 合并序列 (B, 576 + T_text, 4096) │
│ ↓ │
│ LLaMA Transformer ×32 层 │
│ ↓ │
│ 输出 logits,预测下一个 token │
└──────────────────────────────────────────────────────┘
两阶段训练策略:
Stage 1 - 特征对齐(约 60 万图文对,1 epoch):
✅ 训练 MLP Projector
❄️ 冻结 CLIP ViT
❄️ 冻结 LLM
目标:让 Projector 学会将视觉特征翻译成 LLM 的"语言"
Stage 2 - 视觉对话微调(约 15 万多轮对话,1 epoch):
✅ 训练 MLP Projector
❄️ 冻结 CLIP ViT
✅ 训练 LLM(LoRA 或全参数)
目标:让 LLM 学会用视觉信息回答问题
class MLPProjector(nn.Module):
"""
视觉-语言桥梁模块
设计简单但关键:把 CLIP 的视觉空间映射到 LLM 的 token 空间
LLaVA-1.5 发现两层 MLP 比单层线性效果好很多
"""
def __init__(self, vision_dim=1024, llm_dim=4096):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(vision_dim, llm_dim),
nn.GELU(),
nn.Linear(llm_dim, llm_dim),
)
def forward(self, x):
return self.proj(x) # (B, num_patches, llm_dim)
class LLaVA(nn.Module):
"""
LLaVA 整体结构
核心设计:视觉 token 和文本 token 在 LLM 输入层拼接,
之后完全用统一的 Transformer 处理,不需要特殊的 Cross-Attention
"""
def __init__(self, vision_encoder, llm, projector,
image_token_id: int = -200):
super().__init__()
self.vision_encoder = vision_encoder
self.projector = projector
self.llm = llm
self.image_token_id = image_token_id # 特殊占位符 token 的 id
def get_visual_tokens(self, images):
with torch.no_grad():
# 取 patch tokens(去掉 CLS),保留空间信息
vision_feats = self.vision_encoder(images)
patch_feats = vision_feats[:, 1:] # (B, N_patch, D_vision)
return self.projector(patch_feats) # (B, N_patch, D_llm)
def forward(self, input_ids, images, attention_mask=None, labels=None):
"""
input_ids: (B, T) 文本 token,其中 <IMAGE> 位置为特殊 token id
images: (B, 3, H, W)
典型的 input_ids 结构(展开后):
[SYSTEM] [USER:] [<IMAGE>×576] [问题文字 tokens] [ASSISTANT:]
"""
# 1. 文本 token 的 embedding
inputs_embeds = self.llm.get_input_embeddings()(input_ids) # (B, T, D)
# 2. 视觉 token 的 embedding
vis_tokens = self.get_visual_tokens(images) # (B, N, D)
# 3. 将视觉 embedding 替换掉 <IMAGE> 占位符的位置
image_mask = (input_ids == self.image_token_id)
for b in range(input_ids.shape[0]):
img_pos = image_mask[b].nonzero(as_tuple=True)[0]
n_img = vis_tokens.shape[1]
inputs_embeds[b, img_pos[:n_img]] = vis_tokens[b]
# 4. 整体送入 LLM,像正常文本一样处理
outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels # 只在 ASSISTANT 的回答部分计算 loss
)
return outputsLLaMA
back to 目录
背景与动机
LLaMA(2023,Meta)是开源 LLM 领域的里程碑,其设计哲学是:在同等参数量下,用更多的数据和更精心的架构设计,超越更大的模型。
LLaMA 相对原始 Transformer 的主要改进可以概括为四点:
| 改进点 | 原始 Transformer | LLaMA |
|---|---|---|
| 归一化位置 | Post-LN(层后归一化) | Pre-RMSNorm(层前,去掉均值) |
| 激活函数 | ReLU / GELU | SwiGLU(门控激活) |
| 位置编码 | 绝对位置编码 | RoPE(旋转位置编码) |
| 注意力 | MHA(多头注意力) | GQA(分组查询注意力) |
整体架构
输入 token ids: (B, T)
│
Embedding(vocab_size=32000, dim=4096)
│ (B, T, 4096)
│
┌───────────────────────────────────────────────────────┐
│ Transformer Block ×32 │
│ │
│ x ──► RMSNorm ──► Self-Attention(RoPE + GQA) ──►──┐ │
│ │ + │ │
│ └──────────────────────────────────────────────────┘ │
│ ↓ │
│ x ──► RMSNorm ──► SwiGLU FFN ──────────────────────┐ │
│ │ + │ │
│ └───────────────────────────────────────────────────┘ │
└───────────────────────────────────────────────────────┘
│
RMSNorm(最后一层)
│
Linear(4096, 32000) = lm_head
(与 Embedding 共享权重,节省参数)
│
logits: (B, T, 32000)
Pre-RMSNorm 的意义
原始 Transformer 在 Attention/FFN 之后归一化(Post-LN),训练不稳定,需要 warmup 很长时间。
Pre-LN(LN 在 Attention/FFN 之前)训练更稳定,梯度流更好。
RMSNorm 去掉了均值项,计算更快,效果相当。
SwiGLU 的意义
普通 FFN:
SwiGLU:
多了一个门控机制: 决定”开多大门”, 是信号。这让模型能更灵活地选择性地传递信息,实践中比 GELU/ReLU 效果更好。
GQA 的意义(Grouped Query Attention)
标准多头注意力:32 个 Q head,32 个 K/V head → 显存占用大
GQA(LLaMA-2 采用):32 个 Q head,只有 8 个 K/V head,每 4 个 Q 共享 1 对 K/V
→ K/V Cache 的显存缩减 4 倍,推理速度大幅提升,精度损失很小
class RMSNorm(nn.Module):
"""
Root Mean Square Normalization
相比 LayerNorm 去掉了减均值的步骤,计算量更小,效果相当
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
# 在 float32 精度下计算,避免 fp16 溢出
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return (x.float() / norm).type_as(x) * self.weight
class SwiGLU(nn.Module):
"""
SwiGLU FFN:比 FFN 多了一个门控路径
hidden_dim 通常为 int(8/3 * dim),然后对齐到 256 的倍数
"""
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # gate
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # down
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # up
def forward(self, x):
# SiLU(x) = x * sigmoid(x)(Swish 激活)
# 门控:SiLU(gate) ⊙ up,决定信息流量
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class LlamaAttention(nn.Module):
"""
LLaMA Attention:GQA + RoPE + 因果 mask
n_kv_heads < n_heads 时启用 GQA
"""
def __init__(self, dim: int, n_heads: int, n_kv_heads: int = None,
max_seq_len: int = 4096, base: float = 10000.0):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads or n_heads
self.n_rep = n_heads // self.n_kv_heads # 每个 KV head 被几个 Q head 共享
self.d_k = dim // n_heads
self.Wq = nn.Linear(dim, n_heads * self.d_k, bias=False)
self.Wk = nn.Linear(dim, self.n_kv_heads * self.d_k, bias=False)
self.Wv = nn.Linear(dim, self.n_kv_heads * self.d_k, bias=False)
self.Wo = nn.Linear(n_heads * self.d_k, dim, bias=False)
freqs = precompute_freqs_cis(self.d_k, max_seq_len, base)
self.register_buffer('freqs_cis', freqs)
def forward(self, x, start_pos=0):
B, T, _ = x.shape
q = self.Wq(x).view(B, T, self.n_heads, self.d_k)
k = self.Wk(x).view(B, T, self.n_kv_heads, self.d_k)
v = self.Wv(x).view(B, T, self.n_kv_heads, self.d_k)
# 旋转位置编码:只旋转 Q 和 K,不动 V
freqs = self.freqs_cis[start_pos: start_pos + T]
q, k = apply_rotary_emb(q, k, freqs)
# GQA:把 K/V 的 head 数扩展到和 Q 一样(repeat_interleave)
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=2)
v = v.repeat_interleave(self.n_rep, dim=2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_k)
# 上三角 mask(不含对角线):未来 token 不可见
mask = torch.triu(torch.full((T, T), float('-inf'), device=x.device), 1)
scores = scores + mask
out = F.softmax(scores, dim=-1) @ v
out = out.transpose(1,2).contiguous().view(B, T, -1)
return self.Wo(out)
class LlamaBlock(nn.Module):
"""Pre-Norm Transformer Block"""
def __init__(self, dim=4096, n_heads=32, n_kv_heads=8, ffn_hidden=11008):
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = LlamaAttention(dim, n_heads, n_kv_heads)
self.norm2 = RMSNorm(dim)
self.ffn = SwiGLU(dim, ffn_hidden)
def forward(self, x, start_pos=0):
# Pre-Norm:先归一化再运算,残差连接跨越归一化
x = x + self.attn(self.norm1(x), start_pos)
x = x + self.ffn(self.norm2(x))
return x
class LLaMA(nn.Module):
def __init__(self, vocab_size=32000, dim=4096, n_layers=32,
n_heads=32, n_kv_heads=8, ffn_hidden=11008, max_seq_len=4096):
super().__init__()
self.embed = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList([
LlamaBlock(dim, n_heads, n_kv_heads, ffn_hidden)
for _ in range(n_layers)
])
self.norm = RMSNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
# 权重绑定:embedding 矩阵和 lm_head 共享同一份参数
# 直觉:输入和输出都是词表空间,用同一套表示更一致,且节省约 500M 参数
self.lm_head.weight = self.embed.weight
def forward(self, tokens, start_pos=0):
x = self.embed(tokens)
for layer in self.layers:
x = layer(x, start_pos)
return self.lm_head(self.norm(x)) # (B, T, vocab_size)
@torch.no_grad()
def generate(self, prompt_ids, max_new_tokens=100, temperature=1.0):
"""自回归生成:每次预测下一个 token,拼到序列尾部,循环"""
self.eval()
x = prompt_ids
for _ in range(max_new_tokens):
logits = self.forward(x)[:, -1, :] # 只取最后一步的预测
logits = logits / temperature
next_t = torch.multinomial(F.softmax(logits, -1), 1)
x = torch.cat([x, next_t], dim=1)
return x
# 小型验证
model = LLaMA(vocab_size=1000, dim=256, n_layers=4,
n_heads=8, n_kv_heads=4, ffn_hidden=512)
x = torch.randint(0, 1000, (2, 16))
print(model(x).shape) # (2, 16, 1000)Qwen3
back to 目录
背景与动机
Qwen3(2025,阿里通义)是当前(2025年初)最强的开源多语言 LLM 之一,在架构上相对 LLaMA 做了若干精细化改进。
Qwen3 的整体思路是:站在 LLaMA 的肩膀上,把每个细节都再优化一遍。核心改进:
- Q/K 加 RMSNorm:在 RoPE 之前对 Q 和 K 各做一次 RMSNorm,防止注意力分数数值不稳定(特别是在长上下文场景)
- 更大的 RoPE base(θ=1,000,000):LLaMA-2 用 10000,Qwen3 用 1,000,000。更大的 base 意味着旋转”更慢”,保留更多低频位置信息,支持更长的上下文(128K+)
- MoE 版本(Mixture of Experts):Qwen3-MoE 系列用 64 个专家 FFN,每次只激活 8 个,在保持参数量大(表达力强)的同时,计算量只是密集版的一小部分
Qwen3 vs LLaMA 架构差异
LLaMA Qwen3
归一化 RMSNorm(Pre-Norm) RMSNorm(Pre-Norm)✓ 相同
RoPE base 10,000 1,000,000 ✦ 更大
Q/K 偏置 无 bias Q/K 有 bias ✦ 新增
Q/K Norm 无 RMSNorm(Q), RMSNorm(K) ✦ 新增
GQA 部分版本 全系列标配 ✓ 相同
FFN SwiGLU SwiGLU 或 MoE ✦ MoE 版新增
MoE FFN 架构
输入 x: (B, T, D)
│
Router: Linear(D, n_experts=64)
→ softmax → 选出每个 token 的 Top-8 专家 + 分配权重
│
┌─────┬─────┬─────┬─────────────────────────┐
│FFN_0│FFN_1│FFN_2│ ... FFN_63 │ ← 64 个 SwiGLU 专家
└─────┴─────┴─────┴─────────────────────────┘
每个 token 只经过被选中的 8 个专家(稀疏激活!)
│
加权求和(router 分配的权重)
│
+ 共享专家(Shared Expert,每个 token 都过)
│
输出 (B, T, D)
优点:
✦ 总参数量 = 64 × FFN 大小(容量大,知识丰富)
✦ 每步计算量 = 8 × FFN 大小(推理快,显存占用低)
✦ 不同专家自然学习到不同领域的知识(语言、代码、数学...)
class Qwen3Attention(nn.Module):
"""
Qwen3 的核心改进:
1. Q/K 带 bias(微小但有效的改进)
2. Q/K 在 RoPE 前各过一次 RMSNorm(稳定长上下文的 attention score)
3. RoPE base = 1,000,000(支持 128K+ 上下文)
"""
def __init__(self, dim=4096, n_heads=32, n_kv_heads=8,
max_seq_len=32768, rope_base=1_000_000):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.d_k = dim // n_heads
# Q/K 有 bias,V 没有 bias(Qwen3 的选择)
self.Wq = nn.Linear(dim, n_heads * self.d_k, bias=True)
self.Wk = nn.Linear(dim, n_kv_heads * self.d_k, bias=True)
self.Wv = nn.Linear(dim, n_kv_heads * self.d_k, bias=False)
self.Wo = nn.Linear(n_heads * self.d_k, dim, bias=False)
# 对每个 head 的 Q/K 独立做 RMSNorm
# 动机:防止极长序列中 QK^T 的值过大,导致注意力坍塌(只关注极少数位置)
self.q_norm = RMSNorm(self.d_k)
self.k_norm = RMSNorm(self.d_k)
freqs = precompute_freqs_cis(self.d_k, max_seq_len, theta=rope_base)
self.register_buffer('freqs_cis', freqs)
def forward(self, x, start_pos=0):
B, T, _ = x.shape
q = self.Wq(x).view(B, T, self.n_heads, self.d_k)
k = self.Wk(x).view(B, T, self.n_kv_heads, self.d_k)
v = self.Wv(x).view(B, T, self.n_kv_heads, self.d_k)
# 先归一化,再旋转——让 Q/K 的模长稳定在合理范围
q = self.q_norm(q)
k = self.k_norm(k)
freqs = self.freqs_cis[start_pos: start_pos + T]
q, k = apply_rotary_emb(q, k, freqs)
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=2)
v = v.repeat_interleave(self.n_rep, dim=2)
q = q.transpose(1,2); k = k.transpose(1,2); v = v.transpose(1,2)
scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_k)
mask = torch.triu(torch.full((T,T), float('-inf'), device=x.device), 1)
out = F.softmax(scores + mask, dim=-1) @ v
out = out.transpose(1,2).contiguous().view(B, T, -1)
return self.Wo(out)
class MoEFFN(nn.Module):
"""
Mixture of Experts FFN(Qwen3-MoE 使用)
稀疏激活:每个 token 只走 n_active 个专家,大幅降低计算量
"""
def __init__(self, dim: int, n_experts: int = 64, n_active: int = 8,
ffn_hidden: int = 2048, n_shared: int = 1):
super().__init__()
self.router = nn.Linear(dim, n_experts, bias=False)
self.n_active = n_active
# 稀疏专家:每次只激活部分
self.experts = nn.ModuleList([SwiGLU(dim, ffn_hidden) for _ in range(n_experts)])
# 共享专家:每个 token 都过,提供稳定的基础表示
# 设计灵感:不是所有知识都需要专家化,通用能力用共享专家维持
self.shared_experts = nn.ModuleList([SwiGLU(dim, ffn_hidden) for _ in range(n_shared)])
def forward(self, x):
B, T, D = x.shape
x_flat = x.view(-1, D)
# 路由:为每个 token 选出 top-k 个专家
logits = self.router(x_flat)
weights, indices = torch.topk(logits, self.n_active, dim=-1)
weights = F.softmax(weights, dim=-1) # 归一化权重
# 稀疏计算(生产实现用 scatter/expert-parallel,这里用循环示意)
out = torch.zeros_like(x_flat)
for k in range(self.n_active):
expert_id = indices[:, k]
w = weights[:, k].unsqueeze(-1)
for i, expert in enumerate(self.experts):
mask = (expert_id == i)
if mask.any():
out[mask] += w[mask] * expert(x_flat[mask])
# 共享专家(每个 token 都走,补充通用能力)
for shared in self.shared_experts:
out = out + shared(x_flat)
return out.view(B, T, D)
class Qwen3Block(nn.Module):
def __init__(self, dim=4096, n_heads=32, n_kv_heads=8,
ffn_hidden=11008, use_moe=False):
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = Qwen3Attention(dim, n_heads, n_kv_heads)
self.norm2 = RMSNorm(dim)
self.ffn = MoEFFN(dim, ffn_hidden=ffn_hidden) if use_moe \
else SwiGLU(dim, ffn_hidden)
def forward(self, x, start_pos=0):
x = x + self.attn(self.norm1(x), start_pos)
x = x + self.ffn(self.norm2(x))
return x
class Qwen3(nn.Module):
def __init__(self, vocab_size=151936, dim=4096, n_layers=32,
n_heads=32, n_kv_heads=8, ffn_hidden=11008, use_moe=False):
super().__init__()
self.embed = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList([
Qwen3Block(dim, n_heads, n_kv_heads, ffn_hidden, use_moe)
for _ in range(n_layers)
])
self.norm = RMSNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
def forward(self, tokens, start_pos=0):
x = self.embed(tokens)
for layer in self.layers:
x = layer(x, start_pos)
return self.lm_head(self.norm(x))
# 小型验证
model = Qwen3(vocab_size=1000, dim=256, n_layers=2,
n_heads=8, n_kv_heads=4, ffn_hidden=512)
x = torch.randint(0, 1000, (2, 16))
print(model(x).shape) # (2, 16, 1000)架构对比速查表
back to 目录
整体架构对比
| 模型 | 类型 | 核心创新 | 注意力机制 | 位置编码 | 归一化 | 激活函数 | 主要用途 |
|---|---|---|---|---|---|---|---|
| ResNet18 | CNN | 残差连接 | 无 | 无 | BatchNorm | ReLU | 图像分类 |
| UNet | CNN+跳连 | 编解码+跳连 | 无 | 无 | BatchNorm | ReLU | 图像分割/生成 |
| SwinIR | 层级 ViT | 窗口注意力 | W-MSA / SW-MSA | 相对位置偏置 | LayerNorm | GELU | 图像复原/超分 |
| CLIP | 双塔 | 对比学习 | MHSA | Learnable | LayerNorm | GELU | 图文对齐 |
| LLaVA | 多模态 | 模态对齐 | MHSA | RoPE | RMSNorm | SwiGLU | 视觉问答 |
| LLaMA | 纯解码器 | 高效 LLM | GQA + 因果mask | RoPE(1e4) | RMSNorm | SwiGLU | 文本生成 |
| Qwen3 | 纯解码器 | Q/K-Norm + MoE | GQA + 因果mask | RoPE(1e6) | RMSNorm | SwiGLU | 多语言/多模态 |
位置编码选择指南
| 场景 | 推荐编码 | 原因 |
|---|---|---|
| 短序列分类(BERT 风格) | Learnable PE | 简单有效,不需外推 |
| 图像 patch(ViT/CLIP) | Learnable PE | 2D 位置,学习更直接 |
| 图像局部窗口(Swin) | 相对位置偏置 | 只需感知窗口内相对距离 |
| 长文本生成(LLM) | RoPE | 天然相对位置,外推好 |
| 超长上下文(128K+) | RoPE(大 base) | 更大 theta,低频信息更丰富 |
| 极长外推(MPT 等) | ALiBi | 线性衰减,外推最稳 |
注意力机制复杂度对比
设序列长度 N,窗口大小 M,专家数 E,激活专家数 K
MHA(标准多头注意力): O(N²·d) ← N 大时极慢
GQA(分组查询注意力): O(N²·d) ← 同 MHA,但 KV Cache 显存少
Window Attention: O(N·M²·d) ← M 固定,随 N 线性增长
MoE FFN: O(N·K/E·d_ffn) ← K/E 通常很小(8/64=12.5%)
参考资料:timm · HuggingFace Transformers · 各模型原论文