第 1 步:建立”什么叫好剪枝信号”的直觉
在看任何具体信号之前,先想清楚一件事:剪枝信号本质上是一个排序函数,输入 N’ 个 token,输出 N’ 个分数,然后 top-k。
所以一个好的信号只需要满足两个条件:
- 相关 token 的分数 > 不相关 token 的分数(绝对值不重要,排序才重要)
- 计算便宜(否则剪枝省下的时间都被打分吃掉)
这个视角很关键,因为它告诉你不需要纠结 score 的 scale / normalization,只要 ranking 对就行。很多初学者会陷入”我的分数数值对不对”的纠结,其实完全不用管。
第 2 步:理解最基础的信号——cross-modal cosine similarity
这是你必须第一个实现和理解的信号。不是因为它最好,而是因为它是所有复杂信号的基线。如果一个复杂信号打不过 cosine,那基本可以判定该信号没价值。
信号定义
对每个 visual token v_i ∈ R^D(D 是 LLM hidden dim)
对 query pooled embedding q ∈ R^D
score_i = cos(v_i, q) = (v_i · q) / (||v_i|| × ||q||)
为什么它有效
因为 projector 的训练目标之一就是让 visual token 落到和 LLM text embedding 可以做语义比较的空间里。LLaVA 等模型的 projector 是用 caption 数据训过的,图文语义是对齐的。所以 cosine 衡量的就是”这个 visual token 在语义上和 query 有多接近”。
为什么它会失效
有三种典型失效模式,你必须都能识别出来:
失效 1:Modality gap 即使 projector 对齐了,visual embedding 和 text embedding 仍然倾向于聚在 embedding 空间的不同区域(这是 CLIP 论文里就讨论过的现象)。cosine 在这种情况下不稳定——可能所有 visual token 的 cosine 值都挤在 0.2-0.3 这个窄区间,区分度差。
诊断方法:画 scores 的直方图。如果方差很小(比如 std < 0.05),说明 modality gap 影响大。
失效 2:Query 语义太宽 “What is happening in the video?” 这种 query 的 embedding 几乎和所有 visual token 都有相似度,cosine 没有判别力。
诊断方法:看 top-k 和 bottom-k 的分数差距。差距小说明 query 太宽。
失效 3:High-norm visual token bias 有些 visual token 的 L2 norm 特别大(比如图像里一些特定 patch),即使做了 cosine normalization,它们的 dot product 仍然偏高。这会让某些 token 永远被选中,无论 query 是什么。
诊断方法:统计每个 visual token 被选中的频率(across 多个 query)。如果某几个 token 被选中的频率接近 100%,就是这个问题。
实现要点
几个容易踩坑的地方:
- Query 的 embedding 必须用 LLM 的 embedding layer,不是 CLIP 或其他 text encoder
- Query 的 pooling:用 mean pool 即可,不要用 last token(那个往往是 ”?” 或句号,语义弱)
- 忽略 special tokens:BOS / system prompt / “Question:” 这种 template token 不要进 pooling,否则会稀释信号
- Normalization 放在最后:先 dot product 再 normalize,数值稳定性更好
学习任务
实现这个信号,然后做两个诊断实验:
实验 A:在同一张图上用不同 query 打分,看 top-k 选择是否真的随 query 变化。如果变化不明显,说明 cosine 信号太弱。
实验 B:固定 query,看 top-k 选出的 visual token 在原图上的空间分布。用图像 patch 的空间坐标可视化。如果它们都集中在某一个区域,要么 query 就是指那个区域(好现象),要么说明 high-norm bias(坏现象)。
第 3 步:理解 attention-based 信号以及为什么多数人选它
虽然你之前决定不用 attention,但你必须理解它,否则你无法在论文里证明”不用 attention 也行”。
信号定义
在 LLM 的第 K 层(比如第 8 层),抽取 query token 对 visual token 的 attention weight:
attn_{i,j} = softmax_j(Q_i · K_j / sqrt(d))
i: query token index
j: visual token index
visual_importance_j = Σ_i attn_{i,j} # 或 max_i
为什么它比 cosine 强
关键原因:attention 是在 LLM 自己已经做过一层语义变换之后才算的。cosine 用的是 projector 刚出来的原始 visual embedding,LLM 根本还没”理解”它。而 attention 在第 K 层(K > 0),visual token 已经和 query token、system prompt 在前 K-1 层做过信息交互,表示已经 contextualized 了。
这就像你问一个人”今天天气怎么样”,他先听完你的话(contextualization)再判断重要的词,而不是看见词就立即打分。
为什么选第 8 层左右
这是 FlexSelect / DyToK 等论文的实测结论:
- 太浅(第 1-2 层):跨模态融合不充分,attention 信号还没”理解” query
- 太深(第 20+ 层):视觉 token 已经被压缩进少数”信息枢纽 token”,attention 趋向集中,失去细粒度
- 中早层(第 6-10 层):cross-modal 融合刚完成但尚未过度压缩,判别力最好
为什么你不用它
主要原因工程上已经讨论过:
- FlashAttention 不暴露 weight,要 fall back 到慢实现
- 第 K 层的 attention 要真的跑 K 层 LLM 才能拿到,不够”轻”
- 跨 head 聚合方式没定论
学习任务
即使不用,也要亲手实现一次 attention-based 信号,然后对比 cosine:
实验 C:在同一组 (image, query) 上同时算 cosine 分数和第 8 层的 attention 分数,计算两个排名的 Spearman correlation。如果 correlation > 0.7,说明 cosine 已经捕捉到了 attention 要捕捉的大部分信号,你用 cosine 就够了。如果 < 0.5,说明 attention 捕捉到了 cosine 没有的东西,你需要找一个中间方案(比如下面的 probing)。
第 4 步:理解 hidden state norm 这种”旁路信号”
这是一类被很多人忽视但非常有用的信号:不依赖 query,只看 visual token 本身在某层的状态。
信号定义
用 LLM 前 K 层 forward 一遍
取第 K 层的 hidden state h_j ∈ R^D
norm_j = ||h_j||_2
选 top-k 时,norm 大的保留。
为什么它有效
这是 LLM 的一个经验规律:“信息量大”的 token 在深层通常 norm 更大,因为 residual connection 不断累加信息。Massive activation 这个现象在 LLM 里普遍存在。
更精妙的是:attention sink token(BOS 附近的特殊 token)通常 norm 也大,这会污染信号。但有意思的是,如果你限定只看 visual token,attention sink 的干扰就没了——因为它们不在你的筛选范围内。
它是 query-agnostic 的,为什么能做 query-aware 剪枝
它单独用不能做 query-aware。但它可以:
- 作为 cosine 的修正项:
final_score = cosine + λ × norm,让”语义相关 + 信息量大”的 token 优先 - 作为前段冗余剪枝的信号:在 query 未知时用它来做初筛
- 作为 anti-bias 信号:和 cosine 结合后,减少 cosine 的 modality gap 问题
学习任务
实验 D:只用 norm 做剪枝,看性能。通常会比 cosine 差但不会差太多。这证明了”即使完全不看 query 也能剪到不错的子集”,也就是你前段 diversity-based 剪枝的理论基础。
第 5 步:组合信号——真正的训练无关方案
单一信号都有缺陷,你真正要用的是组合信号。这是很多论文避而不谈的”调参艺术”,但其实有规律可循。
组合策略 1:线性组合
final_score = α × cosine_sim + β × norm + γ × position_bias
其中 position_bias 是可选的,用来降低 boundary / corner patch 的权重(它们往往是无意义的背景)。
α, β, γ 怎么选:
- 不要在 test set 上调,会过拟合
- 用一个小的 validation set(比如 MMDU 的 200 个样本)做 grid search
- 经验值:α ≈ 1.0, β ≈ 0.1, γ ≈ -0.3
组合策略 2:rank fusion
不直接组合分数,而是组合排名:
rank_cosine_j = rank of token j by cosine
rank_norm_j = rank of token j by norm
final_rank_j = (rank_cosine_j + rank_norm_j) / 2
这个方法的好处是对每个信号的 scale 不敏感,更稳定。推荐用这个。
组合策略 3:Multiplicative gating
final_score = cosine_sim × sigmoid(norm / threshold)
这种是”cosine 为主,norm 作为 gate 抑制低信息 token”。适合 cosine 本身已经够好、只想过滤明显的 noise token 的场景。
学习任务
实验 E:对比三种组合策略和单一 cosine,在你的评测 benchmark 上报性能。通常 rank fusion > multiplicative > linear > single。这个对比可以直接进论文消融。
第 6 步:引入多轮场景的信号演化
前面讨论的都是单轮。多轮场景下,信号本身没变,但**“query”这个输入变了**。
累积 query 的实现
class CumulativeQueryEncoder:
def __init__(self, llm_embed, alpha=0.3):
self.llm_embed = llm_embed # LLM 的 embedding layer
self.alpha = alpha
self.history_embeddings = []
def update(self, new_query_text):
query_ids = tokenizer(new_query_text).input_ids
query_emb = self.llm_embed(query_ids).mean(dim=0) # [D]
self.history_embeddings.append(query_emb)
def get_query_vector(self):
if len(self.history_embeddings) == 0:
return None
if len(self.history_embeddings) == 1:
return self.history_embeddings[0]
# 当前 query + 历史累积
current = self.history_embeddings[-1]
history = torch.stack(self.history_embeddings[:-1]).mean(0)
return self.alpha * history + (1 - self.alpha) * currentalpha 怎么调
alpha 控制”历史 vs 当前” 的权重:
- alpha = 0:完全用当前 query,不管历史。在 topic shift 时好,但历史相关的 follow-up 问题会丢信息
- alpha = 1:完全用历史平均,忽略当前。永远稳但不够 query-specific
- alpha = 0.3:偏向当前,但历史提供 context。经验上合理起点
Topic shift 的处理
多轮场景最怕的是用户突然换话题,这时候累积 query 会把新 topic 稀释。可以加一个简单检测:
def should_reset_history(new_query_emb, history_emb):
similarity = F.cosine_similarity(new_query_emb, history_emb.mean(0), dim=0)
return similarity < 0.3 # 阈值经验值
if should_reset_history(current, history):
# 清空历史,重新开始累积
cumulative = current
else:
cumulative = alpha * history + (1-alpha) * current学习任务
实验 F:在 MMDU 上对比三种 query 策略:
- 只用当前 query
- 累积 query(固定 alpha = 0.3)
- 累积 query + topic shift 重置
按 turn index 分层报性能(turn 1 / turn 2-5 / turn 6+)。你会看到策略 1 在 turn 1 最好但后续掉点,策略 3 在高 turn 上最稳。这是一个能讲故事的实验。
第 7 步:Diversity 作为正则项——防止信号塌陷
前面所有信号都是单独给 token 打分的,存在一个严重问题:top-k 可能全部来自同一个区域。
比如 query 是”图中有几个人”,cosine 把所有人脸 patch 都打高分,top-k 全选脸,完全丢失了身体 / 背景 / 场景信息。但”几个人”这个问题需要全局计数,只看脸是不够的。
解决方案:Greedy diverse top-k
不要一次性 topk,而是贪心地选:
def diverse_topk(scores, tokens, k, lambda_div=0.5):
selected = [argmax(scores)] # 先选分数最高的
for _ in range(k - 1):
# 对每个候选,计算 "query 相关性 - 与已选集合的相似度惩罚"
candidates = [i for i in range(len(scores)) if i not in selected]
best_score = -inf
best_idx = None
for i in candidates:
relevance = scores[i]
diversity_penalty = max(
cosine(tokens[i], tokens[j]) for j in selected
)
combined = relevance - lambda_div * diversity_penalty
if combined > best_score:
best_score = combined
best_idx = i
selected.append(best_idx)
return selected为什么这个特别重要
对你的多轮场景,这是关键中的关键。为什么?
单轮:query 明确,top-k 集中也没问题,因为都是相关 token。 多轮:前一轮的 query A 让某类 token 分数高,但后一轮 query B 可能需要完全不同的 token。如果前一轮选出的 N” 已经过度集中,那么这 N” 个 token 的 KV 即使永久保留,对后一轮用处也有限。
所以 diversity 的作用不是单轮的”不丢信息”,而是多轮的”留好后路”。这是一个可以在论文里专门讲的 insight。
学习任务
实验 G:对比普通 top-k 和 diverse top-k 在多轮 benchmark 上的 per-turn 性能。你会看到 diverse 在 turn 1 可能略差(因为不是严格的 top relevance),但从 turn 2 开始稳定胜出。这个 gap 会随 turn 增加而放大。
第 8 步:把所有东西串起来的最终信号
综合前面所有,你的训练无关方案应该是这样的决策流程:
Input:
visual_tokens: [N', D] (来自 projector + diversity-based 前段剪枝)
current_query_text, history_query_texts
Step 1: 构造 query vector
使用累积 query 策略(带 topic shift 重置)
Step 2: 计算基础信号
s_cos[i] = cosine(visual_tokens[i], query_vec)
s_norm[i] = ||hidden_state_at_layer_K[i]|| # 可选,需要跑 K 层
Step 3: 组合信号
s_combined = rank_fusion(s_cos, s_norm)
# 或者如果不跑 K 层: s_combined = s_cos
Step 4: Diverse top-k 选择
selected = diverse_topk(s_combined, visual_tokens, k=N'', lambda_div=0.5)
Output:
selected_indices → 用于 PagedAttention 的 page table
学习路径总结:你应该按什么顺序掌握
我给你排一个实际的学习顺序,每一步都有可跑的实验:
第 1 周:单一信号
- 实现 cosine similarity 信号
- 实验 A / B:诊断 cosine 的 failure mode
- 在 MMDU 小子集上跑通端到端 pipeline,不求性能,求能跑
第 2 周:对比信号
- 实现 attention-based 信号(用 non-flash attention 实现)
- 实现 hidden norm 信号
- 实验 C / D:对比三种信号,建立哪种信号强/弱的直觉
第 3 周:组合信号
- 实现 linear / rank fusion / multiplicative 三种组合
- 实验 E:在 validation set 上调参,确定最好的组合策略
第 4 周:多轮扩展
- 实现累积 query 和 topic shift 重置
- 实验 F:按 turn 分层的消融实验,这是你论文 main result 的核心
第 5 周:Diversity 正则
- 实现 diverse top-k
- 实验 G:diversity 对多轮的增益
这五周走完,你会对剪枝信号有非常扎实的理解,训练无关版本也完整跑出来了。之后再上训练版本,你会发现训练带来的提升可能并没有你想象的那么大——很多时候训练无关版本加上合适的 diversity + 累积 query,在多轮场景下已经能打败复杂的训练版方法。这本身就是论文可以讲的一个 insight。
一个重要的心态
不要一上来就追求最复杂的方法。剪枝这个领域的大量论文看起来花里胡哨,但核心信号就是 cosine / attention / norm 这三种,和它们的组合。你的创新点不会来自于发明新信号,而来自于如何组合它们、如何让它们适配你的场景(多轮 + cache 复用)。所以把基础信号理解透、能快速诊断它们的失效模式,比堆砌新技巧重要 10 倍。
你想先从哪个实验开始动手?我可以陪你一步步做第一个实验的完整实现。