09 SFT Training Loop
LLM 算子深度解析:09 SFT Training Loop — 两行代码决定你的 SFT 是"微调"还是"背诵"
1. SFT 训练的核心矛盾:教会模型回答,不是教会模型提问
1.1 预训练和微调的本质区别
预训练阶段,模型的任务是"给定前文,预测下一个字"。训练数据是一本一本书、一篇文章一篇文章——每一个 token 都要算 Loss,因为模型需要学会语言本身的规律。
SFT 阶段完全不同。一条训练数据长这样:
[Prompt] 请帮我写一首关于春天的五言绝句。
[Response] 春风拂柳绿,细雨润花红。燕舞莺歌处,人间万象融。
我们的目标是:模型看到 Prompt 后,能输出 Response。我们只关心 Response 的质量,不关心 Prompt 长什么样。
如果你把 Prompt 和 Response 一起送进 CrossEntropyLoss,模型会干嘛?它会努力去"背诵"人类的提问方式——“请帮我”、“写一首”、“关于春天的”——这些 token 的预测也会产生梯度,白白浪费算力去学一些对回答问题毫无帮助的东西。
1.2 解决方案就在一个参数里
PyTorch 的 CrossEntropyLoss 有一个神仙参数叫 ignore_index,默认值是 -100。任何 labels 中值为 -100 的位置,都不会产生梯度,也不会计入 Loss 统计。
所以 SFT 训练的核心 trick 就是一句话:把 labels 中属于 Prompt 的部分全部设成 -100。
输入: [请, 帮我, 写, 一首, 诗, 春风, 拂柳, 绿, ...]
|--- Prompt ---| |------ Response ------|
labels: [-100, -100, -100, -100, -100, 春风, 拂柳, 绿, ...]
↑ 全 mask,不产生梯度 ↑ 保留原样,真正的监督信号
2. Shift 错位对齐:自回归模型的一个"坑"
2.1 为什么需要错位?
自回归语言模型的预测逻辑是:
用 token_0, token_1, ..., token_t → 预测 token_{t+1}
也就是模型的输出 logits[t] 预测的是第 t+1 个位置的 token。但 labels 序列的第 t 个位置存的是 token_t 本身——这就错位了。
打个比方:老师给你看前 3 个字,让你猜第 4 个字。答案应该是第 4 个字,而不是第 3 个字。
2.2 怎么对齐?
shift_logits = logits[..., :-1, :] # 丢掉最后一个位置的预测
shift_labels = labels[..., 1:] # 丢掉第一个位置的标签
对齐之后:
Shift 前: logits[t] → 预测 labels[t] ❌ 自己预测自己(复制)
Shift 后: logits[t] → 预测 labels[t+1] ✅ 前文预测后文(生成)
3. 代码实现:从零手写 SFT 训练核心
3.1 数据构造:一行代码的讲究
def build_sft_data(prompt_ids: list[int], response_ids: list[int],
pad_id: int = 0, max_len: int = 16):
# Step 1: 拼接
input_ids = prompt_ids + response_ids
# Step 2: 构造 labels — 核心就这一行
labels = [-100] * len(prompt_ids) + response_ids
# Step 3: 截断与填充
if len(input_ids) > max_len:
input_ids = input_ids[:max_len]
labels = labels[:max_len]
else:
pad_len = max_len - len(input_ids)
input_ids = input_ids + [pad_id] * pad_len
labels = labels + [-100] * pad_len # padding 也要 mask!
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
三个关键点:
[-100] * len(prompt_ids):Prompt 全部 mask,一个不剩。- 截断时同时截 labels:否则 input_ids 和 labels 长度不一致,CrossEntropyLoss 直接报错。
- Padding 也填
-100:Padding token(通常是 0 或<pad>)不是真实数据,模型不应该学它。如果忘了这行,模型会花大量算力去预测填充符。
3.2 Loss 计算:四个操作一气呵成
def compute_sft_loss(logits: torch.Tensor, labels: torch.Tensor):
# logits: [B, seq_len, vocab_size]
# labels: [B, seq_len]
# Step 1: Shift 错位
shift_logits = logits[..., :-1, :].contiguous() # [B, seq_len-1, V]
shift_labels = labels[..., 1:].contiguous() # [B, seq_len-1]
# Step 2: 展平 + CrossEntropyLoss
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
shift_logits = shift_logits.view(-1, shift_logits.size(-1)) # [B*S, V]
shift_labels = shift_labels.view(-1) # [B*S]
loss = loss_fct(shift_logits, shift_labels)
return loss
每一步的解释:
logits[..., :-1, :]:切掉最后一个预测位置。因为序列的最后一个 token 没有"下一个 token"作为监督信号。labels[..., 1:]:切掉第一个标签 token。因为第一个位置的 token 没有对应的"上一个预测"。.contiguous():切片操作可能导致 tensor 在内存中不连续,后续.view()要求连续内存。没有这一行,PyTorch 会报"view size is not compatible"的错——这是新手最常见的报错之一。ignore_index=-100:PyTorch 底层会跳过所有 labels 中值为 -100 的位置,不计入前向 Loss 也不参与反向梯度。这比手动构造 mask 矩阵高效得多,因为是在 CUDA kernel 层面做的跳过。
4. 一张表看清 SFT 的标签流向
以一个具体例子走一遍完整流程:
prompt_ids = [10, 20, 30] # 3 个 token 的提问
response_ids = [40, 50, 60, 70] # 4 个 token 的回答
max_len = 8, pad_id = 0
| 步骤 | input_ids | labels | 说明 |
|---|---|---|---|
| 拼接 | [10,20,30,40,50,60,70] |
[-100,-100,-100,40,50,60,70] |
Prompt 全 mask |
| 填充 (+1) | [10,20,30,40,50,60,70,0] |
[-100,-100,-100,40,50,60,70,-100] |
Pad 也 mask |
| Shift logits | 取 [:-1] → 7 个位置 |
— | 丢掉最后一个预测 |
| Shift labels | — | 取 [1:] → 7 个位置 |
丢掉第一个标签 |
| 有效位置 | — | 共 7 个位置,只有 4 个有效 | Prompt(3) + Pad(1) 被 mask |
最终 CrossEntropyLoss 只计算 4 个 Response token 的 Loss,Prompt 和 Padding 完全不影响训练。
5. 与工业实现对照
| 框架 | 实现方式 | 差异 |
|---|---|---|
| HuggingFace Trainer | DataCollatorForCompletionOnlyLM 自动查找 Response 起始位置并设 -100 |
自动化程度高,但不透明 |
| LLaMA-Factory | preprocess.py 中手动构造 labels |
和本节实现几乎一模一样 |
| DeepSpeed-Chat | data_utils.py 中使用相同的 mask 逻辑 |
大规模分布式训练的首选方案 |
工业界的共识就是这三步:拼接 → mask prompt → shift 对齐。没有银弹,没有黑魔法。
6. 踩坑记录
6.1 忘了 shift 导致 Loss 偏高
这是 SFT 训练中最常见的 bug。如果不做 labels[..., 1:],相当于让 logits[t] 预测 labels[t]——模型的任务从"预测下一个 token"变成了"复制当前 token"。模型当然学得会(复制嘛),但推理时生成出来的东西就是一坨随机噪声,因为推理时根本不知道当前 token 是什么。
检查方法:打印一个 batch 的 shift_labels 中非 -100 的 token 数量,和 Response 长度对比。
6.2 Padding 的 labels 没设 -100
模型输出的 token 分布中 padding token(通常是 0)的频率异常高。因为模型在努力学"在句子末尾输出 "——这显然不是你想要的。
6.3 截断剪掉了所有 Response
max_len 设得太小,截断后 Prompt 还在但 Response 全被剪掉了。此时 labels 全是 -100,Loss = 0,但模型什么都没有学到。训练曲线看着很美(Loss=0),实际上是空跑。
检查方法:确保 max_len > len(prompt) + min_response_len。
7. 延伸思考
- Chat Template 和本节的关系:真实 SFT 数据还需要插入特殊 token(LLaMA 的
<|begin_of_text|>、<|start_header_id|>等)。这些特殊 token 也应该 mask 掉,因为它们属于"格式"而非"内容"。但 chat template 的处理是数据预处理层的事,本节的核心 label mask 逻辑不变。 - Packing 技术:为提高 GPU 利用率,可以把多条短数据拼成一条长序列(packing)。此时 label mask 变得复杂:不仅要 mask prompt,还要 mask 不同样本之间的 padding 和 cross-sample attention。
- 与 RLHF 的关系:SFT 是整个 RLHF 管线(SFT → Reward Model → PPO)的第一步。本节的数据构造逻辑直接沿用给后续的 PPO 阶段(PPO 中也需要 mask prompt 来避免策略在无关 token 上乱改)。
更多推荐
所有评论(0)