pytorch api文档:torch.flatten()函数-将张量展平成一维 作者:马育民 • 2026-01-28 15:19 • 阅读:10003 # 介绍 PyTorch 中的 `torch.flatten()` 函数,这是用于将张量(Tensor)展平成一维的常用函数。 ### 作用 `torch.flatten()` 的核心是将 **任意形状的张量**,按照指定维度范围压缩成 **一维张量**,是深度学习中(如全连接层前的特征展平)最常用的操作之一。 ### 使用场景 1. **CNN + 全连接层**:CNN 输出是 4 维张量 `(batch, channel, h, w)`,全连接层需要 2 维输入 `(batch, feature)`,此时用 `torch.flatten(x, start_dim=1)` 是标准操作; 2. **可视化/统计**:需要将多维数据转为一维,方便计算均值、方差或绘制分布图; 3. **数据预处理**:将高维特征(如图片像素)转为一维特征向量。 # 语法 ```python torch.flatten(input, start_dim=0, end_dim=-1) ``` #### 参数解释 | 参数 | 类型 | 默认值 | 核心作用 | |--------------|---------|--------|--------------------------------------------------------------------------| | `input` | Tensor | 无 | 必选参数,需要展平的输入张量(任意形状) | | `start_dim` | int | 0 | 展平的起始维度,从该维度开始往后的所有维度都会被展平 | | `end_dim` | int | -1 | 展平的结束维度,到该维度为止(-1 表示最后一个维度) | # 例子 先定义一个多维张量作为基础示例: ```python import torch # 定义一个 4 维张量:(批次, 通道, 高度, 宽度) x = torch.randn(2, 3, 4, 5) # shape: [2, 3, 4, 5] print("原始张量形状:", x.shape) ``` ### 1. 默认用法(展平所有维度) 不指定 `start_dim` 和 `end_dim`,默认从第 `0` 维展平到最后一维,得到一维张量: ```python flatten_all = torch.flatten(x) print("全展平后形状:", flatten_all.shape) # 输出: torch.Size([120]) (2*3*4*5=120) ``` ### 2. 指定起始维度(保留前 N 维) 最常用场景:深度学习中保留批次维度,只展平特征维度(比如 CNN 输出接全连接层): ```python # 保留批次维度(第0维),展平从第1维到最后一维 flatten_feature = torch.flatten(x, start_dim=1) print("保留批次展平后形状:", flatten_feature.shape) # 输出: torch.Size([2, 60]) (3*4*5=60) ``` ### 3. 指定起始+结束维度(部分展平) 只展平中间某几个维度,前后维度保留: ```python # 只展平第1、2维(通道、高度),保留第0维(批次)和第3维(宽度) flatten_part = torch.flatten(x, start_dim=1, end_dim=2) print("部分展平后形状:", flatten_part.shape) # 输出: torch.Size([2, 12, 5]) (3*4=12) ``` ### 4. 负数维度(按倒数维度指定) `end_dim` 支持负数,-1 表示最后一维,-2 表示倒数第二维,以此类推: ```python # 展平第1维到倒数第二维(高度),保留最后一维(宽度) flatten_negative = torch.flatten(x, start_dim=1, end_dim=-2) print("负数维度展平后形状:", flatten_negative.shape) # 输出: torch.Size([2, 12, 5])(和上面示例等价) ``` # 注意事项 1. **维度索引规则**:PyTorch 张量维度从 0 开始计数,比如 `(2,3,4,5)` 的维度索引是 0:批次、1:通道、2:高度、3:宽度; 2. **返回值特性**:`torch.flatten()` 返回的是原张量的**视图(view)**(非深拷贝),修改展平后的张量会影响原张量(节省内存); 3. **等价替代**:`torch.flatten(x, start_dim=1)` 等价于 `x.view(x.size(0), -1)`,后者是更底层的写法,效果完全一致。 # 总结 1. `torch.flatten()` 用于将张量展平为一维,核心参数是 `start_dim`(起始展平维度)和 `end_dim`(结束展平维度); 2. 最常用场景是 `start_dim=1`,保留批次维度,仅展平特征维度(适配全连接层输入); 3. 展平操作返回的是原张量的视图,不额外占用内存,维度索引从 0 开始,支持负数索引。 原文出处:http://malaoshi.top/show_1GW2ftlTVQM9.html