Transformer模型训练中的损失函数设计踩坑记录
最近在训练Transformer模型时,遇到了一个令人头疼的问题:模型训练loss居高不下,且验证集表现糟糕。经过一番排查,发现是损失函数设计不当导致的。
问题复现
使用标准的交叉熵损失函数训练时出现了以下异常:
import torch
import torch.nn as nn
# 错误示例
loss_fn = nn.CrossEntropyLoss()
class TransformerModel(nn.Module):
def __init__(self):
super().__init__()
self.transformer = nn.TransformerEncoder(...)
def forward(self, x):
output = self.transformer(x)
return output # 直接输出logits
核心问题分析
通过调试发现,模型输出的logits维度与标签维度不匹配。在处理序列数据时,需要确保:
- 输出维度:(batch_size, seq_len, vocab_size)
- 标签维度:(batch_size, seq_len) 或 (batch_size*seq_len)
正确实现方式
# 正确示例
loss_fn = nn.CrossEntropyLoss(ignore_index=0) # 忽略padding token
# 注意:确保模型输出和标签维度匹配
outputs = model(input_ids)
loss = loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1))
关键踩坑点
- 维度不匹配:常见于多层嵌套处理后未正确reshape
- padding处理:未设置ignore_index导致训练不稳定
- 损失函数选择:在不同任务(分类vs生成)下应选用不同损失函数
建议在模型训练初期就加入维度检查代码,避免后期调试成本过高。

讨论