pytorch api文档:torch.softmax() 作者:马育民 • 2026-01-17 11:39 • 阅读:10001 # 介绍 `torch.softmax()` 函数(实际常用 `torch.nn.functional.softmax`),这是深度学习多分类任务中最核心的激活函数 # softmax函数公式 对于输入: $$z = [z\_1,z\_2,\ldots,z\_n] $$ softmax 计算每个元素的输出: $$\text{softmax}(z\_i) = \frac{e^{z\_i}}{\sum\_{j=1}^n e^{z\_j}}$$ $$ \text{j} = 1,2,3, \ldots ,n$$ ### 特征 - 输出值在 `(0, 1)` 之间; - 所有输出值之 **和为 1**; - 指数特性会**放大数值差异**(大值更大,小值更小),适合突出“最可能的类别”。 # 语法 ``` torch.nn.functional.softmax(input, dim=None, dtype=None) ``` 将任意实数范围的输入张量(通常叫 logits),**按指定维度归一化为 0~1 之间的概率分布**,且该维度下所有元素的总和为 1。 ### 参数 | 参数 | 作用 | 注意事项 | |--------|----------------------------------------------------------------------|--------------------------------------------------------------------------| | `input` | 输入张量(模型输出的 logits,无需提前归一化)| 通常为 float32/float64,整数张量需先转换类型 | | `dim` | 指定归一化的维度(核心!)| 多分类中:一维张量用 `dim=0`;批量二维张量用 `dim=1`(类别维度) | | `dtype` | 指定输出张量的数据类型 | 可选,默认与输入一致,用于精度控制(如 float16 训练) | ### dim 参数的理解(轴的编号) 张量的维度编号从 0 开始,比如 `shape [2,3]` 的张量: - `dim=0`:垂直方向(行)求和; - `dim=1`:水平方向(列)求和。 ### dim 参数选错的后果 ```python # 错误示例:对 batch 维度(dim=0)归一化(无意义) wrong_softmax = F.softmax(batch_logits, dim=0) print("错误dim=0的输出:\n", wrong_softmax) print("错误求和验证:", wrong_softmax.sum(dim=0)) # tensor([1.0000, 1.0000, 1.0000]) # 后果:每个「类别列」求和为1,而非每个样本的「类别行」求和为1,完全不符合分类逻辑 ``` # 例子 先通过代码掌握核心用法,注意 `dim` 参数是**最关键的必传参数**: ```python import torch import torch.nn.functional as F # 示例1:一维张量(单样本,3个类别) logits = torch.tensor([1.0, 2.0, 3.0]) # 模型原始输出(logits) softmax_out = F.softmax(logits, dim=0) print("softmax输出:", softmax_out) # tensor([0.0900, 0.2447, 0.6652]) print("求和验证:", softmax_out.sum()) # tensor(1.0)(满足概率分布) # 示例2:二维张量(批量样本,2个样本×3个类别) batch_logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # dim=1:对每个样本的「类别维度」做归一化(核心场景) batch_softmax = F.softmax(batch_logits, dim=1) print("批量softmax输出:\n", batch_softmax) # 输出: # [[0.0900, 0.2447, 0.6652], # 第一个样本的类别概率 # [0.0900, 0.2447, 0.6652]] # 第二个样本的类别概率 print("每个样本求和:", batch_softmax.sum(dim=1)) # tensor([1.0000, 1.0000]) ``` # 注意 ### 数值稳定性(内部优化) 直接按公式计算 softmax 容易因 `e^x` 溢出(比如 logits 为 1000 时,`e^1000` 超出浮点数范围),PyTorch 的 `softmax` 内部已做**数值稳定化优化**: ```python # 原理:先减去输入维度的最大值,再计算指数 def stable_softmax(x, dim): x_max = torch.max(x, dim=dim, keepdim=True)[0] x_stable = x - x_max # 避免指数溢出 exp_x = torch.exp(x_stable) return exp_x / exp_x.sum(dim=dim, keepdim=True) # 验证溢出场景 overflow_logits = torch.tensor([1000.0, 1001.0, 1002.0]) # PyTorch 官方 softmax 正常输出 print(F.softmax(overflow_logits, dim=0)) # tensor([0.0900, 0.2447, 0.6652]) # 手动实现的稳定版也一致 print(stable_softmax(overflow_logits, dim=0)) # tensor([0.0900, 0.2447, 0.6652]) ``` **提示:**无需手动做这个优化,PyTorch 已内置实现! # 应用场景 ### 多分类任务的完整流程 ```python # 模拟:模型输出 logits → softmax 转概率 → 取最大概率的类别 logits = torch.tensor([[2.5, 1.8, 3.2], [0.5, 4.1, 2.0]]) prob = F.softmax(logits, dim=1) pred = torch.argmax(prob, dim=1) # 取概率最大的类别索引 print("类别概率:\n", prob) print("预测类别:", pred) # tensor([2, 1])(第一个样本预测类别2,第二个预测类别1) ``` ### (2)不要和 CrossEntropyLoss 重复使用 PyTorch 的 `nn.CrossEntropyLoss` 内部**已经包含了 softmax 计算**,如果输入是 logits,无需提前做 softmax: ```python import torch.nn as nn # 错误:logits → softmax → CrossEntropyLoss(重复计算,结果错误) logits = torch.tensor([[1.0, 2.0, 3.0]]) label = torch.tensor([2]) prob = F.softmax(logits, dim=1) loss_wrong = nn.CrossEntropyLoss()(prob, label) # 正确:直接用 logits 输入 CrossEntropyLoss loss_right = nn.CrossEntropyLoss()(logits, label) print("错误损失:", loss_wrong) # tensor(1.1019) print("正确损失:", loss_right) # tensor(0.4076)(符合预期) ``` ### (3)softmax 与 log_softmax 如果需要计算 log 概率(比如配合 `NLLLoss`),优先用 `F.log_softmax` 而非 `torch.log(F.softmax())`,数值更稳定: ```python # 推荐:直接用 log_softmax log_prob = F.log_softmax(logits, dim=1) # 不推荐:log(softmax) 可能因数值过小导致精度丢失 log_prob_bad = torch.log(F.softmax(logits, dim=1)) print("log_softmax输出:\n", log_prob) print("数值一致性:", torch.allclose(log_prob, log_prob_bad)) # True(简单场景下一致) ``` # 总结 1. `torch.nn.functional.softmax` 核心作用是**按指定维度将 logits 归一化为和为 1 的概率分布**,是多分类任务的标配。 2. `dim` 参数是关键:批量多分类中必须设为 `dim=1`(对每个样本的类别维度归一化),选错维度会导致结果无意义。 3. 实战注意:`CrossEntropyLoss` 内置 softmax,输入无需提前处理;需 log 概率时用 `log_softmax` 更稳定。 原文出处:http://malaoshi.top/show_1GW2blPk6St3.html