神经网络四大范式详解
CNN · GAN · Auto-Regressive · Diffusion
训练与推理流程 · 损失函数 · 输入输出形状 · 完整代码
目录
CNN(端到端学习)
核心思想
CNN 是最”直白”的范式:给一个输入 X,直接预测一个输出 Y,用损失函数衡量预测的好坏,梯度反传更新参数。这种”输入→网络→输出→loss→反传”的完整链条,就是所谓的”端到端(End-to-End)”。
比喻: CNN 像是一个学生直接背答案。你给他一道题(图像),他直接给你一个答案(类别/分割图),老师(loss)告诉他答得有多差,他就调整自己的”答题方式”。全程没有中间人。
与后面的生成模型不同,CNN 的输出空间是确定性的——同一张图,每次推理结果一样。这是判别式模型(Discriminative Model)的典型特征。
三种典型任务
图像分类(Classification):
输入: (B, 3, H, W) ─────► 网络 ─────► (B, C) → CrossEntropy Loss
一张图 C 个类别的 logits
语义分割(Segmentation):
输入: (B, 3, H, W) ─────► 网络 ─────► (B, C, H, W) → CrossEntropy / Dice Loss
一张图 每个像素的 C 类概率
图像超分(Super Resolution):
输入: (B, 3, h, w) ─────► 网络 ─────► (B, 3, H, W) → L1 / L2 / Perceptual Loss
低分辨率 高分辨率
H = s·h, W = s·w (s 为放大倍数)
训练与推理流程
训练流程:
① 采样 mini-batch: x (B, 3, H, W), y_true (B,) 或 (B, C, H, W)
② 前向传播: y_pred = model(x)
③ 计算损失: loss = criterion(y_pred, y_true)
④ 反向传播: loss.backward()
⑤ 参数更新: optimizer.step()
推理流程(确定性,无随机性):
① with torch.no_grad():
② y_pred = model(x)
③ 后处理(argmax / softmax / 上采样等)
常用损失函数
| 任务 | 损失函数 | 形式 |
|---|---|---|
| 分类 | Cross-Entropy | |
| 分割(类别不均衡) | Dice Loss | |
| 回归/重建 | MSE (L2) | |
| 回归/重建(鲁棒) | MAE (L1) | |
| 超分/风格 | Perceptual Loss | , 为预训练 VGG |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
# ─────────────────────────────────────────────
# 1.1 图像分类(端到端最简示例)
# ─────────────────────────────────────────────
class SimpleCNNClassifier(nn.Module):
"""
输入: (B, 3, 32, 32)
输出: (B, num_classes) ← logits(未经 softmax)
"""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(2), # (B, 64, 16, 16)
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.AdaptiveAvgPool2d(4), # (B, 128, 4, 4)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 4 * 4, 256), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.classifier(self.features(x))
def train_classifier(model, train_loader, epochs=5, lr=1e-3):
"""端到端分类训练循环"""
device = next(model.parameters()).device
criterion = nn.CrossEntropyLoss() # 内部含 softmax,输入 logits
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
for epoch in range(epochs):
model.train()
total_loss, correct, total = 0.0, 0, 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# ── 标准端到端步骤 ──
y_pred = model(x) # (B, C) logits
loss = criterion(y_pred, y) # scalar
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ──────────────────
total_loss += loss.item() * x.size(0)
correct += (y_pred.argmax(1) == y).sum().item()
total += x.size(0)
scheduler.step()
print(f"Epoch {epoch+1}: loss={total_loss/total:.4f}, "
f"acc={correct/total:.4f}")
@torch.no_grad()
def inference_classifier(model, x):
"""推理:确定性,无随机"""
model.eval()
logits = model(x) # (B, C)
probs = F.softmax(logits, dim=-1) # (B, C) 概率
preds = logits.argmax(dim=-1) # (B,) 预测类别
return preds, probs
# ─────────────────────────────────────────────
# 1.2 语义分割(像素级分类)
# ─────────────────────────────────────────────
class DiceLoss(nn.Module):
"""
Dice Loss:对类别不均衡鲁棒(医学图像分割常用)
输入 pred: (B, C, H, W) 经 softmax 后的概率
输入 target: (B, H, W) 整数类别标签
"""
def __init__(self, smooth=1.0):
super().__init__()
self.smooth = smooth
def forward(self, pred, target):
C = pred.shape[1]
# one-hot 编码 target
target_oh = F.one_hot(target, C).permute(0,3,1,2).float() # (B, C, H, W)
pred = pred.softmax(dim=1)
# 按 batch 和 class 计算 Dice
intersection = (pred * target_oh).sum(dim=(2,3)) # (B, C)
union = pred.sum(dim=(2,3)) + target_oh.sum(dim=(2,3))
dice = (2 * intersection + self.smooth) / (union + self.smooth)
return 1 - dice.mean()
class PerceptualLoss(nn.Module):
"""
感知损失(Perceptual Loss):用预训练 VGG 提取特征,
在特征空间而非像素空间度量重建质量
适用于超分辨率、风格迁移
"""
def __init__(self):
super().__init__()
import torchvision.models as models
vgg = models.vgg16(pretrained=False)
# 取前 16 层(relu3_3 之前的特征)
self.features = nn.Sequential(*list(vgg.features.children())[:16])
for p in self.features.parameters():
p.requires_grad = False # 冻结,不参与训练
def forward(self, pred, target):
# pred, target: (B, 3, H, W),假设已归一化到 [0,1]
feat_pred = self.features(pred)
feat_target = self.features(target.detach())
return F.mse_loss(feat_pred, feat_target)
# ── 快速验证 ──
model = SimpleCNNClassifier(num_classes=10)
# 构造假数据
x = torch.randn(64, 3, 32, 32)
y = torch.randint(0, 10, (64,))
ds = TensorDataset(x, y)
loader = DataLoader(ds, batch_size=16)
train_classifier(model, loader, epochs=2)2. GAN(生成对抗网络)
核心思想
GAN 的思想来自一个博弈论比喻:造假币的人(Generator)vs 验钞员(Discriminator)。
- Generator(G):接收一个随机噪声 ,生成假图像,目标是骗过 D
- Discriminator(D):接收真实图像或假图像,判断真假,目标是不被骗
两者在训练中相互博弈、共同进化:G 越来越擅长造假,D 越来越擅长辨别,最终 G 生成的图像以假乱真(D 的输出趋近于 0.5,无法判断真假)。
关键洞见: 我们无法直接写出”真实图像的分布长什么样”,但我们可以通过对抗训练让 G 隐式地学到这个分布。G 永远不需要知道真实数据长什么样,只需要知道 D 给了多少分。
训练流程(双人博弈,交替更新)
真实数据 x_real ~ p_data 噪声 z ~ N(0, I)
│ │
│ Generator G │
│ z → G(z) = x_fake │
│ │ │
▼ ▼
Discriminator D 接收两类输入:
D(x_real) → 希望接近 1(真的打高分)
D(x_fake) → 希望接近 0(假的打低分)
─── 第一步:更新 D(固定 G)───────────────
loss_D = -[log D(x_real) + log(1 - D(G(z)))]
目标:最大化辨别能力(最小化 loss_D)
─── 第二步:更新 G(固定 D)───────────────
loss_G = -log D(G(z)) ← 非饱和版本(实践常用)
或 log(1 - D(G(z))) ← 原始论文版(梯度小,难训练)
目标:让 D 把假图打高分(骗过 D)
注意:两步交替,不同时更新!
训练不稳定问题与改进
原始 GAN 的问题:
✗ 梯度消失:D 太强时,D(G(z))≈0,log(1-D(G(z)))≈0,G 无梯度
✗ 模式崩溃(Mode Collapse):G 只生成一种图像就能骗过 D
✗ 训练不稳定,对超参敏感
改进方向:
DCGAN:用卷积架构替代全连接,BatchNorm,LeakyReLU
WGAN:用 Wasserstein 距离替代 JS 散度,梯度裁剪
WGAN-GP:用梯度惩罚替代梯度裁剪(更稳定)
StyleGAN:噪声注入、AdaIN、渐进式训练(高分辨率合成)
条件 GAN:G 和 D 都以类别标签为条件
输入输出形状
Generator G:
输入: z (B, latent_dim) ← 纯随机噪声,如 latent_dim=100
输出: x_fake (B, C, H, W) ← 生成的假图像,值域 [-1, 1] 或 [0, 1]
Discriminator D:
输入: x (B, C, H, W) ← 真实图像 或 假图像
输出: score (B, 1) ← 真实概率(DCGAN)或真实性分数(WGAN)
训练时的 Tensor 流:
z: (B, 100) → G → x_fake: (B, 3, 64, 64)
x_real: (B, 3, 64, 64) → D → real_score: (B, 1)
x_fake: (B, 3, 64, 64) → D → fake_score: (B, 1)
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─────────────────────────────────────────────
# 2.1 DCGAN(深度卷积 GAN)
# ─────────────────────────────────────────────
class DCGANGenerator(nn.Module):
"""
z: (B, latent_dim) → x_fake: (B, 3, 64, 64)
使用转置卷积逐步上采样,无池化层
"""
def __init__(self, latent_dim=100, img_channels=3, feature_map=64):
super().__init__()
fm = feature_map
self.net = nn.Sequential(
# 输入 z 形状: (B, latent_dim, 1, 1)
nn.ConvTranspose2d(latent_dim, fm*8, 4, 1, 0, bias=False), # → (B, 512, 4, 4)
nn.BatchNorm2d(fm*8), nn.ReLU(True),
nn.ConvTranspose2d(fm*8, fm*4, 4, 2, 1, bias=False), # → (B, 256, 8, 8)
nn.BatchNorm2d(fm*4), nn.ReLU(True),
nn.ConvTranspose2d(fm*4, fm*2, 4, 2, 1, bias=False), # → (B, 128, 16, 16)
nn.BatchNorm2d(fm*2), nn.ReLU(True),
nn.ConvTranspose2d(fm*2, fm, 4, 2, 1, bias=False), # → (B, 64, 32, 32)
nn.BatchNorm2d(fm), nn.ReLU(True),
nn.ConvTranspose2d(fm, img_channels, 4, 2, 1, bias=False), # → (B, 3, 64, 64)
nn.Tanh() # 输出 [-1, 1],对应归一化的图像
)
def forward(self, z):
# z: (B, latent_dim) → 变形为 (B, latent_dim, 1, 1)
z = z.view(z.size(0), -1, 1, 1)
return self.net(z) # (B, 3, 64, 64)
class DCGANDiscriminator(nn.Module):
"""
x: (B, 3, 64, 64) → score: (B, 1)
使用步长卷积下采样,避免 MaxPool(会丢失空间信息)
用 LeakyReLU(非 ReLU):防止梯度死亡,让负值有梯度
"""
def __init__(self, img_channels=3, feature_map=64):
super().__init__()
fm = feature_map
self.net = nn.Sequential(
# 注意:第一层不用 BN(作者经验)
nn.Conv2d(img_channels, fm, 4, 2, 1, bias=False), # → (B, 64, 32, 32)
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(fm, fm*2, 4, 2, 1, bias=False), # → (B, 128, 16, 16)
nn.BatchNorm2d(fm*2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(fm*2, fm*4, 4, 2, 1, bias=False), # → (B, 256, 8, 8)
nn.BatchNorm2d(fm*4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(fm*4, fm*8, 4, 2, 1, bias=False), # → (B, 512, 4, 4)
nn.BatchNorm2d(fm*8), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(fm*8, 1, 4, 1, 0, bias=False), # → (B, 1, 1, 1)
nn.Sigmoid() # 输出 (0, 1) 概率
)
def forward(self, x):
return self.net(x).view(-1, 1) # (B, 1)
def train_dcgan(G, D, dataloader, epochs=50, latent_dim=100, lr=2e-4):
"""
DCGAN 训练循环:D 和 G 交替更新
关键:D 更新时固定 G,G 更新时固定 D
"""
device = next(G.parameters()).device
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
# betas=(0.5, 0.999):DCGAN 论文推荐,β1=0.5 比默认 0.9 更稳定
criterion = nn.BCELoss() # Binary Cross Entropy
real_label = 1.0 # 真实图像的目标标签
fake_label = 0.0 # 假图像的目标标签
for epoch in range(epochs):
for i, (x_real, _) in enumerate(dataloader):
x_real = x_real.to(device)
B = x_real.size(0)
# ═══════ 第一步:更新 Discriminator D ═══════
# 目标:最大化 log D(x_real) + log(1 - D(G(z)))
D.zero_grad()
# 真实图像的损失
real_labels = torch.full((B, 1), real_label, device=device)
real_score = D(x_real) # (B, 1)
loss_D_real = criterion(real_score, real_labels)
loss_D_real.backward()
# 假图像的损失
z = torch.randn(B, latent_dim, device=device) # (B, 100)
x_fake = G(z) # (B, 3, 64, 64)
fake_labels = torch.full((B, 1), fake_label, device=device)
fake_score = D(x_fake.detach()) # detach:不反传到 G
loss_D_fake = criterion(fake_score, fake_labels)
loss_D_fake.backward()
loss_D = loss_D_real + loss_D_fake
opt_D.step()
# ═══════ 第二步:更新 Generator G ═══════
# 目标:最大化 log D(G(z)),即让 D 把假图打成真的
G.zero_grad()
# 用真实标签欺骗 D(非饱和目标,梯度更大)
fake_score_for_G = D(x_fake) # 不 detach,梯度流回 G
loss_G = criterion(fake_score_for_G, real_labels) # 假图配真标签
loss_G.backward()
opt_G.step()
if i % 100 == 0:
print(f"Epoch {epoch} Step {i}: "
f"loss_D={loss_D.item():.4f}, loss_G={loss_G.item():.4f}")
@torch.no_grad()
def generate_images(G, n=16, latent_dim=100):
"""
推理:GAN 生成只需一步,无需迭代
输入随机噪声,输出图像
"""
G.eval()
z = torch.randn(n, latent_dim) # (16, 100) 随机采样
x_fake = G(z) # (16, 3, 64, 64) 直接生成
# 反归一化 [-1,1] → [0,1]
x_fake = (x_fake + 1) / 2
return x_fake
# ─────────────────────────────────────────────
# 2.2 WGAN-GP(更稳定的 GAN 变体)
# ─────────────────────────────────────────────
def gradient_penalty(D, x_real, x_fake, device):
"""
梯度惩罚:强制 D 满足 1-Lipschitz 约束
在真假样本的插值点上,要求梯度范数接近 1
"""
B = x_real.size(0)
# 随机插值系数 α ~ Uniform(0,1)
alpha = torch.rand(B, 1, 1, 1, device=device)
x_interp = (alpha * x_real + (1 - alpha) * x_fake.detach()).requires_grad_(True)
score = D(x_interp)
grad = torch.autograd.grad(
outputs=score, inputs=x_interp,
grad_outputs=torch.ones_like(score),
create_graph=True, retain_graph=True
)[0] # (B, C, H, W)
grad_norm = grad.view(B, -1).norm(2, dim=1) # (B,)
# 惩罚梯度范数偏离 1 的程度
gp = ((grad_norm - 1) ** 2).mean()
return gp
def train_wgan_gp(G, D, dataloader, epochs=50, latent_dim=100,
lr=1e-4, lambda_gp=10, n_critic=5):
"""
WGAN-GP 训练
n_critic=5:每更新 1 次 G,先更新 5 次 D(D 需要更充分训练)
无 BN in D(梯度惩罚与 BN 不兼容)
"""
device = next(G.parameters()).device
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.0, 0.9))
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.0, 0.9))
for epoch in range(epochs):
for i, (x_real, _) in enumerate(dataloader):
x_real = x_real.to(device)
B = x_real.size(0)
# 更新 D(n_critic 次)
for _ in range(n_critic):
z = torch.randn(B, latent_dim, device=device)
x_fake = G(z).detach()
# Wasserstein 距离估计:E[D(real)] - E[D(fake)]
# WGAN 的 D 不是分类器,是"评分器"(Critic),无 Sigmoid
loss_D = -D(x_real).mean() + D(x_fake).mean()
gp = gradient_penalty(D, x_real, x_fake, device)
loss_D = loss_D + lambda_gp * gp
D.zero_grad(); loss_D.backward(); opt_D.step()
# 更新 G(1 次)
z = torch.randn(B, latent_dim, device=device)
x_fake = G(z)
loss_G = -D(x_fake).mean() # 最大化 D(G(z))
G.zero_grad(); loss_G.backward(); opt_G.step()
# 快速验证
G = DCGANGenerator(latent_dim=100)
D = DCGANDiscriminator()
z = torch.randn(4, 100)
x_fake = G(z)
print("G output:", x_fake.shape) # (4, 3, 64, 64)
score = D(x_fake)
print("D output:", score.shape) # (4, 1)3. Auto-Regressive(自回归生成)
核心思想
自回归模型是语言模型的主流范式。其核心假设非常简单:一个序列的联合概率,可以分解为每个位置的条件概率之积。
比喻: 就像写作文一样,每次写下一个字,都要基于前面已经写下的所有内容来决定下一个字写什么。GPT 做的事情正是这个:给定前面的 token,预测下一个最可能的 token(Next-Token Prediction)。
训练目标极其简单:最大化训练数据在模型下的似然(等价于最小化交叉熵)。
Next-Token Prediction
训练时:一次并行预测所有位置的下一个 token
输入序列: [A, B, C, D, E] (T=5 个 token)
目标序列: [B, C, D, E, F] (向右移一位)
因果 mask(下三角矩阵)确保位置 t 只能看到 t 之前的 token:
位置 1 (A):只看 A,预测 B
位置 2 (B):看 A B,预测 C
位置 3 (C):看 A B C,预测 D
...
Loss = CrossEntropy(pred[0:T-1], target[1:T])
= 平均每个位置的预测 loss(Teacher Forcing)
"Teacher Forcing":训练时用真实的 token 做历史(而非模型的预测),
这让训练稳定且并行,但推理时必须串行生成。
推理:自回归串行生成
推理时:串行,每次只预测一个新 token
Prompt: [A, B, C]
│
step 1: 模型看 [A,B,C] → 预测 x_4,采样得到 D
step 2: 模型看 [A,B,C,D] → 预测 x_5,采样得到 E
step 3: 模型看 [A,B,C,D,E] → 预测 x_6,采样得到 F
...
直到生成 <EOS> 或达到最大长度
时间复杂度:O(T²·d)
每步都要重新计算所有历史 token 的 K/V → 很慢
→ KV Cache 优化!
KV Cache:推理加速的核心技术
没有 KV Cache(每步重算所有历史):
Step 1: 计算 [A,B,C] 的 K/V,预测第 4 个 token
Step 2: 重算 [A,B,C,D] 的 K/V,预测第 5 个 token ← A/B/C 的 K/V 重复计算了!
Step 3: 重算 [A,B,C,D,E] 的 K/V ...
有 KV Cache(缓存历史 K/V,每步只算新 token 的 K/V):
Step 1: 计算 [A,B,C] 的 K/V,存入 Cache,预测第 4 个 token
Step 2: 只算 [D] 的 K/V,拼到 Cache,预测第 5 个 token ✓ 不重复计算
Step 3: 只算 [E] 的 K/V,拼到 Cache ...
KV Cache 的代价:
显存:每层存 K/V,形状 (2, B, T, n_kv_heads, d_k)
batch=1, 32层, T=4096, 8 heads, d_k=128 → 约 4GB(FP16)
→ 这是 LLM 推理显存的主要消耗!
GQA(分组查询注意力)本质上就是为了减少 KV Cache 大小:
MHA: K/V 各 32 个 head → Cache 满
GQA: K/V 各 8 个 head → Cache 缩减 4 倍
采样策略
Greedy(贪心):每步取概率最大的 token
优点:确定性,快
缺点:容易重复,不多样
Temperature Sampling:logits / T,T<1 变尖锐(更保守),T>1 变平坦(更多样)
T→0:退化为贪心
T→∞:变为均匀随机
Top-K Sampling:只从概率最高的 K 个 token 中采样
K=50:常用设置,过滤掉长尾低概率词
Top-P(Nucleus)Sampling:累积概率超过 P 的最小集合
P=0.9:动态调整候选集大小,比 Top-K 更自适应
实践:通常 Temperature + Top-P 组合使用
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ─────────────────────────────────────────────
# 3.1 极简 GPT(仅展示核心结构)
# ─────────────────────────────────────────────
class CausalSelfAttentionWithKVCache(nn.Module):
"""
带 KV Cache 的因果自注意力
训练时:一次并行处理整个序列(T > 1,需要因果 mask)
推理时:每次只输入 1 个新 token(T = 1,不需要 mask,用 cache 拼接历史)
"""
def __init__(self, d_model=256, n_heads=8):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, kv_cache=None, use_cache=False):
"""
x: (B, T, D)
kv_cache: None 或 (K_cache, V_cache),各 (B, h, T_past, d_k)
use_cache: 推理时设为 True
返回: (output, new_kv_cache)
"""
B, T, D = x.shape
q, k, v = self.W_qkv(x).chunk(3, dim=-1) # 各 (B, T, D)
def split(t):
return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# → (B, h, T, d_k)
q, k, v = split(q), split(k), split(v)
# 拼接历史 KV Cache
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2) # (B, h, T_past+T, d_k)
v = torch.cat([v_cache, v], dim=2)
new_cache = (k, v) if use_cache else None
# 注意力计算
S = k.shape[2] # 当前总序列长度(含历史)
scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_k) # (B, h, T, S)
# 因果 mask:训练时必须,推理单步时 T=1 可以省略
if T > 1:
# 只 mask query 对应的位置(T × S 的子矩阵)
mask = torch.triu(torch.full((T, S), float('-inf'), device=x.device), 1)
# 如果有 cache,mask 的列从 S-T 开始才有效
if kv_cache is not None:
T_past = S - T
mask = torch.tril(torch.ones(T, S, device=x.device),
diagonal=T_past) == 0
scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0),
float('-inf'))
else:
scores = scores + mask
attn = F.softmax(scores, dim=-1)
out = (attn @ v).transpose(1,2).contiguous().view(B, T, D)
return self.W_o(out), new_cache
class GPTBlock(nn.Module):
def __init__(self, d_model=256, n_heads=8, ffn_ratio=4):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttentionWithKVCache(d_model, n_heads)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * ffn_ratio),
nn.GELU(),
nn.Linear(d_model * ffn_ratio, d_model)
)
def forward(self, x, kv_cache=None, use_cache=False):
attn_out, new_cache = self.attn(self.norm1(x), kv_cache, use_cache)
x = x + attn_out
x = x + self.ffn(self.norm2(x))
return x, new_cache
class MiniGPT(nn.Module):
def __init__(self, vocab_size=1000, d_model=256, n_heads=8,
n_layers=4, max_len=512):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.blocks = nn.ModuleList([
GPTBlock(d_model, n_heads) for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, token_ids, kv_caches=None, use_cache=False):
"""
token_ids: (B, T) int
kv_caches: None 或 长度为 n_layers 的列表,每个元素是 (K, V)
训练时: kv_caches=None, use_cache=False
推理时: kv_caches=past_caches, use_cache=True
"""
B, T = token_ids.shape
# 位置编码:如果有 cache,起始位置是 T_past
T_past = kv_caches[0][0].shape[2] if kv_caches else 0
pos = torch.arange(T_past, T_past + T, device=token_ids.device)
x = self.token_emb(token_ids) + self.pos_emb(pos) # (B, T, D)
new_kv_caches = []
for i, block in enumerate(self.blocks):
cache = kv_caches[i] if kv_caches else None
x, new_cache = block(x, cache, use_cache)
new_kv_caches.append(new_cache)
logits = self.lm_head(self.norm(x)) # (B, T, vocab_size)
if use_cache:
return logits, new_kv_caches
return logits
def train_gpt(model, dataloader, epochs=5, lr=3e-4):
"""
GPT 训练:Next-Token Prediction
输入序列 [x_0, x_1, ..., x_{T-1}]
预测目标 [x_1, x_2, ..., x_T](向右偏移一位)
注意:Teacher Forcing——用真实 token 作为历史,不用模型自己预测的
"""
device = next(model.parameters()).device
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.1)
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in dataloader:
token_ids = batch.to(device) # (B, T)
# ── 核心训练设置 ──
# 输入:token_ids[:, :-1] 即 [x_0, ..., x_{T-2}] 形状 (B, T-1)
# 目标:token_ids[:, 1:] 即 [x_1, ..., x_{T-1}] 形状 (B, T-1)
# 这样位置 t 的预测对应目标 t+1,实现 next-token prediction
logits = model(token_ids[:, :-1]) # (B, T-1, vocab_size)
targets = token_ids[:, 1:] # (B, T-1)
# CrossEntropy:把 (B, T-1, V) 展平成 (B*(T-1), V)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)), # (B*(T-1), vocab)
targets.reshape(-1) # (B*(T-1),)
)
optimizer.zero_grad()
loss.backward()
# 梯度裁剪:防止 Transformer 训练时梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: avg loss = {total_loss/len(dataloader):.4f}")
@torch.no_grad()
def generate_with_kv_cache(model, prompt_ids, max_new_tokens=50,
temperature=1.0, top_p=0.9):
"""
带 KV Cache 的自回归生成
核心逻辑:
1. 先处理 prompt,建立初始 KV Cache
2. 每步只输入 1 个新 token,利用 Cache 避免重算历史
"""
model.eval()
device = prompt_ids.device
token_ids = prompt_ids # (B, T_prompt)
kv_caches = None
# ── Phase 1: Prefill(预填充)──────────────────────────────
# 一次性处理整个 prompt,建立 KV Cache
logits, kv_caches = model(token_ids, kv_caches=None, use_cache=True)
# logits: (B, T_prompt, vocab), kv_caches: 每层的 (K, V)
# 取最后一个位置的 logits,采样第一个新 token
next_token = sample_token(logits[:, -1, :], temperature, top_p) # (B, 1)
generated = [next_token]
# ── Phase 2: Decode(逐步解码)────────────────────────────
for step in range(max_new_tokens - 1):
# 每次只输入 1 个 token,形状 (B, 1)
logits, kv_caches = model(next_token, kv_caches=kv_caches, use_cache=True)
# logits: (B, 1, vocab)
next_token = sample_token(logits[:, -1, :], temperature, top_p) # (B, 1)
generated.append(next_token)
# 遇到 EOS 提前停止(实际使用时需要设定 eos_token_id)
# if (next_token == eos_token_id).all(): break
return torch.cat([prompt_ids] + generated, dim=1) # (B, T_prompt + max_new)
def sample_token(logits, temperature=1.0, top_p=0.9):
"""
Top-P(Nucleus)采样
logits: (B, vocab_size)
"""
logits = logits / max(temperature, 1e-6)
probs = F.softmax(logits, dim=-1) # (B, vocab)
# Top-P:按概率从大到小排序,取累积概率 ≤ P 的集合
sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
cumsum = sorted_probs.cumsum(dim=-1)
# 找到超过 top_p 的位置,把后面的 token 概率置零
remove = cumsum - sorted_probs > top_p
sorted_probs[remove] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) # 重归一化
# 采样
sample_idx = torch.multinomial(sorted_probs, 1) # (B, 1) 在排序后的索引
next_token_idx = sorted_idx.gather(1, sample_idx) # 还原到原始词表索引
return next_token_idx # (B, 1)
# 验证
model = MiniGPT(vocab_size=1000, d_model=128, n_heads=4, n_layers=2)
# 训练(仿造数据)
fake_data = torch.randint(0, 1000, (32, 64))
ds = torch.utils.data.TensorDataset(fake_data)
loader = torch.utils.data.DataLoader(ds, batch_size=8)
# 解包每个 batch
class FlatLoader:
def __init__(self, dl): self.dl = dl
def __iter__(self):
for (x,) in self.dl: yield x
def __len__(self): return len(self.dl)
train_gpt(model, FlatLoader(loader), epochs=2)
# 推理
prompt = torch.randint(0, 1000, (1, 5))
output = generate_with_kv_cache(model, prompt, max_new_tokens=10)
print("Generated shape:", output.shape) # (1, 15)4. Diffusion(扩散模型)
核心思想
扩散模型来自一个物理比喻:墨水滴入水中会慢慢扩散,最终变成均匀的颜色。
- 前向过程(Forward Process):把一张真实图像,逐步加入高斯噪声,经过 步后完全变成纯噪声。这个过程是固定的,不需要学习。
- 反向过程(Reverse Process):训练一个神经网络,学会逆转噪声——从纯噪声出发,一步步去除噪声,最终还原出清晰图像。
为什么不直接让网络学 noise → image? 因为直接映射太难(噪声空间和图像空间的分布差异太大)。扩散模型把这个困难问题拆解成 T 个小步骤,每一步只需预测”从 到 “的微小变化,每步都容易学。
比喻: 就像雕塑家把大理石雕成雕像,不是一刀刻出来的,而是每次只去掉一点点多余的石料,T 步后得到完美作品。网络学的就是”每一步应该去掉哪块石料”。
前向过程(加噪过程)
给定真实图像 x_0,逐步加噪:
x_t = √(ᾱ_t) · x_0 + √(1 - ᾱ_t) · ε, ε ~ N(0, I)
其中:
β_t:第 t 步的噪声强度(噪声调度),β_1 < β_2 < ... < β_T
α_t = 1 - β_t
ᾱ_t = ∏_{s=1}^{t} α_s (累积乘积)
t=0: x_0 = 原始图像,完全清晰
t=T/2: x_{T/2} 开始有明显噪声,但还能看出大致内容
t=T: x_T ≈ N(0, I),纯噪声,完全看不出原始内容
关键性质:
① 可以跳步:给定 x_0,直接一步算出任意时刻的 x_t(不需要逐步迭代)
② 条件分布:q(x_t | x_0) = N(x_t; √ᾱ_t·x_0, (1-ᾱ_t)·I)
三种预测目标
这是扩散模型最重要的设计选择之一:网络 究竟预测什么?
方式一:预测噪声 ε(ε-prediction,最常用,DDPM 默认)
x_t = √ᾱ_t · x_0 + √(1-ᾱ_t) · ε
网络学习:ε_θ(x_t, t) ≈ ε(反向解出 x_0 = (x_t - √(1-ᾱ_t)·ε_θ) / √ᾱ_t)
Loss: ||ε - ε_θ(x_t, t)||²
方式二:预测原图 x_0(x-prediction)
网络学习:x_θ(x_t, t) ≈ x_0
可以通过 x_0 推出 ε,等价但训练信号不同
Loss: ||x_0 - x_θ(x_t, t)||²
方式三:预测 v(v-prediction,Stable Diffusion v2 / EDM 使用)
v = √ᾱ_t · ε - √(1-ᾱ_t) · x_0 (一种特殊的线性组合)
网络学习:v_θ(x_t, t) ≈ v
优点:在 t 很大或很小时梯度更稳定,数值条件数更好
Loss: ||v - v_θ(x_t, t)||²
三者的等价关系(可互相推导):
已知 x_t 和任一预测,都可以推出其他两者
x_0 = √ᾱ_t · x_t - √(1-ᾱ_t) · v_pred
= (x_t - √(1-ᾱ_t) · ε_pred) / √ᾱ_t
DDPM vs DDIM 推理
DDPM(Denoising Diffusion Probabilistic Models):
● 推理步数 = T = 1000(很慢!每次推理需要 1000 步 UNet 前向)
● 每步有随机性(加回部分噪声),是随机采样
● x_{t-1} = μ_θ(x_t, t) + σ_t · z,z ~ N(0,I)
DDIM(Denoising Diffusion Implicit Models):
● 推理步数可以是 50、20、10 甚至更少!
● 每步是确定性的(η=0 时),相同噪声 → 相同图像
● 核心思想:跳步!不需要走完所有 1000 步,可以直接跳从 t=1000 到 t=900
DDIM 更新公式(η=0 时完全确定性):
x_{t-1} = √ᾱ_{t-1} · x̂_0(x_t)
+ √(1-ᾱ_{t-1} - η²σ_t²) · ε_θ(x_t, t)
+ η · σ_t · z
其中 x̂_0(x_t) = (x_t - √(1-ᾱ_t) · ε_θ(x_t, t)) / √ᾱ_t
是当前步预测的"去噪后原图"
η=0:完全确定性,速度快,但多样性略低
η=1:退化回 DDPM,有随机性,多样性高
Latent Space vs Pixel Space
Pixel Space Diffusion(原始 DDPM):
直接在像素空间做扩散
x_t: (B, 3, H, W) ← 对图像像素加噪/去噪
问题:H=W=256 → 196608 维,计算量极大
Latent Diffusion(Stable Diffusion 等):
先用 VAE 编码到低维隐空间,在隐空间做扩散,最后解码
┌─────────────────────────────────────────────────┐
│ 图像 x: (B, 3, 512, 512) │
│ ↓ VAE Encoder(冻结,预训练) │
│ 隐变量 z: (B, 4, 64, 64) ← 压缩 64 倍 │
│ ↓ 加噪/去噪(扩散过程在这里发生!) │
│ z_denoised: (B, 4, 64, 64) │
│ ↓ VAE Decoder(冻结,预训练) │
│ 生成图像: (B, 3, 512, 512) │
└─────────────────────────────────────────────────┘
优点:
✦ 隐空间更小(4×64×64 vs 3×512×512),计算量降低 ~48×
✦ 隐空间更"语义化",扩散效果更好
✦ UNet 只需要处理 64×64 的特征图
条件生成(文生图):
UNet 接受额外的条件 c(文本 embedding)
通过 Cross-Attention 把文本信息注入 UNet 中间层
ε_θ(z_t, t, c) → 条件去噪
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ─────────────────────────────────────────────
# 4.1 噪声调度器(Noise Scheduler)
# ─────────────────────────────────────────────
class DDPMScheduler:
"""
管理扩散过程的所有参数:β_t, α_t, ᾱ_t
以及前向加噪和反向去噪的步骤
"""
def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02,
schedule='linear'):
self.T = T
# 噪声调度:从小到大的 β_t 序列
if schedule == 'linear':
# 线性:最简单,DDPM 原论文使用
betas = torch.linspace(beta_start, beta_end, T)
elif schedule == 'cosine':
# 余弦:在 t 接近 0 时变化更平滑,生成质量更好
steps = T + 1
x = torch.linspace(0, T, steps)
alphas_cumprod = torch.cos(((x/T) + 0.008) / 1.008 * math.pi/2) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = betas.clamp(0, 0.999)
elif schedule == 'sqrt':
# 平方根调度:适合一致性模型
betas = torch.linspace(beta_start**0.5, beta_end**0.5, T)**2
self.betas = betas # (T,)
self.alphas = 1.0 - betas # (T,)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # ᾱ_t (T,)
self.sqrt_alphas_cumprod = self.alphas_cumprod.sqrt()
self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod).sqrt()
def q_sample(self, x0, t, noise=None):
"""
前向加噪(Forward Process):一步算出任意时刻的 x_t
x_t = √ᾱ_t · x_0 + √(1-ᾱ_t) · ε
输入: x0 (B, C, H, W), t (B,) 时间步索引
输出: x_t (B, C, H, W)
"""
if noise is None:
noise = torch.randn_like(x0)
# 取 t 对应的调度参数,并 reshape 为 (B, 1, 1, 1) 便于广播
sqrt_a = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
x_t = sqrt_a * x0 + sqrt_1a * noise
return x_t, noise # 返回 x_t 和 ε,训练时 ε 是监督目标
def predict_x0_from_eps(self, x_t, t, eps_pred):
"""从预测的噪声 ε_pred 还原 x_0"""
sqrt_a = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
return (x_t - sqrt_1a * eps_pred) / sqrt_a
def predict_x0_from_v(self, x_t, t, v_pred):
"""从预测的 v 还原 x_0(v-prediction)"""
sqrt_a = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
# v = √ᾱ·ε - √(1-ᾱ)·x_0 → x_0 = √ᾱ·x_t - √(1-ᾱ)·v
return sqrt_a * x_t - sqrt_1a * v_pred
def predict_eps_from_v(self, x_t, t, v_pred):
"""从预测的 v 还原 ε"""
sqrt_a = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
# ε = √ᾱ·v + √(1-ᾱ)·x_t
return sqrt_a * v_pred + sqrt_1a * x_t
def ddpm_step(self, x_t, t_idx, eps_pred):
"""
DDPM 反向一步(随机):x_t → x_{t-1}
"""
t = torch.tensor([t_idx])
beta = self.betas[t_idx]
a = self.alphas[t_idx]
a_bar = self.alphas_cumprod[t_idx]
a_bar_prev = self.alphas_cumprod[t_idx-1] if t_idx > 0 \
else torch.tensor(1.0)
# 均值(DDPM 论文 Eq. 11)
coef1 = beta * a_bar_prev.sqrt() / (1 - a_bar)
coef2 = (1 - a_bar_prev) * a.sqrt() / (1 - a_bar)
x0_pred = self.predict_x0_from_eps(
x_t.unsqueeze(0), t, eps_pred.unsqueeze(0)
).squeeze(0)
mu = coef1 * x0_pred + coef2 * x_t
# 方差(固定为 β̃_t)
sigma = ((1 - a_bar_prev) / (1 - a_bar) * beta).sqrt()
z = torch.randn_like(x_t) if t_idx > 0 else 0
return mu + sigma * z # x_{t-1}
# ─────────────────────────────────────────────
# 4.2 简化版 UNet(去噪网络的主体)
# ─────────────────────────────────────────────
class SinusoidalTimeEmb(nn.Module):
"""
时间步 t 的编码:类似 Transformer 的位置编码
网络需要知道当前处于哪个噪声级别,才能做出正确的预测
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
# t: (B,) 整数时间步
half = self.dim // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device) / half
)
angles = t[:, None].float() * freqs[None] # (B, dim/2)
emb = torch.cat([angles.sin(), angles.cos()], dim=-1) # (B, dim)
return emb
class ResBlock(nn.Module):
"""UNet 中的基本残差块,接受时间步 embedding 作为额外输入"""
def __init__(self, in_ch, out_ch, time_dim):
super().__init__()
self.conv1 = nn.Sequential(nn.GroupNorm(8, in_ch), nn.SiLU(),
nn.Conv2d(in_ch, out_ch, 3, padding=1))
# 时间步 embedding 注入:通过 scale-shift(仿射变换)方式
self.time_mlp = nn.Sequential(nn.SiLU(),
nn.Linear(time_dim, out_ch * 2))
self.conv2 = nn.Sequential(nn.GroupNorm(8, out_ch), nn.SiLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1))
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch \
else nn.Identity()
def forward(self, x, t_emb):
h = self.conv1(x)
# time_mlp 输出 scale 和 shift,做特征调制
scale, shift = self.time_mlp(t_emb).chunk(2, dim=-1)
scale = scale.view(*scale.shape, 1, 1)
shift = shift.view(*shift.shape, 1, 1)
h = h * (1 + scale) + shift # 仿射变换,注入时间信息
h = self.conv2(h)
return h + self.skip(x) # 残差
class SimpleUNet(nn.Module):
"""
扩散模型的去噪网络
输入: x_t (B, C, H, W) + 时间步 t (B,)
输出: 预测的噪声 ε,或 x_0,或 v(取决于 prediction_type)
形状不变:输出 (B, C, H, W)
"""
def __init__(self, in_channels=3, base_channels=64,
channel_mults=(1, 2, 4, 8), time_dim=256,
prediction_type='epsilon'):
super().__init__()
self.prediction_type = prediction_type
channels = [base_channels * m for m in channel_mults]
# 时间编码:t → embedding 向量
self.time_emb = nn.Sequential(
SinusoidalTimeEmb(base_channels),
nn.Linear(base_channels, time_dim),
nn.SiLU(),
nn.Linear(time_dim, time_dim)
)
# 初始卷积
self.init_conv = nn.Conv2d(in_channels, channels[0], 3, padding=1)
# Encoder(下采样)
self.downs = nn.ModuleList()
self.downpools = nn.ModuleList()
in_ch = channels[0]
for out_ch in channels[1:]:
self.downs.append(ResBlock(in_ch, out_ch, time_dim))
self.downpools.append(nn.Conv2d(out_ch, out_ch, 4, 2, 1)) # 步长2下采样
in_ch = out_ch
# Bottleneck
self.mid = ResBlock(channels[-1], channels[-1], time_dim)
# Decoder(上采样)
self.ups = nn.ModuleList()
self.uppools = nn.ModuleList()
for out_ch in reversed(channels[:-1]):
self.uppools.append(nn.ConvTranspose2d(in_ch, in_ch, 4, 2, 1))
self.ups.append(ResBlock(in_ch + out_ch, out_ch, time_dim)) # +out_ch 是 skip
in_ch = out_ch
self.final = nn.Sequential(
nn.GroupNorm(8, channels[0]),
nn.SiLU(),
nn.Conv2d(channels[0], in_channels, 3, padding=1) # 输出和输入同形状
)
def forward(self, x_t, t):
"""
x_t: (B, C, H, W) 含噪图像
t: (B,) 时间步(整数)
返回: (B, C, H, W) 预测目标(ε / x_0 / v)
"""
t_emb = self.time_emb(t) # (B, time_dim)
x = self.init_conv(x_t)
# Encoder + 保存 skip 特征
skips = []
for down, pool in zip(self.downs, self.downpools):
x = down(x, t_emb)
skips.append(x)
x = pool(x)
x = self.mid(x, t_emb)
# Decoder + 融合 skip 特征
for up_pool, up, skip in zip(self.uppools, self.ups, reversed(skips)):
x = up_pool(x)
if x.shape != skip.shape:
x = F.interpolate(x, size=skip.shape[2:])
x = torch.cat([x, skip], dim=1)
x = up(x, t_emb)
return self.final(x)
# ─────────────────────────────────────────────
# 4.3 DDPM 训练
# ─────────────────────────────────────────────
def train_ddpm(unet, scheduler, dataloader, epochs=100, lr=2e-4,
prediction_type='epsilon'):
"""
DDPM 训练:
① 随机采样真实图像 x_0
② 随机采样时间步 t
③ 随机采样噪声 ε
④ 用前向过程得到 x_t
⑤ 用 UNet 预测 ε(或 x_0 或 v)
⑥ 计算 MSE 损失
"""
device = next(unet.parameters()).device
optimizer = torch.optim.AdamW(unet.parameters(), lr=lr)
T = scheduler.T
for epoch in range(epochs):
total_loss = 0
for x0, _ in dataloader:
x0 = x0.to(device) # (B, 3, H, W) 真实图像,归一化到 [-1, 1]
B = x0.shape[0]
# ① 均匀采样时间步 t ~ Uniform{1, ..., T}
t = torch.randint(0, T, (B,), device=device) # (B,)
# ② 随机采样噪声
noise = torch.randn_like(x0) # (B, 3, H, W)
# ③ 前向加噪:得到 x_t
x_t, _ = scheduler.q_sample(x0, t, noise) # (B, 3, H, W)
# ④ UNet 预测
pred = unet(x_t, t) # (B, 3, H, W)
# ⑤ 计算损失(根据 prediction_type 选择监督目标)
if prediction_type == 'epsilon':
target = noise # 预测噪声 ε
elif prediction_type == 'x0':
target = x0 # 预测原图 x_0
elif prediction_type == 'v':
# v = √ᾱ_t · ε - √(1-ᾱ_t) · x_0
sqrt_a = scheduler.sqrt_alphas_cumprod[t].view(-1,1,1,1)
sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
target = sqrt_a * noise - sqrt_1a * x0
loss = F.mse_loss(pred, target)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch}: loss = {total_loss/len(dataloader):.5f}")
# ─────────────────────────────────────────────
# 4.4 DDPM 推理(随机采样,1000步)
# ─────────────────────────────────────────────
@torch.no_grad()
def ddpm_sample(unet, scheduler, shape, device, prediction_type='epsilon'):
"""
DDPM 标准采样:从纯噪声开始,逐步去噪 T 步
shape: (B, C, H, W) 要生成的图像形状
每一步都有随机性(加回 σ_t · z),所以每次生成结果不同
"""
unet.eval()
T = scheduler.T
# 从纯高斯噪声开始
x_t = torch.randn(*shape, device=device) # x_T ~ N(0, I)
for t_idx in reversed(range(T)): # t = T-1, T-2, ..., 1, 0
t_batch = torch.full((shape[0],), t_idx, device=device, dtype=torch.long)
# UNet 预测
pred = unet(x_t, t_batch) # (B, C, H, W)
# 根据 prediction_type 转换为 ε
if prediction_type == 'epsilon':
eps_pred = pred
elif prediction_type == 'x0':
# 从 x_0 推出 ε
sqrt_a = scheduler.sqrt_alphas_cumprod[t_idx]
sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t_idx]
eps_pred = (x_t - sqrt_a * pred) / sqrt_1a
elif prediction_type == 'v':
eps_pred = scheduler.predict_eps_from_v(
x_t, torch.tensor([t_idx]), pred.unsqueeze(0)
).squeeze(0)
# DDPM 反向一步
x_t = scheduler.ddpm_step(x_t, t_idx, eps_pred)
return x_t.clamp(-1, 1) # 裁剪到合理范围
# ─────────────────────────────────────────────
# 4.5 DDIM 推理(确定性,支持快速采样)
# ─────────────────────────────────────────────
@torch.no_grad()
def ddim_sample(unet, scheduler, shape, device,
n_steps=50, eta=0.0, prediction_type='epsilon'):
"""
DDIM 快速确定性采样
n_steps: 推理步数(可以远小于训练的 T=1000,如 50 步)
eta=0.0: 完全确定性(相同 noise → 相同图像)
eta=1.0: 退化为 DDPM(随机)
关键:在 T 个时间步中等间隔选取 n_steps 个,跳步推理
"""
unet.eval()
T = scheduler.T
# 选取等间隔的时间步序列(从大到小)
timesteps = torch.linspace(T-1, 0, n_steps, dtype=torch.long)
x_t = torch.randn(*shape, device=device) # 从纯噪声开始
for i, t_idx in enumerate(timesteps):
t_idx_int = t_idx.item()
t_batch = torch.full((shape[0],), t_idx_int, device=device, dtype=torch.long)
# UNet 预测
pred = unet(x_t, t_batch)
# 统一转换为 ε 预测和 x_0 预测
if prediction_type == 'epsilon':
eps_pred = pred
sqrt_a = scheduler.sqrt_alphas_cumprod[t_idx_int]
sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t_idx_int]
x0_pred = (x_t - sqrt_1a * eps_pred) / sqrt_a
elif prediction_type == 'x0':
x0_pred = pred
sqrt_a = scheduler.sqrt_alphas_cumprod[t_idx_int]
sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t_idx_int]
eps_pred = (x_t - sqrt_a * x0_pred) / sqrt_1a
elif prediction_type == 'v':
x0_pred = scheduler.predict_x0_from_v(x_t, torch.tensor([t_idx_int]), pred)
sqrt_a = scheduler.sqrt_alphas_cumprod[t_idx_int]
sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t_idx_int]
eps_pred = (x_t - sqrt_a * x0_pred) / sqrt_1a
# 确定下一个时间步 t_prev
if i + 1 < len(timesteps):
t_prev = timesteps[i + 1].item()
else:
t_prev = 0
# DDIM 更新公式
a_t = scheduler.alphas_cumprod[t_idx_int]
a_prev = scheduler.alphas_cumprod[t_prev] if t_prev > 0 \
else torch.tensor(1.0)
sigma_t = eta * ((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev)).sqrt()
# x_{t-1} = √ᾱ_{t-1} · x_0_pred
# + √(1-ᾱ_{t-1}-σ²) · ε_pred ← "指向噪声的方向"
# + σ · z ← 可选随机项
x_t = (a_prev.sqrt() * x0_pred
+ (1 - a_prev - sigma_t**2).clamp(min=0).sqrt() * eps_pred
+ sigma_t * torch.randn_like(x_t))
return x_t.clamp(-1, 1)
# ─────────────────────────────────────────────
# 4.6 Latent Diffusion(潜在扩散模型骨架)
# ─────────────────────────────────────────────
class SimpleVAE(nn.Module):
"""
极简 VAE:图像 ↔ 隐变量
真实使用时(如 Stable Diffusion)的 VAE 是预训练好后冻结的
"""
def __init__(self, img_channels=3, latent_channels=4, base_ch=64):
super().__init__()
# Encoder:图像 → 均值 + 方差
self.encoder = nn.Sequential(
nn.Conv2d(img_channels, base_ch, 3, padding=1),
nn.SiLU(),
nn.Conv2d(base_ch, base_ch*2, 4, stride=2, padding=1), # ÷2
nn.SiLU(),
nn.Conv2d(base_ch*2, base_ch*4, 4, stride=2, padding=1), # ÷4
nn.SiLU(),
nn.Conv2d(base_ch*4, latent_channels*2, 3, padding=1) # mu + logvar
)
# Decoder:隐变量 → 图像
self.decoder = nn.Sequential(
nn.Conv2d(latent_channels, base_ch*4, 3, padding=1),
nn.SiLU(),
nn.ConvTranspose2d(base_ch*4, base_ch*2, 4, stride=2, padding=1), # ×2
nn.SiLU(),
nn.ConvTranspose2d(base_ch*2, base_ch, 4, stride=2, padding=1), # ×4
nn.SiLU(),
nn.Conv2d(base_ch, img_channels, 3, padding=1),
nn.Tanh()
)
def encode(self, x):
"""图像 → (μ, σ)"""
out = self.encoder(x) # (B, 2C, H/4, W/4)
mu, logvar = out.chunk(2, dim=1)
return mu, logvar
def reparameterize(self, mu, logvar):
"""重参数化采样:z = μ + ε·σ"""
std = (0.5 * logvar).exp()
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar
def latent_diffusion_train_step(vae, unet, scheduler, x0, t, device,
prediction_type='epsilon'):
"""
Latent Diffusion 单步训练
关键区别:扩散过程发生在隐空间 z,而非像素空间 x
"""
# ① 编码到隐空间(VAE 冻结,不计算梯度)
with torch.no_grad():
mu, logvar = vae.encode(x0)
# 训练时通常用确定性的 μ,不加随机扰动
z0 = mu # (B, 4, H/8, W/8) ← 尺寸是原图的 1/8×1/8
# ② 在隐空间加噪
noise = torch.randn_like(z0)
z_t, _ = scheduler.q_sample(z0, t, noise) # (B, 4, H/8, W/8)
# ③ UNet 在隐空间预测(输入输出都是 4 通道的隐变量)
pred = unet(z_t, t) # (B, 4, H/8, W/8)
# ④ 损失(和像素空间完全一样,只是作用在 z 上)
if prediction_type == 'epsilon':
target = noise
elif prediction_type == 'x0':
target = z0
elif prediction_type == 'v':
sqrt_a = scheduler.sqrt_alphas_cumprod[t].view(-1,1,1,1)
sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
target = sqrt_a * noise - sqrt_1a * z0
return F.mse_loss(pred, target)
@torch.no_grad()
def latent_diffusion_sample(vae, unet, scheduler, n=4,
latent_shape=(4, 8, 8), n_steps=50, device='cpu'):
"""
Latent Diffusion 采样:
① 在隐空间采样(DDIM 快速)
② VAE 解码回像素空间
"""
# 在隐空间生成
z_shape = (n, *latent_shape)
z_sample = ddim_sample(unet, scheduler, z_shape, device, n_steps=n_steps)
# z_sample: (n, 4, H/8, W/8)
# VAE 解码回图像
images = vae.decode(z_sample) # (n, 3, H, W)
return (images + 1) / 2 # [-1,1] → [0,1]
# ── 快速验证 ──
scheduler = DDPMScheduler(T=1000, schedule='cosine')
unet = SimpleUNet(in_channels=3, base_channels=32,
channel_mults=(1, 2, 4), time_dim=128)
# 验证前向过程
x0 = torch.randn(2, 3, 32, 32) # 假图像
t = torch.randint(0, 1000, (2,))
x_t, noise = scheduler.q_sample(x0, t)
print("x_t shape:", x_t.shape) # (2, 3, 32, 32)
# 验证 UNet
pred = unet(x_t, t)
print("pred shape:", pred.shape) # (2, 3, 32, 32)
# 验证 DDIM 采样
samples = ddim_sample(unet, scheduler, (2, 3, 32, 32), device='cpu', n_steps=10)
print("samples shape:", samples.shape) # (2, 3, 32, 32)5. 四大范式对比总结
训练目标对比
| 范式 | 训练目标 | 损失函数 | 输入 | 输出 |
|---|---|---|---|---|
| CNN | 最小化预测误差 | CE / MSE / Dice | (B, C, H, W) | (B, classes) 或 (B, C, H, W) |
| GAN | G 骗过 D,D 识破 G | BCE (DCGAN) / Wasserstein (WGAN) | z: (B, latent) | G: 图像 (B, C, H, W);D: 概率 (B, 1) |
| AR | 最大化序列似然 | Cross-Entropy | token_ids: (B, T) | logits: (B, T, V) |
| Diffusion | 预测噪声/原图/v | MSE | x_t: (B,C,H,W), t: (B,) | ε / x₀ / v: (B, C, H, W) |
推理流程对比
| 范式 | 推理步数 | 随机性 | 速度 | 特点 |
|---|---|---|---|---|
| CNN | 1 步 | 无(确定性) | 极快 | 直接映射 |
| GAN | 1 步 | 有(采样 z) | 极快 | 一步生成 |
| AR | T 步(T=生成长度) | 有(token 采样) | 慢(串行) | KV Cache 加速 |
| Diffusion (DDPM) | 1000 步 | 有(随机去噪) | 极慢 | 质量高 |
| Diffusion (DDIM) | 20~50 步 | 可调(η) | 较快 | 速度质量平衡 |
生成能力对比
CNN(判别式):
✦ 不擅长生成,主要用于判别/回归任务
✦ 生成类任务需要特殊结构(超分、去噪)
GAN:
✦ 快速生成高质量图像(一步到位)
✦ 训练不稳定,可能 mode collapse
✦ 难以控制生成内容(条件 GAN 部分解决)
Auto-Regressive:
✦ 文本生成的绝对主流(GPT 系列)
✦ 支持任意长度序列生成
✦ 推理串行,速度受限于序列长度
✦ 可结合图像 token 做图像生成(VQVAE + GPT)
Diffusion:
✦ 目前图像生成质量最高(SD、DALL-E 3)
✦ 训练稳定,无 mode collapse
✦ 推理需多步,有加速方案(DDIM、LCM)
✦ 易于条件控制(CFG、ControlNet)
一图看懂四大范式
训练信号来源
│
┌─────────────────────────┼──────────────────────────────┐
│ │ │
▼ ▼ ▼
标注标签 对抗博弈 数据本身
(supervised) (adversarial) (self-supervised)
│ │ │ │
▼ ▼ ▼ ▼
CNN GAN AR Diffusion
(判别) (生成) (自回归) (生成)
确定性 一步生成 串行生成 多步去噪
推理快 训练难 训练快 训练稳
参考:DDPM (Ho et al., 2020) · DDIM (Song et al., 2021) · GPT (Radford et al.) · DCGAN (Radford et al., 2016)