神经网络四大范式详解

CNN · GAN · Auto-Regressive · Diffusion
训练与推理流程 · 损失函数 · 输入输出形状 · 完整代码


目录

  1. CNN(端到端学习)
  2. GAN(生成对抗网络)
  3. Auto-Regressive(自回归)
  4. Diffusion(扩散模型)
  5. 四大范式对比总结

CNN(端到端学习)

核心思想

CNN 是最”直白”的范式:给一个输入 X,直接预测一个输出 Y,用损失函数衡量预测的好坏,梯度反传更新参数。这种”输入→网络→输出→loss→反传”的完整链条,就是所谓的”端到端(End-to-End)”。

比喻: CNN 像是一个学生直接背答案。你给他一道题(图像),他直接给你一个答案(类别/分割图),老师(loss)告诉他答得有多差,他就调整自己的”答题方式”。全程没有中间人。

与后面的生成模型不同,CNN 的输出空间是确定性的——同一张图,每次推理结果一样。这是判别式模型(Discriminative Model)的典型特征。

三种典型任务

图像分类(Classification):
  输入: (B, 3, H, W)  ─────► 网络 ─────► (B, C)  → CrossEntropy Loss
        一张图                              C 个类别的 logits

语义分割(Segmentation):
  输入: (B, 3, H, W)  ─────► 网络 ─────► (B, C, H, W)  → CrossEntropy / Dice Loss
        一张图                              每个像素的 C 类概率

图像超分(Super Resolution):
  输入: (B, 3, h, w)  ─────► 网络 ─────► (B, 3, H, W)  → L1 / L2 / Perceptual Loss
        低分辨率                             高分辨率
        H = s·h, W = s·w  (s 为放大倍数)

训练与推理流程

训练流程:
  ① 采样 mini-batch: x (B, 3, H, W),  y_true (B,) 或 (B, C, H, W)
  ② 前向传播: y_pred = model(x)
  ③ 计算损失: loss = criterion(y_pred, y_true)
  ④ 反向传播: loss.backward()
  ⑤ 参数更新: optimizer.step()

推理流程(确定性,无随机性):
  ① with torch.no_grad():
  ② y_pred = model(x)
  ③ 后处理(argmax / softmax / 上采样等)

常用损失函数

任务损失函数形式
分类Cross-Entropy
分割(类别不均衡)Dice Loss
回归/重建MSE (L2)
回归/重建(鲁棒)MAE (L1)
超分/风格Perceptual Loss 为预训练 VGG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
 
# ─────────────────────────────────────────────
# 1.1 图像分类(端到端最简示例)
# ─────────────────────────────────────────────
class SimpleCNNClassifier(nn.Module):
    """
    输入: (B, 3, 32, 32)
    输出: (B, num_classes)  ← logits(未经 softmax)
    """
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),           # (B, 64, 16, 16)
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d(4),   # (B, 128, 4, 4)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
 
    def forward(self, x):
        return self.classifier(self.features(x))
 
 
def train_classifier(model, train_loader, epochs=5, lr=1e-3):
    """端到端分类训练循环"""
    device    = next(model.parameters()).device
    criterion = nn.CrossEntropyLoss()           # 内部含 softmax,输入 logits
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
 
    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
 
            # ── 标准端到端步骤 ──
            y_pred = model(x)                            # (B, C) logits
            loss   = criterion(y_pred, y)                # scalar
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # ──────────────────
 
            total_loss += loss.item() * x.size(0)
            correct    += (y_pred.argmax(1) == y).sum().item()
            total      += x.size(0)
 
        scheduler.step()
        print(f"Epoch {epoch+1}: loss={total_loss/total:.4f}, "
              f"acc={correct/total:.4f}")
 
 
@torch.no_grad()
def inference_classifier(model, x):
    """推理:确定性,无随机"""
    model.eval()
    logits = model(x)                     # (B, C)
    probs  = F.softmax(logits, dim=-1)    # (B, C) 概率
    preds  = logits.argmax(dim=-1)        # (B,)   预测类别
    return preds, probs
 
 
# ─────────────────────────────────────────────
# 1.2 语义分割(像素级分类)
# ─────────────────────────────────────────────
class DiceLoss(nn.Module):
    """
    Dice Loss:对类别不均衡鲁棒(医学图像分割常用)
    输入 pred: (B, C, H, W) 经 softmax 后的概率
    输入 target: (B, H, W) 整数类别标签
    """
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
 
    def forward(self, pred, target):
        C = pred.shape[1]
        # one-hot 编码 target
        target_oh = F.one_hot(target, C).permute(0,3,1,2).float()  # (B, C, H, W)
        pred = pred.softmax(dim=1)
 
        # 按 batch 和 class 计算 Dice
        intersection = (pred * target_oh).sum(dim=(2,3))   # (B, C)
        union        = pred.sum(dim=(2,3)) + target_oh.sum(dim=(2,3))
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()
 
 
class PerceptualLoss(nn.Module):
    """
    感知损失(Perceptual Loss):用预训练 VGG 提取特征,
    在特征空间而非像素空间度量重建质量
    适用于超分辨率、风格迁移
    """
    def __init__(self):
        super().__init__()
        import torchvision.models as models
        vgg   = models.vgg16(pretrained=False)
        # 取前 16 层(relu3_3 之前的特征)
        self.features = nn.Sequential(*list(vgg.features.children())[:16])
        for p in self.features.parameters():
            p.requires_grad = False   # 冻结,不参与训练
 
    def forward(self, pred, target):
        # pred, target: (B, 3, H, W),假设已归一化到 [0,1]
        feat_pred   = self.features(pred)
        feat_target = self.features(target.detach())
        return F.mse_loss(feat_pred, feat_target)
 
 
# ── 快速验证 ──
model = SimpleCNNClassifier(num_classes=10)
# 构造假数据
x = torch.randn(64, 3, 32, 32)
y = torch.randint(0, 10, (64,))
ds = TensorDataset(x, y)
loader = DataLoader(ds, batch_size=16)
train_classifier(model, loader, epochs=2)

2. GAN(生成对抗网络)

核心思想

GAN 的思想来自一个博弈论比喻:造假币的人(Generator)vs 验钞员(Discriminator)

  • Generator(G):接收一个随机噪声 ,生成假图像,目标是骗过 D
  • Discriminator(D):接收真实图像或假图像,判断真假,目标是不被骗

两者在训练中相互博弈、共同进化:G 越来越擅长造假,D 越来越擅长辨别,最终 G 生成的图像以假乱真(D 的输出趋近于 0.5,无法判断真假)。

关键洞见: 我们无法直接写出”真实图像的分布长什么样”,但我们可以通过对抗训练让 G 隐式地学到这个分布。G 永远不需要知道真实数据长什么样,只需要知道 D 给了多少分。

训练流程(双人博弈,交替更新)

真实数据 x_real ~ p_data           噪声 z ~ N(0, I)
         │                                │
         │          Generator G           │
         │          z → G(z) = x_fake     │
         │                │               │
         ▼                ▼               
  Discriminator D 接收两类输入:
    D(x_real) → 希望接近 1(真的打高分)
    D(x_fake) → 希望接近 0(假的打低分)

─── 第一步:更新 D(固定 G)───────────────
  loss_D = -[log D(x_real) + log(1 - D(G(z)))]
  目标:最大化辨别能力(最小化 loss_D)

─── 第二步:更新 G(固定 D)───────────────
  loss_G = -log D(G(z))          ← 非饱和版本(实践常用)
        或 log(1 - D(G(z)))      ← 原始论文版(梯度小,难训练)
  目标:让 D 把假图打高分(骗过 D)

注意:两步交替,不同时更新!

训练不稳定问题与改进

原始 GAN 的问题:
  ✗ 梯度消失:D 太强时,D(G(z))≈0,log(1-D(G(z)))≈0,G 无梯度
  ✗ 模式崩溃(Mode Collapse):G 只生成一种图像就能骗过 D
  ✗ 训练不稳定,对超参敏感

改进方向:
  DCGAN:用卷积架构替代全连接,BatchNorm,LeakyReLU
  WGAN:用 Wasserstein 距离替代 JS 散度,梯度裁剪
  WGAN-GP:用梯度惩罚替代梯度裁剪(更稳定)
  StyleGAN:噪声注入、AdaIN、渐进式训练(高分辨率合成)
  条件 GAN:G 和 D 都以类别标签为条件

输入输出形状

Generator G:
  输入:  z (B, latent_dim)        ← 纯随机噪声,如 latent_dim=100
  输出:  x_fake (B, C, H, W)      ← 生成的假图像,值域 [-1, 1] 或 [0, 1]

Discriminator D:
  输入:  x (B, C, H, W)           ← 真实图像 或 假图像
  输出:  score (B, 1)             ← 真实概率(DCGAN)或真实性分数(WGAN)

训练时的 Tensor 流:
  z: (B, 100)  →  G  →  x_fake: (B, 3, 64, 64)
  x_real: (B, 3, 64, 64)  →  D  →  real_score: (B, 1)
  x_fake: (B, 3, 64, 64)  →  D  →  fake_score: (B, 1)
import torch
import torch.nn as nn
import torch.nn.functional as F
 
# ─────────────────────────────────────────────
# 2.1 DCGAN(深度卷积 GAN)
# ─────────────────────────────────────────────
 
class DCGANGenerator(nn.Module):
    """
    z: (B, latent_dim) → x_fake: (B, 3, 64, 64)
    使用转置卷积逐步上采样,无池化层
    """
    def __init__(self, latent_dim=100, img_channels=3, feature_map=64):
        super().__init__()
        fm = feature_map
        self.net = nn.Sequential(
            # 输入 z 形状: (B, latent_dim, 1, 1)
            nn.ConvTranspose2d(latent_dim, fm*8, 4, 1, 0, bias=False),  # → (B, 512, 4, 4)
            nn.BatchNorm2d(fm*8), nn.ReLU(True),
 
            nn.ConvTranspose2d(fm*8, fm*4, 4, 2, 1, bias=False),        # → (B, 256, 8, 8)
            nn.BatchNorm2d(fm*4), nn.ReLU(True),
 
            nn.ConvTranspose2d(fm*4, fm*2, 4, 2, 1, bias=False),        # → (B, 128, 16, 16)
            nn.BatchNorm2d(fm*2), nn.ReLU(True),
 
            nn.ConvTranspose2d(fm*2, fm,   4, 2, 1, bias=False),        # → (B, 64, 32, 32)
            nn.BatchNorm2d(fm), nn.ReLU(True),
 
            nn.ConvTranspose2d(fm, img_channels, 4, 2, 1, bias=False),  # → (B, 3, 64, 64)
            nn.Tanh()   # 输出 [-1, 1],对应归一化的图像
        )
 
    def forward(self, z):
        # z: (B, latent_dim) → 变形为 (B, latent_dim, 1, 1)
        z = z.view(z.size(0), -1, 1, 1)
        return self.net(z)   # (B, 3, 64, 64)
 
 
class DCGANDiscriminator(nn.Module):
    """
    x: (B, 3, 64, 64) → score: (B, 1)
    使用步长卷积下采样,避免 MaxPool(会丢失空间信息)
    用 LeakyReLU(非 ReLU):防止梯度死亡,让负值有梯度
    """
    def __init__(self, img_channels=3, feature_map=64):
        super().__init__()
        fm = feature_map
        self.net = nn.Sequential(
            # 注意:第一层不用 BN(作者经验)
            nn.Conv2d(img_channels, fm, 4, 2, 1, bias=False),      # → (B, 64, 32, 32)
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(fm, fm*2, 4, 2, 1, bias=False),              # → (B, 128, 16, 16)
            nn.BatchNorm2d(fm*2), nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(fm*2, fm*4, 4, 2, 1, bias=False),            # → (B, 256, 8, 8)
            nn.BatchNorm2d(fm*4), nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(fm*4, fm*8, 4, 2, 1, bias=False),            # → (B, 512, 4, 4)
            nn.BatchNorm2d(fm*8), nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(fm*8, 1, 4, 1, 0, bias=False),               # → (B, 1, 1, 1)
            nn.Sigmoid()   # 输出 (0, 1) 概率
        )
 
    def forward(self, x):
        return self.net(x).view(-1, 1)   # (B, 1)
 
 
def train_dcgan(G, D, dataloader, epochs=50, latent_dim=100, lr=2e-4):
    """
    DCGAN 训练循环:D 和 G 交替更新
    关键:D 更新时固定 G,G 更新时固定 D
    """
    device   = next(G.parameters()).device
    opt_D    = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_G    = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    # betas=(0.5, 0.999):DCGAN 论文推荐,β1=0.5 比默认 0.9 更稳定
 
    criterion = nn.BCELoss()   # Binary Cross Entropy
 
    real_label = 1.0   # 真实图像的目标标签
    fake_label = 0.0   # 假图像的目标标签
 
    for epoch in range(epochs):
        for i, (x_real, _) in enumerate(dataloader):
            x_real = x_real.to(device)
            B      = x_real.size(0)
 
            # ═══════ 第一步:更新 Discriminator D ═══════
            # 目标:最大化 log D(x_real) + log(1 - D(G(z)))
            D.zero_grad()
 
            # 真实图像的损失
            real_labels = torch.full((B, 1), real_label, device=device)
            real_score  = D(x_real)                  # (B, 1)
            loss_D_real = criterion(real_score, real_labels)
            loss_D_real.backward()
 
            # 假图像的损失
            z           = torch.randn(B, latent_dim, device=device)  # (B, 100)
            x_fake      = G(z)                       # (B, 3, 64, 64)
            fake_labels = torch.full((B, 1), fake_label, device=device)
            fake_score  = D(x_fake.detach())         # detach:不反传到 G
            loss_D_fake = criterion(fake_score, fake_labels)
            loss_D_fake.backward()
 
            loss_D = loss_D_real + loss_D_fake
            opt_D.step()
 
            # ═══════ 第二步:更新 Generator G ═══════
            # 目标:最大化 log D(G(z)),即让 D 把假图打成真的
            G.zero_grad()
 
            # 用真实标签欺骗 D(非饱和目标,梯度更大)
            fake_score_for_G = D(x_fake)             # 不 detach,梯度流回 G
            loss_G = criterion(fake_score_for_G, real_labels)  # 假图配真标签
            loss_G.backward()
            opt_G.step()
 
            if i % 100 == 0:
                print(f"Epoch {epoch} Step {i}: "
                      f"loss_D={loss_D.item():.4f}, loss_G={loss_G.item():.4f}")
 
 
@torch.no_grad()
def generate_images(G, n=16, latent_dim=100):
    """
    推理:GAN 生成只需一步,无需迭代
    输入随机噪声,输出图像
    """
    G.eval()
    z = torch.randn(n, latent_dim)   # (16, 100) 随机采样
    x_fake = G(z)                    # (16, 3, 64, 64) 直接生成
    # 反归一化 [-1,1] → [0,1]
    x_fake = (x_fake + 1) / 2
    return x_fake
 
 
# ─────────────────────────────────────────────
# 2.2 WGAN-GP(更稳定的 GAN 变体)
# ─────────────────────────────────────────────
 
def gradient_penalty(D, x_real, x_fake, device):
    """
    梯度惩罚:强制 D 满足 1-Lipschitz 约束
    在真假样本的插值点上,要求梯度范数接近 1
    """
    B = x_real.size(0)
    # 随机插值系数 α ~ Uniform(0,1)
    alpha    = torch.rand(B, 1, 1, 1, device=device)
    x_interp = (alpha * x_real + (1 - alpha) * x_fake.detach()).requires_grad_(True)
 
    score    = D(x_interp)
    grad     = torch.autograd.grad(
        outputs=score, inputs=x_interp,
        grad_outputs=torch.ones_like(score),
        create_graph=True, retain_graph=True
    )[0]   # (B, C, H, W)
 
    grad_norm = grad.view(B, -1).norm(2, dim=1)   # (B,)
    # 惩罚梯度范数偏离 1 的程度
    gp = ((grad_norm - 1) ** 2).mean()
    return gp
 
 
def train_wgan_gp(G, D, dataloader, epochs=50, latent_dim=100,
                  lr=1e-4, lambda_gp=10, n_critic=5):
    """
    WGAN-GP 训练
    n_critic=5:每更新 1 次 G,先更新 5 次 D(D 需要更充分训练)
    无 BN in D(梯度惩罚与 BN 不兼容)
    """
    device = next(G.parameters()).device
    opt_D  = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.0, 0.9))
    opt_G  = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.0, 0.9))
 
    for epoch in range(epochs):
        for i, (x_real, _) in enumerate(dataloader):
            x_real = x_real.to(device)
            B      = x_real.size(0)
 
            # 更新 D(n_critic 次)
            for _ in range(n_critic):
                z      = torch.randn(B, latent_dim, device=device)
                x_fake = G(z).detach()
 
                # Wasserstein 距离估计:E[D(real)] - E[D(fake)]
                # WGAN 的 D 不是分类器,是"评分器"(Critic),无 Sigmoid
                loss_D = -D(x_real).mean() + D(x_fake).mean()
                gp     = gradient_penalty(D, x_real, x_fake, device)
                loss_D = loss_D + lambda_gp * gp
 
                D.zero_grad(); loss_D.backward(); opt_D.step()
 
            # 更新 G(1 次)
            z      = torch.randn(B, latent_dim, device=device)
            x_fake = G(z)
            loss_G = -D(x_fake).mean()   # 最大化 D(G(z))
 
            G.zero_grad(); loss_G.backward(); opt_G.step()
 
 
# 快速验证
G = DCGANGenerator(latent_dim=100)
D = DCGANDiscriminator()
z = torch.randn(4, 100)
x_fake = G(z)
print("G output:", x_fake.shape)          # (4, 3, 64, 64)
score  = D(x_fake)
print("D output:", score.shape)           # (4, 1)

3. Auto-Regressive(自回归生成)

核心思想

自回归模型是语言模型的主流范式。其核心假设非常简单:一个序列的联合概率,可以分解为每个位置的条件概率之积

比喻: 就像写作文一样,每次写下一个字,都要基于前面已经写下的所有内容来决定下一个字写什么。GPT 做的事情正是这个:给定前面的 token,预测下一个最可能的 token(Next-Token Prediction)

训练目标极其简单:最大化训练数据在模型下的似然(等价于最小化交叉熵)

Next-Token Prediction

训练时:一次并行预测所有位置的下一个 token

输入序列:  [A,  B,  C,  D,  E]   (T=5 个 token)
目标序列:  [B,  C,  D,  E,  F]   (向右移一位)

因果 mask(下三角矩阵)确保位置 t 只能看到 t 之前的 token:
  位置 1 (A):只看 A,预测 B
  位置 2 (B):看 A B,预测 C
  位置 3 (C):看 A B C,预测 D
  ...

Loss = CrossEntropy(pred[0:T-1], target[1:T])
     = 平均每个位置的预测 loss(Teacher Forcing)

"Teacher Forcing":训练时用真实的 token 做历史(而非模型的预测),
这让训练稳定且并行,但推理时必须串行生成。

推理:自回归串行生成

推理时:串行,每次只预测一个新 token

Prompt: [A, B, C]
        │
  step 1: 模型看 [A,B,C]       → 预测 x_4,采样得到 D
  step 2: 模型看 [A,B,C,D]     → 预测 x_5,采样得到 E
  step 3: 模型看 [A,B,C,D,E]   → 预测 x_6,采样得到 F
  ...
  直到生成 <EOS> 或达到最大长度

时间复杂度:O(T²·d)
  每步都要重新计算所有历史 token 的 K/V → 很慢
  → KV Cache 优化!

KV Cache:推理加速的核心技术

没有 KV Cache(每步重算所有历史):
  Step 1: 计算 [A,B,C]   的 K/V,预测第 4 个 token
  Step 2: 重算 [A,B,C,D] 的 K/V,预测第 5 个 token  ← A/B/C 的 K/V 重复计算了!
  Step 3: 重算 [A,B,C,D,E] 的 K/V ...

有 KV Cache(缓存历史 K/V,每步只算新 token 的 K/V):
  Step 1: 计算 [A,B,C]   的 K/V,存入 Cache,预测第 4 个 token
  Step 2: 只算 [D]       的 K/V,拼到 Cache,预测第 5 个 token  ✓ 不重复计算
  Step 3: 只算 [E]       的 K/V,拼到 Cache ...

KV Cache 的代价:
  显存:每层存 K/V,形状 (2, B, T, n_kv_heads, d_k)
  batch=1, 32层, T=4096, 8 heads, d_k=128 → 约 4GB(FP16)
  → 这是 LLM 推理显存的主要消耗!

GQA(分组查询注意力)本质上就是为了减少 KV Cache 大小:
  MHA:  K/V 各 32 个 head → Cache 满
  GQA:  K/V 各  8 个 head → Cache 缩减 4 倍

采样策略

Greedy(贪心):每步取概率最大的 token
  优点:确定性,快
  缺点:容易重复,不多样

Temperature Sampling:logits / T,T<1 变尖锐(更保守),T>1 变平坦(更多样)
  T→0:退化为贪心
  T→∞:变为均匀随机

Top-K Sampling:只从概率最高的 K 个 token 中采样
  K=50:常用设置,过滤掉长尾低概率词

Top-P(Nucleus)Sampling:累积概率超过 P 的最小集合
  P=0.9:动态调整候选集大小,比 Top-K 更自适应

实践:通常 Temperature + Top-P 组合使用
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
# ─────────────────────────────────────────────
# 3.1 极简 GPT(仅展示核心结构)
# ─────────────────────────────────────────────
 
class CausalSelfAttentionWithKVCache(nn.Module):
    """
    带 KV Cache 的因果自注意力
    训练时:一次并行处理整个序列(T > 1,需要因果 mask)
    推理时:每次只输入 1 个新 token(T = 1,不需要 mask,用 cache 拼接历史)
    """
    def __init__(self, d_model=256, n_heads=8):
        super().__init__()
        self.n_heads = n_heads
        self.d_k     = d_model // n_heads
        self.W_qkv   = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o     = nn.Linear(d_model, d_model, bias=False)
 
    def forward(self, x, kv_cache=None, use_cache=False):
        """
        x:        (B, T, D)
        kv_cache: None 或 (K_cache, V_cache),各 (B, h, T_past, d_k)
        use_cache: 推理时设为 True
 
        返回: (output, new_kv_cache)
        """
        B, T, D = x.shape
 
        q, k, v = self.W_qkv(x).chunk(3, dim=-1)   # 各 (B, T, D)
 
        def split(t):
            return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
            # → (B, h, T, d_k)
 
        q, k, v = split(q), split(k), split(v)
 
        # 拼接历史 KV Cache
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k], dim=2)  # (B, h, T_past+T, d_k)
            v = torch.cat([v_cache, v], dim=2)
 
        new_cache = (k, v) if use_cache else None
 
        # 注意力计算
        S      = k.shape[2]  # 当前总序列长度(含历史)
        scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_k)   # (B, h, T, S)
 
        # 因果 mask:训练时必须,推理单步时 T=1 可以省略
        if T > 1:
            # 只 mask query 对应的位置(T × S 的子矩阵)
            mask = torch.triu(torch.full((T, S), float('-inf'), device=x.device), 1)
            # 如果有 cache,mask 的列从 S-T 开始才有效
            if kv_cache is not None:
                T_past = S - T
                mask   = torch.tril(torch.ones(T, S, device=x.device),
                                    diagonal=T_past) == 0
                scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0),
                                            float('-inf'))
            else:
                scores = scores + mask
 
        attn = F.softmax(scores, dim=-1)
        out  = (attn @ v).transpose(1,2).contiguous().view(B, T, D)
        return self.W_o(out), new_cache
 
 
class GPTBlock(nn.Module):
    def __init__(self, d_model=256, n_heads=8, ffn_ratio=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn  = CausalSelfAttentionWithKVCache(d_model, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn   = nn.Sequential(
            nn.Linear(d_model, d_model * ffn_ratio),
            nn.GELU(),
            nn.Linear(d_model * ffn_ratio, d_model)
        )
 
    def forward(self, x, kv_cache=None, use_cache=False):
        attn_out, new_cache = self.attn(self.norm1(x), kv_cache, use_cache)
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x, new_cache
 
 
class MiniGPT(nn.Module):
    def __init__(self, vocab_size=1000, d_model=256, n_heads=8,
                 n_layers=4, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(max_len, d_model)
        self.blocks    = nn.ModuleList([
            GPTBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.norm    = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
 
    def forward(self, token_ids, kv_caches=None, use_cache=False):
        """
        token_ids: (B, T) int
        kv_caches: None 或 长度为 n_layers 的列表,每个元素是 (K, V)
 
        训练时: kv_caches=None, use_cache=False
        推理时: kv_caches=past_caches, use_cache=True
        """
        B, T  = token_ids.shape
        # 位置编码:如果有 cache,起始位置是 T_past
        T_past = kv_caches[0][0].shape[2] if kv_caches else 0
        pos    = torch.arange(T_past, T_past + T, device=token_ids.device)
 
        x = self.token_emb(token_ids) + self.pos_emb(pos)   # (B, T, D)
 
        new_kv_caches = []
        for i, block in enumerate(self.blocks):
            cache = kv_caches[i] if kv_caches else None
            x, new_cache = block(x, cache, use_cache)
            new_kv_caches.append(new_cache)
 
        logits = self.lm_head(self.norm(x))   # (B, T, vocab_size)
 
        if use_cache:
            return logits, new_kv_caches
        return logits
 
 
def train_gpt(model, dataloader, epochs=5, lr=3e-4):
    """
    GPT 训练:Next-Token Prediction
    输入序列 [x_0, x_1, ..., x_{T-1}]
    预测目标 [x_1, x_2, ..., x_T](向右偏移一位)
 
    注意:Teacher Forcing——用真实 token 作为历史,不用模型自己预测的
    """
    device    = next(model.parameters()).device
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.1)
 
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in dataloader:
            token_ids = batch.to(device)   # (B, T)
 
            # ── 核心训练设置 ──
            # 输入:token_ids[:, :-1]   即 [x_0, ..., x_{T-2}]  形状 (B, T-1)
            # 目标:token_ids[:, 1:]    即 [x_1, ..., x_{T-1}]  形状 (B, T-1)
            # 这样位置 t 的预测对应目标 t+1,实现 next-token prediction
 
            logits = model(token_ids[:, :-1])   # (B, T-1, vocab_size)
            targets = token_ids[:, 1:]          # (B, T-1)
 
            # CrossEntropy:把 (B, T-1, V) 展平成 (B*(T-1), V)
            loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),   # (B*(T-1), vocab)
                targets.reshape(-1)                    # (B*(T-1),)
            )
 
            optimizer.zero_grad()
            loss.backward()
            # 梯度裁剪:防止 Transformer 训练时梯度爆炸
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
 
            total_loss += loss.item()
 
        print(f"Epoch {epoch+1}: avg loss = {total_loss/len(dataloader):.4f}")
 
 
@torch.no_grad()
def generate_with_kv_cache(model, prompt_ids, max_new_tokens=50,
                            temperature=1.0, top_p=0.9):
    """
    带 KV Cache 的自回归生成
    核心逻辑:
      1. 先处理 prompt,建立初始 KV Cache
      2. 每步只输入 1 个新 token,利用 Cache 避免重算历史
    """
    model.eval()
    device     = prompt_ids.device
    token_ids  = prompt_ids                        # (B, T_prompt)
    kv_caches  = None
 
    # ── Phase 1: Prefill(预填充)──────────────────────────────
    # 一次性处理整个 prompt,建立 KV Cache
    logits, kv_caches = model(token_ids, kv_caches=None, use_cache=True)
    # logits: (B, T_prompt, vocab), kv_caches: 每层的 (K, V)
 
    # 取最后一个位置的 logits,采样第一个新 token
    next_token = sample_token(logits[:, -1, :], temperature, top_p)  # (B, 1)
 
    generated = [next_token]
 
    # ── Phase 2: Decode(逐步解码)────────────────────────────
    for step in range(max_new_tokens - 1):
        # 每次只输入 1 个 token,形状 (B, 1)
        logits, kv_caches = model(next_token, kv_caches=kv_caches, use_cache=True)
        # logits: (B, 1, vocab)
 
        next_token = sample_token(logits[:, -1, :], temperature, top_p)  # (B, 1)
        generated.append(next_token)
 
        # 遇到 EOS 提前停止(实际使用时需要设定 eos_token_id)
        # if (next_token == eos_token_id).all(): break
 
    return torch.cat([prompt_ids] + generated, dim=1)   # (B, T_prompt + max_new)
 
 
def sample_token(logits, temperature=1.0, top_p=0.9):
    """
    Top-P(Nucleus)采样
    logits: (B, vocab_size)
    """
    logits = logits / max(temperature, 1e-6)
    probs  = F.softmax(logits, dim=-1)        # (B, vocab)
 
    # Top-P:按概率从大到小排序,取累积概率 ≤ P 的集合
    sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
    cumsum    = sorted_probs.cumsum(dim=-1)
    # 找到超过 top_p 的位置,把后面的 token 概率置零
    remove    = cumsum - sorted_probs > top_p
    sorted_probs[remove] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)  # 重归一化
 
    # 采样
    sample_idx     = torch.multinomial(sorted_probs, 1)  # (B, 1) 在排序后的索引
    next_token_idx = sorted_idx.gather(1, sample_idx)    # 还原到原始词表索引
    return next_token_idx   # (B, 1)
 
 
# 验证
model = MiniGPT(vocab_size=1000, d_model=128, n_heads=4, n_layers=2)
# 训练(仿造数据)
fake_data = torch.randint(0, 1000, (32, 64))
ds     = torch.utils.data.TensorDataset(fake_data)
loader = torch.utils.data.DataLoader(ds, batch_size=8)
# 解包每个 batch
class FlatLoader:
    def __init__(self, dl): self.dl = dl
    def __iter__(self):
        for (x,) in self.dl: yield x
    def __len__(self): return len(self.dl)
 
train_gpt(model, FlatLoader(loader), epochs=2)
 
# 推理
prompt = torch.randint(0, 1000, (1, 5))
output = generate_with_kv_cache(model, prompt, max_new_tokens=10)
print("Generated shape:", output.shape)   # (1, 15)

4. Diffusion(扩散模型)

核心思想

扩散模型来自一个物理比喻:墨水滴入水中会慢慢扩散,最终变成均匀的颜色

  • 前向过程(Forward Process):把一张真实图像,逐步加入高斯噪声,经过 步后完全变成纯噪声。这个过程是固定的,不需要学习。
  • 反向过程(Reverse Process):训练一个神经网络,学会逆转噪声——从纯噪声出发,一步步去除噪声,最终还原出清晰图像。

为什么不直接让网络学 noise → image? 因为直接映射太难(噪声空间和图像空间的分布差异太大)。扩散模型把这个困难问题拆解成 T 个小步骤,每一步只需预测”从 “的微小变化,每步都容易学。

比喻: 就像雕塑家把大理石雕成雕像,不是一刀刻出来的,而是每次只去掉一点点多余的石料,T 步后得到完美作品。网络学的就是”每一步应该去掉哪块石料”。

前向过程(加噪过程)

给定真实图像 x_0,逐步加噪:
  x_t = √(ᾱ_t) · x_0 + √(1 - ᾱ_t) · ε,  ε ~ N(0, I)

其中:
  β_t:第 t 步的噪声强度(噪声调度),β_1 < β_2 < ... < β_T
  α_t = 1 - β_t
  ᾱ_t = ∏_{s=1}^{t} α_s  (累积乘积)

  t=0:   x_0 = 原始图像,完全清晰
  t=T/2: x_{T/2} 开始有明显噪声,但还能看出大致内容
  t=T:   x_T ≈ N(0, I),纯噪声,完全看不出原始内容

关键性质:
  ① 可以跳步:给定 x_0,直接一步算出任意时刻的 x_t(不需要逐步迭代)
  ② 条件分布:q(x_t | x_0) = N(x_t; √ᾱ_t·x_0, (1-ᾱ_t)·I)

三种预测目标

这是扩散模型最重要的设计选择之一:网络 究竟预测什么?

方式一:预测噪声 ε(ε-prediction,最常用,DDPM 默认)
  x_t = √ᾱ_t · x_0 + √(1-ᾱ_t) · ε
  网络学习:ε_θ(x_t, t) ≈ ε(反向解出 x_0 = (x_t - √(1-ᾱ_t)·ε_θ) / √ᾱ_t)
  Loss: ||ε - ε_θ(x_t, t)||²

方式二:预测原图 x_0(x-prediction)
  网络学习:x_θ(x_t, t) ≈ x_0
  可以通过 x_0 推出 ε,等价但训练信号不同
  Loss: ||x_0 - x_θ(x_t, t)||²

方式三:预测 v(v-prediction,Stable Diffusion v2 / EDM 使用)
  v = √ᾱ_t · ε - √(1-ᾱ_t) · x_0  (一种特殊的线性组合)
  网络学习:v_θ(x_t, t) ≈ v
  优点:在 t 很大或很小时梯度更稳定,数值条件数更好
  Loss: ||v - v_θ(x_t, t)||²

三者的等价关系(可互相推导):
  已知 x_t 和任一预测,都可以推出其他两者
  x_0 = √ᾱ_t · x_t - √(1-ᾱ_t) · v_pred
      = (x_t - √(1-ᾱ_t) · ε_pred) / √ᾱ_t

DDPM vs DDIM 推理

DDPM(Denoising Diffusion Probabilistic Models):
  ● 推理步数 = T = 1000(很慢!每次推理需要 1000 步 UNet 前向)
  ● 每步有随机性(加回部分噪声),是随机采样
  ● x_{t-1} = μ_θ(x_t, t) + σ_t · z,z ~ N(0,I)

DDIM(Denoising Diffusion Implicit Models):
  ● 推理步数可以是 50、20、10 甚至更少!
  ● 每步是确定性的(η=0 时),相同噪声 → 相同图像
  ● 核心思想:跳步!不需要走完所有 1000 步,可以直接跳从 t=1000 到 t=900

DDIM 更新公式(η=0 时完全确定性):
  x_{t-1} = √ᾱ_{t-1} · x̂_0(x_t)
           + √(1-ᾱ_{t-1} - η²σ_t²) · ε_θ(x_t, t)
           + η · σ_t · z

  其中 x̂_0(x_t) = (x_t - √(1-ᾱ_t) · ε_θ(x_t, t)) / √ᾱ_t
      是当前步预测的"去噪后原图"

  η=0:完全确定性,速度快,但多样性略低
  η=1:退化回 DDPM,有随机性,多样性高

Latent Space vs Pixel Space

Pixel Space Diffusion(原始 DDPM):
  直接在像素空间做扩散
  x_t: (B, 3, H, W)  ← 对图像像素加噪/去噪
  问题:H=W=256 → 196608 维,计算量极大

Latent Diffusion(Stable Diffusion 等):
  先用 VAE 编码到低维隐空间,在隐空间做扩散,最后解码
  ┌─────────────────────────────────────────────────┐
  │  图像 x: (B, 3, 512, 512)                       │
  │       ↓  VAE Encoder(冻结,预训练)              │
  │  隐变量 z: (B, 4, 64, 64)   ← 压缩 64 倍        │
  │       ↓  加噪/去噪(扩散过程在这里发生!)         │
  │  z_denoised: (B, 4, 64, 64)                     │
  │       ↓  VAE Decoder(冻结,预训练)              │
  │  生成图像: (B, 3, 512, 512)                      │
  └─────────────────────────────────────────────────┘

  优点:
    ✦ 隐空间更小(4×64×64 vs 3×512×512),计算量降低 ~48×
    ✦ 隐空间更"语义化",扩散效果更好
    ✦ UNet 只需要处理 64×64 的特征图

条件生成(文生图):
  UNet 接受额外的条件 c(文本 embedding)
  通过 Cross-Attention 把文本信息注入 UNet 中间层
  ε_θ(z_t, t, c) → 条件去噪
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
# ─────────────────────────────────────────────
# 4.1 噪声调度器(Noise Scheduler)
# ─────────────────────────────────────────────
 
class DDPMScheduler:
    """
    管理扩散过程的所有参数:β_t, α_t, ᾱ_t
    以及前向加噪和反向去噪的步骤
    """
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02,
                 schedule='linear'):
        self.T = T
 
        # 噪声调度:从小到大的 β_t 序列
        if schedule == 'linear':
            # 线性:最简单,DDPM 原论文使用
            betas = torch.linspace(beta_start, beta_end, T)
        elif schedule == 'cosine':
            # 余弦:在 t 接近 0 时变化更平滑,生成质量更好
            steps = T + 1
            x     = torch.linspace(0, T, steps)
            alphas_cumprod = torch.cos(((x/T) + 0.008) / 1.008 * math.pi/2) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
            betas = betas.clamp(0, 0.999)
        elif schedule == 'sqrt':
            # 平方根调度:适合一致性模型
            betas = torch.linspace(beta_start**0.5, beta_end**0.5, T)**2
 
        self.betas  = betas                          # (T,)
        self.alphas = 1.0 - betas                    # (T,)
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)   # ᾱ_t (T,)
        self.sqrt_alphas_cumprod       = self.alphas_cumprod.sqrt()
        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod).sqrt()
 
    def q_sample(self, x0, t, noise=None):
        """
        前向加噪(Forward Process):一步算出任意时刻的 x_t
        x_t = √ᾱ_t · x_0 + √(1-ᾱ_t) · ε
        输入: x0 (B, C, H, W), t (B,) 时间步索引
        输出: x_t (B, C, H, W)
        """
        if noise is None:
            noise = torch.randn_like(x0)
 
        # 取 t 对应的调度参数,并 reshape 为 (B, 1, 1, 1) 便于广播
        sqrt_a  = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
 
        x_t = sqrt_a * x0 + sqrt_1a * noise
        return x_t, noise   # 返回 x_t 和 ε,训练时 ε 是监督目标
 
    def predict_x0_from_eps(self, x_t, t, eps_pred):
        """从预测的噪声 ε_pred 还原 x_0"""
        sqrt_a  = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
        return (x_t - sqrt_1a * eps_pred) / sqrt_a
 
    def predict_x0_from_v(self, x_t, t, v_pred):
        """从预测的 v 还原 x_0(v-prediction)"""
        sqrt_a  = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
        # v = √ᾱ·ε - √(1-ᾱ)·x_0  →  x_0 = √ᾱ·x_t - √(1-ᾱ)·v
        return sqrt_a * x_t - sqrt_1a * v_pred
 
    def predict_eps_from_v(self, x_t, t, v_pred):
        """从预测的 v 还原 ε"""
        sqrt_a  = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        sqrt_1a = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
        # ε = √ᾱ·v + √(1-ᾱ)·x_t
        return sqrt_a * v_pred + sqrt_1a * x_t
 
    def ddpm_step(self, x_t, t_idx, eps_pred):
        """
        DDPM 反向一步(随机):x_t → x_{t-1}
        """
        t    = torch.tensor([t_idx])
        beta = self.betas[t_idx]
        a    = self.alphas[t_idx]
        a_bar = self.alphas_cumprod[t_idx]
        a_bar_prev = self.alphas_cumprod[t_idx-1] if t_idx > 0 \
                     else torch.tensor(1.0)
 
        # 均值(DDPM 论文 Eq. 11)
        coef1 = beta * a_bar_prev.sqrt() / (1 - a_bar)
        coef2 = (1 - a_bar_prev) * a.sqrt() / (1 - a_bar)
        x0_pred = self.predict_x0_from_eps(
            x_t.unsqueeze(0), t, eps_pred.unsqueeze(0)
        ).squeeze(0)
        mu = coef1 * x0_pred + coef2 * x_t
 
        # 方差(固定为 β̃_t)
        sigma = ((1 - a_bar_prev) / (1 - a_bar) * beta).sqrt()
 
        z = torch.randn_like(x_t) if t_idx > 0 else 0
        return mu + sigma * z   # x_{t-1}
 
 
# ─────────────────────────────────────────────
# 4.2 简化版 UNet(去噪网络的主体)
# ─────────────────────────────────────────────
 
class SinusoidalTimeEmb(nn.Module):
    """
    时间步 t 的编码:类似 Transformer 的位置编码
    网络需要知道当前处于哪个噪声级别,才能做出正确的预测
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
 
    def forward(self, t):
        # t: (B,)  整数时间步
        half   = self.dim // 2
        freqs  = torch.exp(
            -math.log(10000) * torch.arange(half, device=t.device) / half
        )
        angles = t[:, None].float() * freqs[None]          # (B, dim/2)
        emb    = torch.cat([angles.sin(), angles.cos()], dim=-1)  # (B, dim)
        return emb
 
 
class ResBlock(nn.Module):
    """UNet 中的基本残差块,接受时间步 embedding 作为额外输入"""
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv1   = nn.Sequential(nn.GroupNorm(8, in_ch), nn.SiLU(),
                                     nn.Conv2d(in_ch, out_ch, 3, padding=1))
        # 时间步 embedding 注入:通过 scale-shift(仿射变换)方式
        self.time_mlp = nn.Sequential(nn.SiLU(),
                                      nn.Linear(time_dim, out_ch * 2))
        self.conv2   = nn.Sequential(nn.GroupNorm(8, out_ch), nn.SiLU(),
                                     nn.Conv2d(out_ch, out_ch, 3, padding=1))
        self.skip    = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch \
                       else nn.Identity()
 
    def forward(self, x, t_emb):
        h = self.conv1(x)
        # time_mlp 输出 scale 和 shift,做特征调制
        scale, shift = self.time_mlp(t_emb).chunk(2, dim=-1)
        scale = scale.view(*scale.shape, 1, 1)
        shift = shift.view(*shift.shape, 1, 1)
        h = h * (1 + scale) + shift    # 仿射变换,注入时间信息
        h = self.conv2(h)
        return h + self.skip(x)        # 残差
 
 
class SimpleUNet(nn.Module):
    """
    扩散模型的去噪网络
    输入: x_t (B, C, H, W) + 时间步 t (B,)
    输出: 预测的噪声 ε,或 x_0,或 v(取决于 prediction_type)
    形状不变:输出 (B, C, H, W)
    """
    def __init__(self, in_channels=3, base_channels=64,
                 channel_mults=(1, 2, 4, 8), time_dim=256,
                 prediction_type='epsilon'):
        super().__init__()
        self.prediction_type = prediction_type
        channels = [base_channels * m for m in channel_mults]
 
        # 时间编码:t → embedding 向量
        self.time_emb = nn.Sequential(
            SinusoidalTimeEmb(base_channels),
            nn.Linear(base_channels, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim)
        )
 
        # 初始卷积
        self.init_conv = nn.Conv2d(in_channels, channels[0], 3, padding=1)
 
        # Encoder(下采样)
        self.downs   = nn.ModuleList()
        self.downpools = nn.ModuleList()
        in_ch = channels[0]
        for out_ch in channels[1:]:
            self.downs.append(ResBlock(in_ch, out_ch, time_dim))
            self.downpools.append(nn.Conv2d(out_ch, out_ch, 4, 2, 1))  # 步长2下采样
            in_ch = out_ch
 
        # Bottleneck
        self.mid = ResBlock(channels[-1], channels[-1], time_dim)
 
        # Decoder(上采样)
        self.ups   = nn.ModuleList()
        self.uppools = nn.ModuleList()
        for out_ch in reversed(channels[:-1]):
            self.uppools.append(nn.ConvTranspose2d(in_ch, in_ch, 4, 2, 1))
            self.ups.append(ResBlock(in_ch + out_ch, out_ch, time_dim))  # +out_ch 是 skip
            in_ch = out_ch
 
        self.final = nn.Sequential(
            nn.GroupNorm(8, channels[0]),
            nn.SiLU(),
            nn.Conv2d(channels[0], in_channels, 3, padding=1)  # 输出和输入同形状
        )
 
    def forward(self, x_t, t):
        """
        x_t: (B, C, H, W)  含噪图像
        t:   (B,)           时间步(整数)
        返回: (B, C, H, W)  预测目标(ε / x_0 / v)
        """
        t_emb = self.time_emb(t)          # (B, time_dim)
        x = self.init_conv(x_t)
 
        # Encoder + 保存 skip 特征
        skips = []
        for down, pool in zip(self.downs, self.downpools):
            x = down(x, t_emb)
            skips.append(x)
            x = pool(x)
 
        x = self.mid(x, t_emb)
 
        # Decoder + 融合 skip 特征
        for up_pool, up, skip in zip(self.uppools, self.ups, reversed(skips)):
            x = up_pool(x)
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])
            x = torch.cat([x, skip], dim=1)
            x = up(x, t_emb)
 
        return self.final(x)
 
 
# ─────────────────────────────────────────────
# 4.3 DDPM 训练
# ─────────────────────────────────────────────
 
def train_ddpm(unet, scheduler, dataloader, epochs=100, lr=2e-4,
               prediction_type='epsilon'):
    """
    DDPM 训练:
    ① 随机采样真实图像 x_0
    ② 随机采样时间步 t
    ③ 随机采样噪声 ε
    ④ 用前向过程得到 x_t
    ⑤ 用 UNet 预测 ε(或 x_0 或 v)
    ⑥ 计算 MSE 损失
    """
    device    = next(unet.parameters()).device
    optimizer = torch.optim.AdamW(unet.parameters(), lr=lr)
    T         = scheduler.T
 
    for epoch in range(epochs):
        total_loss = 0
        for x0, _ in dataloader:
            x0 = x0.to(device)       # (B, 3, H, W) 真实图像,归一化到 [-1, 1]
            B  = x0.shape[0]
 
            # ① 均匀采样时间步 t ~ Uniform{1, ..., T}
            t = torch.randint(0, T, (B,), device=device)   # (B,)
 
            # ② 随机采样噪声
            noise = torch.randn_like(x0)                   # (B, 3, H, W)
 
            # ③ 前向加噪:得到 x_t
            x_t, _ = scheduler.q_sample(x0, t, noise)     # (B, 3, H, W)
 
            # ④ UNet 预测
            pred = unet(x_t, t)                             # (B, 3, H, W)
 
            # ⑤ 计算损失(根据 prediction_type 选择监督目标)
            if prediction_type == 'epsilon':
                target = noise                              # 预测噪声 ε
            elif prediction_type == 'x0':
                target = x0                                 # 预测原图 x_0
            elif prediction_type == 'v':
                # v = √ᾱ_t · ε - √(1-ᾱ_t) · x_0
                sqrt_a  = scheduler.sqrt_alphas_cumprod[t].view(-1,1,1,1)
                sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
                target  = sqrt_a * noise - sqrt_1a * x0
 
            loss = F.mse_loss(pred, target)
 
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
 
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: loss = {total_loss/len(dataloader):.5f}")
 
 
# ─────────────────────────────────────────────
# 4.4 DDPM 推理(随机采样,1000步)
# ─────────────────────────────────────────────
 
@torch.no_grad()
def ddpm_sample(unet, scheduler, shape, device, prediction_type='epsilon'):
    """
    DDPM 标准采样:从纯噪声开始,逐步去噪 T 步
 
    shape: (B, C, H, W)  要生成的图像形状
    每一步都有随机性(加回 σ_t · z),所以每次生成结果不同
    """
    unet.eval()
    T  = scheduler.T
 
    # 从纯高斯噪声开始
    x_t = torch.randn(*shape, device=device)   # x_T ~ N(0, I)
 
    for t_idx in reversed(range(T)):           # t = T-1, T-2, ..., 1, 0
        t_batch = torch.full((shape[0],), t_idx, device=device, dtype=torch.long)
 
        # UNet 预测
        pred = unet(x_t, t_batch)   # (B, C, H, W)
 
        # 根据 prediction_type 转换为 ε
        if prediction_type == 'epsilon':
            eps_pred = pred
        elif prediction_type == 'x0':
            # 从 x_0 推出 ε
            sqrt_a  = scheduler.sqrt_alphas_cumprod[t_idx]
            sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t_idx]
            eps_pred = (x_t - sqrt_a * pred) / sqrt_1a
        elif prediction_type == 'v':
            eps_pred = scheduler.predict_eps_from_v(
                x_t, torch.tensor([t_idx]), pred.unsqueeze(0)
            ).squeeze(0)
 
        # DDPM 反向一步
        x_t = scheduler.ddpm_step(x_t, t_idx, eps_pred)
 
    return x_t.clamp(-1, 1)   # 裁剪到合理范围
 
 
# ─────────────────────────────────────────────
# 4.5 DDIM 推理(确定性,支持快速采样)
# ─────────────────────────────────────────────
 
@torch.no_grad()
def ddim_sample(unet, scheduler, shape, device,
                n_steps=50, eta=0.0, prediction_type='epsilon'):
    """
    DDIM 快速确定性采样
 
    n_steps: 推理步数(可以远小于训练的 T=1000,如 50 步)
    eta=0.0: 完全确定性(相同 noise → 相同图像)
    eta=1.0: 退化为 DDPM(随机)
 
    关键:在 T 个时间步中等间隔选取 n_steps 个,跳步推理
    """
    unet.eval()
    T = scheduler.T
 
    # 选取等间隔的时间步序列(从大到小)
    timesteps = torch.linspace(T-1, 0, n_steps, dtype=torch.long)
 
    x_t = torch.randn(*shape, device=device)   # 从纯噪声开始
 
    for i, t_idx in enumerate(timesteps):
        t_idx_int = t_idx.item()
        t_batch   = torch.full((shape[0],), t_idx_int, device=device, dtype=torch.long)
 
        # UNet 预测
        pred = unet(x_t, t_batch)
 
        # 统一转换为 ε 预测和 x_0 预测
        if prediction_type == 'epsilon':
            eps_pred = pred
            sqrt_a   = scheduler.sqrt_alphas_cumprod[t_idx_int]
            sqrt_1a  = scheduler.sqrt_one_minus_alphas_cumprod[t_idx_int]
            x0_pred  = (x_t - sqrt_1a * eps_pred) / sqrt_a
        elif prediction_type == 'x0':
            x0_pred  = pred
            sqrt_a   = scheduler.sqrt_alphas_cumprod[t_idx_int]
            sqrt_1a  = scheduler.sqrt_one_minus_alphas_cumprod[t_idx_int]
            eps_pred = (x_t - sqrt_a * x0_pred) / sqrt_1a
        elif prediction_type == 'v':
            x0_pred  = scheduler.predict_x0_from_v(x_t, torch.tensor([t_idx_int]), pred)
            sqrt_a   = scheduler.sqrt_alphas_cumprod[t_idx_int]
            sqrt_1a  = scheduler.sqrt_one_minus_alphas_cumprod[t_idx_int]
            eps_pred = (x_t - sqrt_a * x0_pred) / sqrt_1a
 
        # 确定下一个时间步 t_prev
        if i + 1 < len(timesteps):
            t_prev = timesteps[i + 1].item()
        else:
            t_prev = 0
 
        # DDIM 更新公式
        a_t    = scheduler.alphas_cumprod[t_idx_int]
        a_prev = scheduler.alphas_cumprod[t_prev] if t_prev > 0 \
                 else torch.tensor(1.0)
 
        sigma_t = eta * ((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev)).sqrt()
 
        # x_{t-1} = √ᾱ_{t-1} · x_0_pred
        #         + √(1-ᾱ_{t-1}-σ²) · ε_pred   ← "指向噪声的方向"
        #         + σ · z                         ← 可选随机项
        x_t = (a_prev.sqrt() * x0_pred
               + (1 - a_prev - sigma_t**2).clamp(min=0).sqrt() * eps_pred
               + sigma_t * torch.randn_like(x_t))
 
    return x_t.clamp(-1, 1)
 
 
# ─────────────────────────────────────────────
# 4.6 Latent Diffusion(潜在扩散模型骨架)
# ─────────────────────────────────────────────
 
class SimpleVAE(nn.Module):
    """
    极简 VAE:图像 ↔ 隐变量
    真实使用时(如 Stable Diffusion)的 VAE 是预训练好后冻结的
    """
    def __init__(self, img_channels=3, latent_channels=4, base_ch=64):
        super().__init__()
        # Encoder:图像 → 均值 + 方差
        self.encoder = nn.Sequential(
            nn.Conv2d(img_channels, base_ch, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_ch, base_ch*2, 4, stride=2, padding=1),   # ÷2
            nn.SiLU(),
            nn.Conv2d(base_ch*2, base_ch*4, 4, stride=2, padding=1), # ÷4
            nn.SiLU(),
            nn.Conv2d(base_ch*4, latent_channels*2, 3, padding=1)    # mu + logvar
        )
        # Decoder:隐变量 → 图像
        self.decoder = nn.Sequential(
            nn.Conv2d(latent_channels, base_ch*4, 3, padding=1),
            nn.SiLU(),
            nn.ConvTranspose2d(base_ch*4, base_ch*2, 4, stride=2, padding=1), # ×2
            nn.SiLU(),
            nn.ConvTranspose2d(base_ch*2, base_ch, 4, stride=2, padding=1),   # ×4
            nn.SiLU(),
            nn.Conv2d(base_ch, img_channels, 3, padding=1),
            nn.Tanh()
        )
 
    def encode(self, x):
        """图像 → (μ, σ)"""
        out = self.encoder(x)                   # (B, 2C, H/4, W/4)
        mu, logvar = out.chunk(2, dim=1)
        return mu, logvar
 
    def reparameterize(self, mu, logvar):
        """重参数化采样:z = μ + ε·σ"""
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std
 
    def decode(self, z):
        return self.decoder(z)
 
    def forward(self, x):
        mu, logvar = self.encode(x)
        z          = self.reparameterize(mu, logvar)
        x_recon    = self.decode(z)
        return x_recon, mu, logvar
 
 
def latent_diffusion_train_step(vae, unet, scheduler, x0, t, device,
                                prediction_type='epsilon'):
    """
    Latent Diffusion 单步训练
 
    关键区别:扩散过程发生在隐空间 z,而非像素空间 x
    """
    # ① 编码到隐空间(VAE 冻结,不计算梯度)
    with torch.no_grad():
        mu, logvar = vae.encode(x0)
        # 训练时通常用确定性的 μ,不加随机扰动
        z0 = mu   # (B, 4, H/8, W/8)  ← 尺寸是原图的 1/8×1/8
 
    # ② 在隐空间加噪
    noise = torch.randn_like(z0)
    z_t, _ = scheduler.q_sample(z0, t, noise)   # (B, 4, H/8, W/8)
 
    # ③ UNet 在隐空间预测(输入输出都是 4 通道的隐变量)
    pred = unet(z_t, t)   # (B, 4, H/8, W/8)
 
    # ④ 损失(和像素空间完全一样,只是作用在 z 上)
    if prediction_type == 'epsilon':
        target = noise
    elif prediction_type == 'x0':
        target = z0
    elif prediction_type == 'v':
        sqrt_a  = scheduler.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        sqrt_1a = scheduler.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
        target  = sqrt_a * noise - sqrt_1a * z0
 
    return F.mse_loss(pred, target)
 
 
@torch.no_grad()
def latent_diffusion_sample(vae, unet, scheduler, n=4,
                             latent_shape=(4, 8, 8), n_steps=50, device='cpu'):
    """
    Latent Diffusion 采样:
    ① 在隐空间采样(DDIM 快速)
    ② VAE 解码回像素空间
    """
    # 在隐空间生成
    z_shape  = (n, *latent_shape)
    z_sample = ddim_sample(unet, scheduler, z_shape, device, n_steps=n_steps)
    # z_sample: (n, 4, H/8, W/8)
 
    # VAE 解码回图像
    images = vae.decode(z_sample)   # (n, 3, H, W)
    return (images + 1) / 2          # [-1,1] → [0,1]
 
 
# ── 快速验证 ──
scheduler = DDPMScheduler(T=1000, schedule='cosine')
unet = SimpleUNet(in_channels=3, base_channels=32,
                  channel_mults=(1, 2, 4), time_dim=128)
 
# 验证前向过程
x0 = torch.randn(2, 3, 32, 32)    # 假图像
t  = torch.randint(0, 1000, (2,))
x_t, noise = scheduler.q_sample(x0, t)
print("x_t shape:", x_t.shape)     # (2, 3, 32, 32)
 
# 验证 UNet
pred = unet(x_t, t)
print("pred shape:", pred.shape)   # (2, 3, 32, 32)
 
# 验证 DDIM 采样
samples = ddim_sample(unet, scheduler, (2, 3, 32, 32), device='cpu', n_steps=10)
print("samples shape:", samples.shape)  # (2, 3, 32, 32)

5. 四大范式对比总结

训练目标对比

范式训练目标损失函数输入输出
CNN最小化预测误差CE / MSE / Dice(B, C, H, W)(B, classes) 或 (B, C, H, W)
GANG 骗过 D,D 识破 GBCE (DCGAN) / Wasserstein (WGAN)z: (B, latent)G: 图像 (B, C, H, W);D: 概率 (B, 1)
AR最大化序列似然Cross-Entropytoken_ids: (B, T)logits: (B, T, V)
Diffusion预测噪声/原图/vMSEx_t: (B,C,H,W), t: (B,)ε / x₀ / v: (B, C, H, W)

推理流程对比

范式推理步数随机性速度特点
CNN1 步无(确定性)极快直接映射
GAN1 步有(采样 z)极快一步生成
ART 步(T=生成长度)有(token 采样)慢(串行)KV Cache 加速
Diffusion (DDPM)1000 步有(随机去噪)极慢质量高
Diffusion (DDIM)20~50 步可调(η)较快速度质量平衡

生成能力对比

CNN(判别式):
  ✦ 不擅长生成,主要用于判别/回归任务
  ✦ 生成类任务需要特殊结构(超分、去噪)

GAN:
  ✦ 快速生成高质量图像(一步到位)
  ✦ 训练不稳定,可能 mode collapse
  ✦ 难以控制生成内容(条件 GAN 部分解决)

Auto-Regressive:
  ✦ 文本生成的绝对主流(GPT 系列)
  ✦ 支持任意长度序列生成
  ✦ 推理串行,速度受限于序列长度
  ✦ 可结合图像 token 做图像生成(VQVAE + GPT)

Diffusion:
  ✦ 目前图像生成质量最高(SD、DALL-E 3)
  ✦ 训练稳定,无 mode collapse
  ✦ 推理需多步,有加速方案(DDIM、LCM)
  ✦ 易于条件控制(CFG、ControlNet)

一图看懂四大范式

                        训练信号来源
                              │
    ┌─────────────────────────┼──────────────────────────────┐
    │                         │                              │
    ▼                         ▼                              ▼
 标注标签               对抗博弈                        数据本身
(supervised)          (adversarial)                (self-supervised)
    │                         │                          │       │
    ▼                         ▼                          ▼       ▼
   CNN                       GAN                        AR    Diffusion
  (判别)                    (生成)                  (自回归)  (生成)
  确定性                   一步生成                  串行生成  多步去噪
  推理快                   训练难                    训练快   训练稳

参考:DDPM (Ho et al., 2020) · DDIM (Song et al., 2021) · GPT (Radford et al.) · DCGAN (Radford et al., 2016)