背景

这段代码实现的是监督微调(Supervised Fine-Tuning, SFT) 中常用的两个函数:

  • build_sft_data:将一条 prompt-response 对处理成模型可直接使用的 input_idslabels
  • compute_sft_loss:计算因果语言模型的自回归损失,只对 response 部分求梯度。

训练时,我们希望模型在看到 prompt 后能生成正确的 response,但 prompt 部分不计算损失(只作为条件),因此需要对标签做特殊处理。


TODO 1:构造 labels

labels = [-100] * len(prompt_ids) + response_ids

作用:创建与 input_ids 等长的标签序列,其中 prompt 部分用 -100 填充,response 部分保留原始的 token id。

为什么是 -100
PyTorch 的 nn.CrossEntropyLoss(ignore_index=-100) 会忽略标签值为 -100 的位置,这些位置不贡献任何梯度,也不会影响损失值。这样模型在 prompt 区域不会被迫去“预测”这些 token,只有 response 区域才会计算损失并更新参数。

示例

  • prompt_ids = [1, 2, 3]
  • response_ids = [4, 5, 6]
  • labels = [-100, -100, -100, 4, 5, 6]
    模型在位置 0~2 的预测不会被惩罚,位置 3~5 的预测目标分别是 4、5、6。

TODO 2:截断与填充

# 超长截断
if len(input_ids) > max_len:
    input_ids = input_ids[:max_len]
    labels = labels[:max_len]
# 不足填充
else:
    pad_len = max_len - len(input_ids)
    input_ids = input_ids + [pad_id] * pad_len
    labels = labels + [-100] * pad_len

作用:保证每个样本的长度统一为 max_len,便于批处理。

  • 超长截断:直接保留前 max_len 个 token(input_ids[:max_len])。如果超长部分包含 response,则模型会失去部分生成目标;通常设置足够大的 max_len 来避免这种情况。
  • 不足填充:在序列末尾添加 pad_id(一般是 0 或特殊填充符)作为输入,同时标签对应位置填充 -100
    为什么标签填 -100? 因为填充符是无效 token,不应参与损失计算。ignore_index=-100 同样会跳过这些位置。

TODO 3:实现 Shift 错位对齐

shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

作用:将模型输出的 logits 与标签对齐,以满足因果语言模型(自回归) 的预测方式。

原理
在自回归模型中,我们使用位置 t 的输出去预测位置 t+1 的 token。因此:

  • 模型输入一个长度为 L 的序列,输出的 logits 形状为 [batch, L, vocab_size],其中第 t 个位置的输出对应对第 t+1 个 token 的预测
  • 我们的标签 labels 长度为 L,其中 labels[t] 是位置 t 处的真实 token。

要计算损失,我们需要将 t 个预测t+1 的真实 token 对齐。所以:

  • logits[..., :-1, :]:丢弃最后一个位置的预测(因为没有 L+1 的真实 token)。
  • labels[..., 1:]:丢弃第一个位置的标签(因为第一个 token 没有前一个位置的预测来对应它)。
  • 对齐后长度变为 L-1,且 shift_logits[t] 对应 shift_labels[t](即原序列的 labels[t+1])。

.contiguous():确保张量在内存中连续存储,为后续 view 操作做准备。

示例
序列 [A, B, C, D],模型看到 [A, B, C] 后预测 [B, C, D],logits 的形状是 [4, V]。去掉 logits 最后一个位置(预测 E 的无用输出),去掉 labels 第一个位置(A 不需要被预测),即可对齐。


TODO 4:展平并计算交叉熵损失

loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = loss_fct(shift_logits, shift_labels)

作用:将对齐后的 logits 和 labels 展平为二维和一维,然后计算交叉熵损失,并自动忽略 label 为 -100 的位置。

  • 展平shift_logits.view(-1, vocab_size) 把所有 batch 和序列位置合并成一个大的二维矩阵 [总 token 数, vocab_size]shift_labels.view(-1) 变成一维向量 [总 token 数]。这是因为 CrossEntropyLoss 要求输入为 (N, C)(N),其中 N 是样本总数,C 是类别数。
  • ignore_index=-100:之前我们在 prompt 部分和填充部分都使用了 -100 作为标签,现在损失函数会自动跳过这些位置的损失计算,只对 response 部分的有效 token 计算交叉熵。
  • 最终的 loss 是一个标量,表示 batch 中所有有效 token 的平均损失。

为什么这样设计损失?
在 SFT 中,我们只关心模型学习生成 response 的能力,prompt 仅作为上下文。因此 prompt 位置不贡献梯度,模型可以专注于优化 answer 部分,提高对话/任务完成的质量。


总结

TODO 操作 核心目的
1 labels = [-100]*len(prompt) + response 屏蔽 prompt 部分的损失
2 超长截断 / 不足填充(pad 处 label 填 -100) 统一长度,填充位置不计算损失
3 切掉 logits 最后一维、切掉 labels 第一维 对齐自回归预测与真实标签
4 展平后用 CrossEntropyLoss(ignore_index=-100) 只计算有效 response 部分的交叉熵

这四个步骤完整实现了 SFT 的数据处理和损失计算流程,是当前大语言模型微调的基础。

Logo

智能硬件社区聚焦AI智能硬件技术生态,汇聚嵌入式AI、物联网硬件开发者,打造交流分享平台,同步全国赛事资讯、开展 OPC 核心人才招募,助力技术落地与开发者成长。

更多推荐