大模型原理:为什么缩放注意力分数 作者:马育民 • 2026-01-18 18:53 • 阅读:10000 # 一句话解释 因为点积注意力的分数会随着 **向量维度 增大而变大,导致 Softmax 梯度变小,模型训练困难。缩放可以让数值回到合理范围,让训练稳定。** ------------------------------------ # 为什么分数会“爆炸”? 假设 Query 和 Key 的维度是 d_model(比如 512 或 1024)。 点积 = Q · K = 求和(Q_i * K_i) 共 d_model 项。 如果 Q、K 的元素均值为 0、方差为 1,那么点积的均值是 0,方差是 d_model。 也就是说: - d_model = 512 → 点积的方差 = 512 - d_model = 1024 → 点积的方差 = 1024 维度越大,点积结果越大。 这会导致分数变得非常大。 ## 理解 #### 类比 可以把点积想象成“扔骰子求和”: - 扔1个骰子(d_model=1):和的范围是1~6,波动小; - 扔10个骰子(d_model=10):和的范围是10~60,波动大; - 扔1024个骰子(d_model=1024):和的范围能到上千,波动极大。 Q/K的点积就像“扔d_model个方差为1的骰子求和”,骰子数越多(维度越大),和的波动(方差)就越大,数值自然就越容易变得很大。 #### 步骤1:均值为0、方差为1的意思 假设Q和K的每个元素都是「均值=0,方差=1」的随机数(这是模型初始化时的常见分布,比如正态分布): - 均值=0:元素值在0附近波动(有正有负,平均下来是0); - 方差=1:元素值的波动范围大概在[-2, 2]之间(正态分布95%的值在±2倍方差内)。 举个例子,d_model=2时,Q=[0.5, -1.2],K=[1.1, 0.8],每个元素都符合“均值0、方差1”。 #### 步骤2:点积的本质是“多个数相加” 点积 = Q₁×K₁ + Q₂×K₂ + ... + Q_d×K_d(共d_model项相加)。 我们先分析**每一项(Q_i×K_i)的均值和方差**: - 因为Q_i和K_i独立,且均值都是0,所以 Q_i×K_i 的均值 = 均值(Q_i) × 均值(K_i) = 0×0 = 0; - 方差的性质:独立变量相乘的方差 ≈ 方差(Q_i) × 方差(K_i) = 1×1 = 1(简化理解,严格来说是 Var(XY)=E[X²]E[Y²] - (E[X]E[Y])²,这里E[X²]=Var(X)+E[X]²=1+0=1,所以Var(XY)=1×1 - 0=1)。 简单说:**点积里的每一项(Q_i×K_i)都是“均值0、方差1”的数**。 #### 步骤3:多个独立数相加,方差会“累加” 方差的核心性质:**独立变量相加,方差等于各方差之和**。 比如: - 1项相加(d_model=1):方差=1; - 2项相加(d_model=2):方差=1+1=2; - 3项相加(d_model=3):方差=1+1+1=3; - ... - d_model项相加:方差=1×d_model = d_model。 而**均值**:独立变量相加,均值等于各均值之和,每一项均值都是0,所以总均值=0+0+...+0=0。 #### 步骤4:用具体例子验证 | 维度d_model | 点积的构成(每一项都是均值0、方差1) | 点积的均值 | 点积的方差 | 点积的典型数值范围(±2倍方差) | |------------|--------------------------------------|------------|------------|--------------------------------| | 2 | Q₁K₁ + Q₂K₂ | 0 | 2 | [-√8, √8] ≈ [-2.8, 2.8] | | 4 | Q₁K₁ + Q₂K₂ + Q₃K₃ + Q₄K₄ | 0 | 4 | [-√16, √16] = [-4, 4] | | 512 | 512个Q_iK_i相加 | 0 | 512 | [-√2048, √2048] ≈ [-45.2, 45.2]| | 1024 | 1024个Q_iK_i相加 | 0 | 1024 | [-√4096, √4096] = [-64, 64] | 你看: - 维度从2涨到1024,点积的方差从2涨到1024; - 数值范围从±2.8涨到±64,**维度越大,点积的数值就越容易出现很大的数**(比如1024维时,点积可能达到±64,而2维时最多±2.8)。 #### 步骤5:为什么“方差大”就等于“分数大”? 方差代表“数值的波动幅度”: - 方差小(比如2):数值只在±2.8之间晃,不会太大; - 方差大(比如1024):数值会频繁出现±60、±50这样的大数。 这就是“维度越大,点积结果越大”的本质——不是每个点积都一定大,而是**出现大数的概率会急剧增加**,整体数值范围被拉得很宽。 # 分数太大会导致什么问题? Softmax 对非常大的数值会变得“极端”。 例如: ``` Softmax([10, 20, 30]) ≈ [0, 0, 1] ``` 所有概率都集中在最大的那个值上,其他几乎为 0。 这会导致: - 注意力分布变得“过于尖锐” - 梯度变得非常小(接近 0) - 模型难以训练(梯度消失) 这在大模型中尤其严重,因为维度通常是 512、1024、2048 甚至更高。 ------------------------------------ # 缩放解决了什么? Transformer 论文提出: ``` scaled_scores = scores / sqrt(d_k) ``` 这样可以让点积的方差从 d_k 变回 1。 结果: - 数值范围合理 - Softmax 不会饱和 - 梯度正常 - 训练稳定 这就是为什么缩放对大模型至关重要。 ------------------------------------ # 例子 ### 简单例子 ``` import torch score = torch.tensor([100.0, 120.0, 150.0]) print("注意力分数:",score) weight1 = torch.softmax(score, dim=-1) print("未缩放,直接执行softmax:", weight1) scale_score = score/32 print("缩放分数:", scale_score) weight2 = torch.softmax(scale_score, dim=-1) print("缩放后,执行softmax::", weight2) ``` 执行结果: ``` 注意力分数: tensor([100., 120., 150.]) 未缩放,直接执行softmax: tensor([1.9287e-22, 9.3576e-14, 1.0000e+00]) 缩放分数: tensor([3.1250, 3.7500, 4.6875]) 缩放后,执行softmax:: tensor([0.1309, 0.2446, 0.6245]) ``` - 未缩放,直接执行softmax,结果:`tensor([1.9287e-22, 9.3576e-14, 1.0000e+00])`,所有概率都集中在最大的那个值上,其他几乎为 0。 - 缩放后的softmax结果:`tensor([0.1309, 0.2446, 0.6245])`,更加平滑 ### 完整例子 ``` import torch import torch.nn.functional as F # ====================== 1. 模拟大模型的注意力分数 ====================== # 模拟 d_k=1024 时的注意力分数(3个位置的分数,数值上百) scores = torch.tensor([100.0, 120.0, 150.0], requires_grad=True) # 需计算梯度 print("===== 原始注意力分数(未缩放) =====") print(f"分数值:{scores.tolist()}") # ====================== 2. 未缩放:Softmax 极端化 + 梯度消失 ====================== # 计算未缩放的 Softmax softmax_unscaled = F.softmax(scores, dim=0) print("\n===== 未缩放的 Softmax 结果 =====") print(f"Softmax 输出:{softmax_unscaled.tolist()}") print(f"Softmax 求和:{softmax_unscaled.sum().item():.4f}") # 验证和为1 # 计算 Softmax 第一个位置的值的梯度(模拟真实训练中的损失反向传播) softmax_unscaled[0].backward(retain_graph=True) # 求第一个位置的梯度 grad_unscaled = scores.grad.clone() # 复制梯度 scores.grad.zero_() # 清空梯度,准备下一次计算 print(f"未缩放的梯度(对第一个位置):{grad_unscaled.tolist()}") # ====================== 3. 缩放:分数正常 + Softmax 平滑 + 梯度正常 ====================== # 步骤1:计算缩放因子(d_k=1024,sqrt(1024)=32) d_k = 1024 scale_factor = torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) # 步骤2:缩放注意力分数 scores_scaled = scores / scale_factor print("\n===== 缩放后的注意力分数 =====") print(f"缩放因子:{scale_factor.item()}") print(f"缩放后分数:{scores_scaled.tolist()}") # 步骤3:计算缩放后的 Softmax softmax_scaled = F.softmax(scores_scaled, dim=0) print("\n===== 缩放后的 Softmax 结果 =====") print(f"Softmax 输出:{softmax_scaled.tolist()}") print(f"Softmax 求和:{softmax_scaled.sum().item():.4f}") # 计算缩放后 Softmax 第一个位置的值的梯度 softmax_scaled[0].backward() grad_scaled = scores.grad.clone() print(f"缩放后的梯度(对第一个位置):{grad_scaled.tolist()}") # ====================== 4. 核心对比总结 ====================== print("\n===== 核心对比 =====") print(f"未缩放 Softmax:几乎极端化 {softmax_unscaled.tolist()}") print(f"缩放后 Softmax:平滑分布 {softmax_scaled.tolist()}") print(f"未缩放梯度:接近 0 {grad_unscaled.tolist()}") print(f"缩放后梯度:正常数值 {grad_scaled.tolist()}") ``` 执行结果: ``` ===== 原始注意力分数(未缩放) ===== 分数值:[100.0, 120.0, 150.0] ===== 未缩放的 Softmax 结果 ===== Softmax 输出:[1.9287498933537385e-22, 9.357622912219837e-14, 1.0] Softmax 求和:1.0000 未缩放的梯度(对第一个位置):[1.9287498933537385e-22, -1.8048514720778033e-35, -1.9287498933537385e-22] ===== 缩放后的注意力分数 ===== 缩放因子:32.0 缩放后分数:[3.125, 3.75, 4.6875] ===== 缩放后的 Softmax 结果 ===== Softmax 输出:[0.13090753555297852, 0.24456748366355896, 0.6245249509811401] Softmax 求和:1.0000 缩放后的梯度(对第一个位置):[0.0035553360357880592, -0.0010004914365708828, -0.002554844366386533] ===== 核心对比 ===== 未缩放 Softmax:几乎极端化 [1.9287498933537385e-22, 9.357622912219837e-14, 1.0] 缩放后 Softmax:平滑分布 [0.13090753555297852, 0.24456748366355896, 0.6245249509811401] 未缩放梯度:接近 0 [1.9287498933537385e-22, -1.8048514720778033e-35, -1.9287498933537385e-22] 缩放后梯度:正常数值 [0.0035553360357880592, -0.0010004914365708828, -0.002554844366386533] ``` # 为什么大模型尤其需要缩放? 因为大模型通常有: - 更高的维度(512 → 1024 → 2048) - 更深的网络 - 更复杂的注意力分布 维度越大,分数爆炸越严重,所以缩放对大模型是必须的。 ------------------------------------ # 总结 **缩放注意力分数是为了避免 Softmax 饱和,防止梯度消失,让大模型训练稳定。** 原文出处:http://malaoshi.top/show_1GW2cFPwovqG.html