pytorch api文档:torch.no_grad()函数-关闭梯度计算 作者:马育民 • 2026-01-28 14:48 • 阅读:10003 # 作用 `torch.no_grad()` 是一个**上下文管理器**(Context Manager) **作用:** **关闭梯度计算**,即:在它包裹的代码块内,所有张量的计算都不会构建计算图,也不会计算和存储梯度。 ### 应用场景 - 模型推理(预测)阶段。 - 验证/测试集评估(计算准确率、损失等,无需反向传播)。 - 不需要梯度的张量计算(比如纯数值运算)。 # 例子 ### 用法1:作为上下文管理器(最常用) ```python import torch import torch.nn as nn # 定义一个简单的模型 class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) model = SimpleModel() # 切换到评估模式(和no_grad是互补的,推理时建议一起用) model.eval() # 模拟输入 input_tensor = torch.randn(5, 10) # 推理阶段使用no_grad() with torch.no_grad(): output = model(input_tensor) print("输出结果:", output) # 验证梯度是否关闭:模型参数的grad为None print("线性层权重的梯度:", model.linear.weight.grad) # 输出 None ``` ### 用法2:作为装饰器(装饰函数) ```python @torch.no_grad() def predict(model, x): model.eval() return model(x) # 调用函数,内部自动关闭梯度 output = predict(model, input_tensor) ``` ### 用法3:临时启用/禁用(较少用) ```python torch.no_grad().__enter__() # 开启no_grad状态 output1 = model(input_tensor) torch.no_grad().__exit__(None, None, None) # 关闭no_grad状态 # 恢复梯度计算 output2 = model(input_tensor) output2.sum().backward() # 可以正常计算梯度 print("反向传播后权重梯度:", model.linear.weight.grad) # 输出非None的张量 ``` # .eval() 和 .no_grad() `model.eval()` ≠ `torch.no_grad()`: - `model.eval()`:主要用于切换模型的评估模式(比如关闭 Dropout、BatchNorm 用训练时的统计量),**不影响梯度计算**。 - `torch.no_grad()`:仅关闭梯度计算,**不影响模型的模式**。 推理时建议**同时使用**:`model.eval()` + `with torch.no_grad()`。 # 总结 1. `torch.no_grad()` 的核心作用是**关闭梯度计算**,节省内存/显存并提升运行速度。 2. 最常用方式是 `with torch.no_grad():` 包裹推理代码,推理时需配合 `model.eval()`。 3. 它仅影响梯度计算,不改变模型的训练/评估模式,二者是互补关系。 原文出处:http://malaoshi.top/show_1GW2ftCYYlFM.html