09. SFT Training Loop 代码笔记
背景
这段代码实现的是监督微调(Supervised Fine-Tuning, SFT) 中常用的两个函数:
build_sft_data:将一条 prompt-response 对处理成模型可直接使用的input_ids和labels。compute_sft_loss:计算因果语言模型的自回归损失,只对 response 部分求梯度。
训练时,我们希望模型在看到 prompt 后能生成正确的 response,但 prompt 部分不计算损失(只作为条件),因此需要对标签做特殊处理。
TODO 1:构造 labels
labels = [-100] * len(prompt_ids) + response_ids
作用:创建与 input_ids 等长的标签序列,其中 prompt 部分用 -100 填充,response 部分保留原始的 token id。
为什么是 -100?
PyTorch 的 nn.CrossEntropyLoss(ignore_index=-100) 会忽略标签值为 -100 的位置,这些位置不贡献任何梯度,也不会影响损失值。这样模型在 prompt 区域不会被迫去“预测”这些 token,只有 response 区域才会计算损失并更新参数。
示例:
prompt_ids = [1, 2, 3]response_ids = [4, 5, 6]- 则
labels = [-100, -100, -100, 4, 5, 6]
模型在位置 0~2 的预测不会被惩罚,位置 3~5 的预测目标分别是 4、5、6。
TODO 2:截断与填充
# 超长截断
if len(input_ids) > max_len:
input_ids = input_ids[:max_len]
labels = labels[:max_len]
# 不足填充
else:
pad_len = max_len - len(input_ids)
input_ids = input_ids + [pad_id] * pad_len
labels = labels + [-100] * pad_len
作用:保证每个样本的长度统一为 max_len,便于批处理。
- 超长截断:直接保留前
max_len个 token(input_ids[:max_len])。如果超长部分包含 response,则模型会失去部分生成目标;通常设置足够大的max_len来避免这种情况。 - 不足填充:在序列末尾添加
pad_id(一般是 0 或特殊填充符)作为输入,同时标签对应位置填充-100。
为什么标签填-100? 因为填充符是无效 token,不应参与损失计算。ignore_index=-100同样会跳过这些位置。
TODO 3:实现 Shift 错位对齐
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
作用:将模型输出的 logits 与标签对齐,以满足因果语言模型(自回归) 的预测方式。
原理:
在自回归模型中,我们使用位置 t 的输出去预测位置 t+1 的 token。因此:
- 模型输入一个长度为
L的序列,输出的 logits 形状为[batch, L, vocab_size],其中第t个位置的输出对应对第t+1个 token 的预测。 - 我们的标签
labels长度为L,其中labels[t]是位置t处的真实 token。
要计算损失,我们需要将 第 t 个预测 与 第 t+1 的真实 token 对齐。所以:
logits[..., :-1, :]:丢弃最后一个位置的预测(因为没有L+1的真实 token)。labels[..., 1:]:丢弃第一个位置的标签(因为第一个 token 没有前一个位置的预测来对应它)。- 对齐后长度变为
L-1,且shift_logits[t]对应shift_labels[t](即原序列的labels[t+1])。
.contiguous():确保张量在内存中连续存储,为后续 view 操作做准备。
示例:
序列 [A, B, C, D],模型看到 [A, B, C] 后预测 [B, C, D],logits 的形状是 [4, V]。去掉 logits 最后一个位置(预测 E 的无用输出),去掉 labels 第一个位置(A 不需要被预测),即可对齐。
TODO 4:展平并计算交叉熵损失
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = loss_fct(shift_logits, shift_labels)
作用:将对齐后的 logits 和 labels 展平为二维和一维,然后计算交叉熵损失,并自动忽略 label 为 -100 的位置。
- 展平:
shift_logits.view(-1, vocab_size)把所有 batch 和序列位置合并成一个大的二维矩阵[总 token 数, vocab_size];shift_labels.view(-1)变成一维向量[总 token 数]。这是因为CrossEntropyLoss要求输入为(N, C)和(N),其中N是样本总数,C是类别数。 - ignore_index=-100:之前我们在 prompt 部分和填充部分都使用了
-100作为标签,现在损失函数会自动跳过这些位置的损失计算,只对 response 部分的有效 token 计算交叉熵。 - 最终的
loss是一个标量,表示 batch 中所有有效 token 的平均损失。
为什么这样设计损失?
在 SFT 中,我们只关心模型学习生成 response 的能力,prompt 仅作为上下文。因此 prompt 位置不贡献梯度,模型可以专注于优化 answer 部分,提高对话/任务完成的质量。
总结
| TODO | 操作 | 核心目的 |
|---|---|---|
| 1 | labels = [-100]*len(prompt) + response |
屏蔽 prompt 部分的损失 |
| 2 | 超长截断 / 不足填充(pad 处 label 填 -100) | 统一长度,填充位置不计算损失 |
| 3 | 切掉 logits 最后一维、切掉 labels 第一维 | 对齐自回归预测与真实标签 |
| 4 | 展平后用 CrossEntropyLoss(ignore_index=-100) |
只计算有效 response 部分的交叉熵 |
这四个步骤完整实现了 SFT 的数据处理和损失计算流程,是当前大语言模型微调的基础。
更多推荐
所有评论(0)