pytorch api文档:torch.sqrt()函数-平方根 作者:马育民 • 2026-01-23 20:57 • 阅读:10000 # 介绍 `torch.sqrt()` 函数,是 PyTorch 中用于计算张量元素平方根的核心函数,非常常用。 # 语法 ``` torch.sqrt(input, *out=None) → Tensor ``` #### 参数 - input:需要计算平方根的输入张量,必须传入。必须是 **非负数**,否则会返回 `nan`(非数字),并且可能触发警告。 - out:输出张量的存储位置,避免函数重新创建新张量,节省内存(尤其对大张量) #### 隐藏的 “设备 / 数据类型” 逻辑 `torch.sqrt()` 没有显式的 `device`/`dtype` 参数,但输出张量的属性会继承 `input` 的特性: - 设备:输入在 GPU 上,输出也在 GPU;输入在 CPU 上,输出也在 CPU。 - 数据类型:输出默认是浮点型(即使输入是整数张量,也会转为 torch.float32/float64)。 示例: # 例子 下面是完整的可运行代码,展示 `torch.sqrt()` 的基本用法: ```python import torch # 1. 基础用法:对张量所有元素开平方 x = torch.tensor([1.0, 4.0, 9.0, 16.0]) sqrt_x = torch.sqrt(x) print("原始张量:", x) print("平方根结果:", sqrt_x) # 输出: tensor([1., 2., 3., 4.]) # 2. 二维张量示例 y = torch.tensor([[1.0, 81.0], [25.0, 49.0]]) sqrt_y = torch.sqrt(y) print("\n二维张量平方根结果:\n", sqrt_y) # 输出: # tensor([[1., 9.], # [5., 7.]]) # 3. 处理负数(会返回nan) z = torch.tensor([-1.0, 4.0]) sqrt_z = torch.sqrt(z) print("\n含负数的张量平方根结果:", sqrt_z) # 输出: tensor([nan, 2.]) ``` ### 三、进阶用法:避免负数报错/返回nan 如果你的数据可能包含负数,可以先对张量做处理(比如取绝对值),再计算平方根: ```python import torch # 处理可能含负数的张量 z = torch.tensor([-1.0, 4.0, -9.0]) # 先取绝对值,再开平方 sqrt_z_safe = torch.sqrt(torch.abs(z)) print("安全计算的平方根结果:", sqrt_z_safe) # 输出: tensor([1., 2., 3.]) ``` ### 四、等价写法 除了 `torch.sqrt(x)`,还可以用张量的方法调用: ```python x = torch.tensor([4.0, 9.0]) sqrt_x = x.sqrt() # 和 torch.sqrt(x) 效果完全一致 print(sqrt_x) # 输出: tensor([2., 3.]) ``` ### 总结 1. `torch.sqrt()` 是逐元素计算张量平方根的函数,输入需为非负数(否则返回nan)。 2. 常用写法有两种:`torch.sqrt(张量)` 或 `张量.sqrt()`。 3. 处理可能含负数的张量时,建议先用 `torch.abs()` 取绝对值,避免出现nan。 原文出处:http://malaoshi.top/show_1GW2e7wRErGM.html