交叉熵损失函数

大语言模型微调及其应用的探索 跟踪前沿的技术

交叉熵损失函数

  1. torch.nn.functional.cross_entropy(…)
  2. 第一个参数是:logits, 第二个参数是: targets
  3. logits: 是模型做出最终概率化之前, 为每一个token打的分
  4. 或者可以称为它为置信度(信心程度), 值越高可信度越高
  5. targets中存放的是”答案”

cross_entropy做什么?

  1. 对每个样本做了softmax, 将logits转成了可能性概率
  2. 获得每个样本中,当前位置token答案的概率
  3. 将当前批次中所有的这些答案概率连接起来
  4. 将连接好的结果作为参数转给-log(p), 得到损失