pytorch api文档:torch.transpose() 函数(转置操作、行列互换) 作者:马育民 • 2026-01-17 17:25 • 阅读:10001 # 介绍 `torch.transpose()` 函数,这是比 `.T` 更灵活的张量维度互换工具,能精准指定任意两个维度进行转置,是处理高维张量(如批量矩阵、图像、序列)的核心函数,我会从基础用法、参数细节到实战场景帮你全面掌握。 # 作用 对输入张量的**指定两个维度(dim0 和 dim1)** 进行互换(转置),其他维度保持不变,返回原张量的视图(view,共享内存)而非副本。 - 对比 `.T`:`.T` 是“反转所有维度”(仅适合二维张量),而 `torch.transpose()` 是“精准互换两个维度”(适合所有维度张量); - 核心优势:灵活性高,可针对高维张量的任意两个维度做转置,不影响其他维度。 # 语法 ``` torch.transpose(input, dim0, dim1) ``` ### 参数 | 参数 | 作用 | 注意事项 | |---------|----------------------------------------------------------------------|--------------------------------------------------------------------------| | `input` | 待转置的输入张量 | 支持任意维度(1维/2维/高维),1维张量转置无效果(dim0=dim1=0 无意义)| | `dim0` | 要互换的第一个维度(索引从 0 开始)| 必须是合法的维度索引(如 shape [2,3,4] 的维度索引只能是 0/1/2)| | `dim1` | 要互换的第二个维度 | 必须与 dim0 不同,否则返回原张量 | ### 易错点:维度索引错误 ```python # 错误示例:dim0/dim1 超出张量维度范围 x = torch.randn(2,3) # torch.transpose(x, 0, 2) # 报错:Dimension out of range (expected to be in range of [-2, 1], but got 2) # 错误示例:dim0=dim1(无意义) x_trans_bad = torch.transpose(x, 0, 0) print(torch.equal(x, x_trans_bad)) # True(返回原张量) ``` # 例子 先通过不同维度的张量示例,理解 `dim0` 和 `dim1` 参数的作用: ```python import torch # 示例1:二维张量(矩阵)转置(等价于 .T) x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape [2, 3](2行3列) # 互换维度0(行)和维度1(列) x_trans = torch.transpose(x, 0, 1) print("原张量:\n", x) print("转置后张量:\n", x_trans) # 输出: # tensor([[1, 4], # [2, 5], # [3, 6]]) # shape [3, 2] print("转置后形状:", x_trans.shape) # torch.Size([3, 2]) # 示例2:三维张量(批量矩阵)转置(核心场景) # shape [batch, row, col] = [2, 3, 4](2个 3×4 的矩阵) batch_mat = torch.tensor([[[1,2,3,4], [5,6,7,8], [9,10,11,12]], [[13,14,15,16], [17,18,19,20], [21,22,23,24]]]) # 对每个批量矩阵,互换 row(dim1)和 col(dim2)→ 批量矩阵转置 batch_mat_trans = torch.transpose(batch_mat, 1, 2) print("原批量矩阵形状:", batch_mat.shape) # [2,3,4] print("转置后形状:", batch_mat_trans.shape) # [2,4,3](每个矩阵从 3×4 → 4×3) print("第一个矩阵转置结果:\n", batch_mat_trans[0]) # 输出: # [[ 1, 5, 9], # [ 2, 6, 10], # [ 3, 7, 11], # [ 4, 8, 12]] # 示例3:四维张量(图像数据)转置 # shape [batch, H, W, C] = [4, 28, 28, 3](4张28×28像素的3通道图像) img = torch.randn(4, 28, 28, 3) # 互换 W(dim2)和 C(dim3)→ 调整通道维度位置 img_trans = torch.transpose(img, 2, 3) print("图像转置后形状:", img_trans.shape) # [4,28,3,28] ``` # 返回视图,共享内存 `torch.transpose()` 不会创建新张量,转置后的张量与原张量共享内存,修改其一会同步修改另一个: ```python x = torch.tensor([[1,2],[3,4]]) x_trans = torch.transpose(x, 0, 1) # 修改转置后的张量 x_trans[0,1] = 100 print("转置后张量:\n", x_trans) # tensor([[ 1, 100], # [ 2, 4]]) print("原张量(同步修改):\n", x) # tensor([[ 1, 2], # [100, 4]]) # 如需独立张量,加 .clone() x_trans_clone = torch.transpose(x, 0, 1).clone() x_trans_clone[0,1] = 200 print("原张量(不受影响):\n", x) # 仍为 [[1,2],[100,4]] ``` # 与 .T、permute() 的区别 PyTorch 中有三种维度调整工具,需根据场景选择: | 工具 | 核心逻辑 | 适用场景 | |-----------------------|---------------------------|-----------------------------------| | `tensor.T` | 反转所有维度(如 [a,b,c]→[c,b,a]) | 仅二维张量转置(矩阵)| | `torch.transpose()` | 互换指定两个维度 | 仅需互换两个维度(如批量矩阵转置)| | `torch.permute()` | 重新排列所有维度 | 需调整多个维度(如 [b,h,w,c]→[b,c,h,w]) | ```python # 示例:permute 调整多个维度(比 transpose 更灵活) img = torch.randn(4, 28, 28, 3) # [batch, H, W, C] # 目标:转为 [batch, C, H, W](PyTorch 图像标准格式) # 方法1:两次 transpose(麻烦) img_trans1 = torch.transpose(torch.transpose(img, 1, 3), 2, 3) # 方法2:一次 permute(简洁) img_permute = img.permute(0, 3, 1, 2) print("transpose 结果形状:", img_trans1.shape) # [4,3,28,28] print("permute 结果形状:", img_permute.shape) # [4,3,28,28] ``` # 可求梯度的张量转置 如果原张量开启 `requires_grad=True`(可训练参数),转置后的张量仍保留梯度属性,不影响反向传播: ```python x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) x_trans = torch.transpose(x, 0, 1) # 模拟反向传播 y = x_trans.sum() y.backward() print("原张量的梯度:\n", x.grad) # tensor([[1., 1.], # [1., 1.]])(转置不影响梯度计算) ``` # 应用场景 ### (1)批量矩阵乘法 处理批量矩阵时,常需要转置其中一个张量以满足矩阵乘法的维度要求: ```python # 场景:2个 3×4 矩阵 和 2个 3×4 矩阵做批量矩阵乘法(需转置第二个矩阵) a = torch.randn(2, 3, 4) # [batch, m, n] b = torch.randn(2, 3, 4) # [batch, m, n] # 对b的每个矩阵转置(dim1和dim2互换)→ [batch, n, m] = [2,4,3] b_trans = torch.transpose(b, 1, 2) # 批量矩阵乘法:(2,3,4) × (2,4,3) → (2,3,3) result = a @ b_trans print("批量矩阵乘法结果形状:", result.shape) # [2,3,3] ``` ### (2)图像维度调整 将 OpenCV/PIL 格式的图像(H,W,C)转为 PyTorch 格式(C,H,W),或反之: ```python # 模拟:读取图像(H=28, W=28, C=3) img_cv = torch.randn(28, 28, 3) # 转为 PyTorch 格式(C,H,W):互换 H(dim0) 和 C(dim2) img_torch = torch.transpose(torch.transpose(img_cv, 0, 2), 1, 2) print("PyTorch格式形状:", img_torch.shape) # [3,28,28] # 转回 OpenCV 格式:反向转置 img_cv2 = torch.transpose(torch.transpose(img_torch, 0, 2), 1, 2) print("转回后形状:", img_cv2.shape) # [28,28,3] ``` # 总结 1. `torch.transpose(input, dim0, dim1)` 是**精准互换两个维度**的转置函数,其他维度保持不变,比 `.T` 更灵活。 2. 核心特性:返回原张量的视图(共享内存),需独立张量则加 `.clone()`;支持任意维度张量,1维张量转置无效果。 3. 场景选择:二维矩阵转置可用 `.T` 或 `transpose(0,1)`;高维张量仅互换两个维度用 `transpose()`;调整多个维度用 `permute()`。 原文出处:http://malaoshi.top/show_1GW2bqM8Ufqm.html