pytorch api文档:nn.functional.cross_entropy()交叉熵损失函数 作者:马育民 • 2026-01-28 19:27 • 阅读:10006 需要掌握:[损失函数-交叉熵(多分类)](https://www.malaoshi.top/show_1EF4La0KENhc.html "损失函数-交叉熵(多分类)") # 介绍 `torch.nn.functional.cross_entropy`(简称 `F.cross_entropy`) 是 PyTorch 中 **分类任务(含大模型文本生成)的核心损失函数** # 语法 ```python torch.nn.functional.cross_entropy( input: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, size_average: Optional[bool] = None, ignore_index: int = -100, reduce: Optional[bool] = None, reduction: str = 'mean', label_smoothing: float = 0.0 ) -> torch.Tensor ``` #### 参数解释 | 参数 | 类型 | 核心作用 | 实战注意事项 | |------|------|----------|--------------| | `input`(必选) | Tensor | 模型输出的 logits,形状:- 分类:`(batch_size, num_classes)`- 文本生成:`(batch_size, seq_len, vocab_size)` | ✅ 不能提前做 softmax(会破坏梯度);✅ 数据类型为浮点型(float32/float64) | | `target`(必选) | Tensor | 真实标签,形状:- 分类:`(batch_size,)`- 文本生成:`(batch_size, seq_len)` | ✅ 是**类别索引**(int64),不是 one-hot 编码;✅ 文本生成中 padding 位置设为 `-100` | | `weight`(可选) | Tensor | 类别权重,形状 `(num_classes,)` | 用于处理类别不平衡(如文本生成中低频 token 加权) | | `ignore_index`(可选) | int | 忽略的标签值,默认 `-100` | 文本生成中必须设置,忽略 padding token 的损失计算 | | `reduction`(可选) | str | 损失聚合方式,可选:- `mean`:默认,返回批次平均损失;- `sum`:返回批次总损失;- `none`:返回每个样本/位置的损失 | 文本生成中常用 `mean`,需聚合整个 batch 的损失 | | `label_smoothing`(可选) | float | 标签平滑系数(0~1) | 防止模型过自信,文本生成中常用 0.1~0.2 | | `size_average`/`reduce`(可选) | bool | 已弃用,由 `reduction` 替代 | 避免使用,仅兼容旧代码 | # 计算过程 ### 一、准备测试数据 为了方便手动计算,设定: - **模型输出logits**:`input = torch.tensor([[1.0, 2.0, 3.0]])`(batch_size=1,num_classes=3) - **真实标签**:`target = torch.tensor([2])`(类别索引为2,对应第三个类别) - 无权重、无ignore_index、默认reduction='mean' --- ### 二、手动分步计算 cross_entropy #### 步骤1:计算 softmax(将logits转为概率) softmax公式: $$p\_i = \frac{e^{z\_i}}{\sum\_{j=1}^C e^{z\_j}}$$ 解释:$$C$$ 为类别数,这里 $$C=3$$ - 其中,$$z\_0=1.0, z\_1=2.0, z\_2=3.0$$ - 分子: $$e^{1.0} \approx 2.71$$,$$e^{2.0} \approx 7.38$$,$$e^{3.0} \approx 20.08$$ - 分母(总和):$$2.71828 + 7.38906 + 20.08554 \approx 30.19288$$ - softmax结果: - 结果:$$p\_0 = 2.71/30.19 \approx 0.09$$ - 结果:$$p\_1 = 7.38/30.19 \approx 0.24$$ - 结果:$$p\_2 = 20.08/30.19 \approx 0.66$$ #### 步骤2:计算 log_softmax(对softmax结果取自然对数) $$\log(p_i)$$: - 结果:$$\log(0.09) \approx -2.40$$ - 结果:$$\log(0.24) \approx -1.40$$ - 结果:$$\log(0.66) \approx -0.40$$ #### 步骤3:计算负对数似然(NLL) 取真实标签对应位置的log_softmax值,取负数: - 真实标签是2 → 取 $$\log(p_2) \approx -0.4076$$ - NLL = $$-(-0.4076) = 0.4076$$ #### 步骤4:reduction='mean'(平均) 因为 `batch_size=1`,平均后结果仍为 **0.4076**(这就是最终的cross_entropy损失值)。 --- ### 三、PyTorch代码验证计算过程 ```python import torch import torch.nn.functional as F # 1. 定义极简数据 input = torch.tensor([[1.0, 2.0, 3.0]]) # logits target = torch.tensor([2]) # 真实类别索引 # 2. 直接计算cross_entropy loss = F.cross_entropy(input, target) print("F.cross_entropy 结果:", loss.item()) # 输出:0.4076079726219177 # 3. 分步拆解验证(和手动计算对应) # 步骤1+2:log_softmax log_softmax = F.log_softmax(input, dim=1) print("\nlog_softmax 结果:", log_softmax) # 输出:tensor([[-2.4076, -1.4076, -0.4076]]) # 步骤3:nll_loss(等价于cross_entropy) nll_loss = F.nll_loss(log_softmax, target) print("F.nll_loss 结果:", nll_loss.item()) # 输出:0.4076079726219177 ``` --- ### 文本生成场景的极简演示 针对大模型文本生成场景,再用「带序列+padding」的极简数据演示: #### 1. 数据准备 - logits:`input = torch.tensor([[[1.0,2.0], [3.0,4.0], [5.0,6.0]]])`(batch=1,seq_len=3,vocab_size=2) - target:`torch.tensor([[1, 0, -100]])`(标签1、0,最后一位是padding,ignore_index=-100) #### 2. 手动计算(仅有效位置) - 位置0:logits=[1,2],标签1 → 损失≈0.1269 - 位置1:logits=[3,4],标签0 → 损失≈1.3133 - 位置2:padding,忽略 - 平均损失:(0.1269 + 1.3133)/2 ≈ 0.7201 #### 3. 代码验证 ```python # 文本生成场景数据 input_seq = torch.tensor([[[1.0,2.0], [3.0,4.0], [5.0,6.0]]]) # (1,3,2) target_seq = torch.tensor([[1, 0, -100]]) # (1,3) # 展平计算(文本生成必备步骤) loss_seq = F.cross_entropy( input_seq.reshape(-1, 2), # 展平为(3,2) target_seq.reshape(-1), # 展平为(3,) ignore_index=-100 ) print("\n文本生成场景损失:", loss_seq.item()) # 输出:0.7201269865036011 ``` --- ### 总结 1. **核心步骤**:cross_entropy = softmax → log → 取真实标签位置负值 → 聚合(mean/sum); 2. **极简验证**:单样本3分类场景下,手动计算结果(0.4076)和PyTorch输出完全一致; 3. **文本生成适配**:需展平logits和target,通过ignore_index忽略padding,仅计算有效token的损失。 # 例子 ### 1. 基础分类任务(2分类) ```python import torch import torch.nn.functional as F # 模型输出:batch_size=2,num_classes=2 logits = torch.tensor([[1.2, 0.8], [0.3, 1.5]]) # 真实标签:类别索引(0/1) target = torch.tensor([0, 1]) # 基础损失计算 loss = F.cross_entropy(logits, target) print("基础损失:", loss.item()) # 输出:0.4163(近似值) ``` ### 2. 大模型文本生成(核心场景) ```python # 模拟大模型输出:batch_size=2,seq_len=3,vocab_size=5 logits = torch.randn(2, 3, 5) # 模拟标签:包含padding(-100),shape=(2,3) target = torch.tensor([[1, 3, -100], [2, 4, 1]]) # 计算文本生成损失(忽略padding) loss = F.cross_entropy( logits.reshape(-1, 5), # 展平为 (6,5) target.reshape(-1), # 展平为 (6,) ignore_index=-100, # 忽略padding reduction='mean' # 平均有效位置的损失 ) print("文本生成损失:", loss.item()) ``` ### 3. 类别加权(处理不平衡) ```python # 3分类任务,类别0权重1,类别1权重2,类别2权重3 weight = torch.tensor([1.0, 2.0, 3.0]) logits = torch.tensor([[0.5, 1.0, 1.5], [2.0, 0.8, 0.3]]) target = torch.tensor([2, 0]) # 带权重的损失 loss = F.cross_entropy(logits, target, weight=weight) print("加权损失:", loss.item()) ``` --- # 避坑点(结合大模型场景) ### **坑1:input 提前做 softmax** ❌ 错误:`logits = F.softmax(logits, dim=-1)` 后再算损失; ✅ 正确:直接用模型原始 logits 输入,函数内部会自动做 log_softmax。 原因:提前 softmax 会导致梯度消失,尤其是大模型深层网络。 ### **坑2:target 用 one-hot 编码** ❌ 错误:`target = torch.tensor([[0,1,0], [1,0,0]])`(one-hot); ✅ 正确:转成类别索引 `target = torch.argmax(one_hot, dim=-1)`。 ### **坑3:未处理 padding 导致损失异常** ❌ 错误:target 中 padding 设为 0,未设置 `ignore_index`; ✅ 正确:padding 设为 `-100`,并指定 `ignore_index=-100`,避免无效梯度。 ### **坑4:维度不匹配** 文本生成中,input 最后一维必须等于词表大小,target 取值范围必须在 `[0, vocab_size-1]`(padding 除外)。 --- # 实现原理 `F.cross_entropy` 等价于以下两步操作,帮你理解底层实现: ```python # 方式1:直接用 F.cross_entropy loss1 = F.cross_entropy(logits, target, ignore_index=-100) # 方式2:手动拆解(等价) log_softmax = F.log_softmax(logits, dim=-1) # 对最后一维做log_softmax loss2 = F.nll_loss(log_softmax, target, ignore_index=-100) print(loss1 == loss2) # 输出:True ``` --- # 总结 1. **核心本质**:`F.cross_entropy = log_softmax + nll_loss`,无需手动处理 logits 归一化; 2. **关键参数**:`input`(原始 logits)、`target`(类别索引)、`ignore_index`(忽略 padding)是文本生成场景的核心参数; 3. **避坑重点**:input 不提前 softmax、target 不用 one-hot、必须处理 padding; 4. **场景适配**:文本生成中需展平 logits 和 target,仅计算有效 token 的损失,是大模型训练的核心损失计算方式。 原文出处:http://malaoshi.top/show_1GW2fy7aYCus.html