pytorch api文档:torch.randperm()生成无重复随机整数的一维张量 作者:马育民 • 2026-01-18 09:50 • 阅读:10003 # 介绍 `torch.randperm()` 函数,这是生成 **无重复随机整数 的 一维张量** 的专用工具,核心作用是创建 `0 到 n-1` 的随机打乱序列,常用来打乱数据集、生成随机索引等场景 ### 作用 生成一个包含 `0, 1, 2, ..., n-1` 所有整数的 **一维张量**,且这些整数被 **随机打乱顺序**(无重复、不遗漏)。 它和 `torch.randint()` 的核心区别: - `torch.randint()`:生成的整数可能重复(如 [2,2,5]); ### 应用场景 - 打乱数据集(最常用) - 随机划分训练/测试集 - 随机选取无重复样本 # 语法 ``` torch.randperm(n, dtype=torch.int64, device=None, requires_grad=False) ``` #### 参数 | 参数 | 作用 | 注意事项 | |----------------|----------------------------------------------------------------------|--------------------------------------------------------------------------| | `n` | 生成 0~n-1 的随机排列,必传(整数)| 若 n=0,返回空张量;n=1,返回 [0] | | `dtype` | 指定整数类型(`torch.int32`/`torch.int64`)| 默认 `torch.int64`(LongTensor),兼容大部分索引场景 | | `device` | 指定存储设备(`cpu`/`cuda`)| 需和待索引的张量同设备,避免数据迁移 | | `requires_grad`| 是否开启梯度计算 | 整数张量无法求梯度,设为 `True` 也无效(仅为参数兼容)| # 例子 `n` 是唯一必传参数(指定生成整数的最大值+1),直接看代码更直观: ```python import torch # 示例1:基础用法(生成 0~4 的随机排列) perm1 = torch.randperm(5) print("0~4 随机排列:", perm1) # 输出示例:tensor([3, 1, 4, 0, 2])(包含0-4所有数,无重复) print("形状:", perm1.shape) # torch.Size([5]) # 示例2:固定随机种子(复现结果) torch.manual_seed(42) perm2 = torch.randperm(5) print("固定种子结果1:", perm2) # tensor([2, 4, 1, 0, 3]) torch.manual_seed(42) perm3 = torch.randperm(5) print("固定种子结果2:", perm3) # tensor([2, 4, 1, 0, 3])(完全一致) # 示例3:指定数据类型(int32) perm4 = torch.randperm(5, dtype=torch.int32) print("数据类型:", perm4.dtype) # torch.int32 print("数值:", perm4) # tensor([2, 4, 1, 0, 3], dtype=torch.int32) ``` ### 参数示例 ```python # 示例1:指定GPU设备(需CUDA环境) if torch.cuda.is_available(): perm_cuda = torch.randperm(10, device="cuda") print("GPU张量设备:", perm_cuda.device) # cuda:0 print("GPU张量:", perm_cuda) else: print("无CUDA环境,默认生成CPU张量") # 示例2:n=0/1的边界情况 print("n=0:", torch.randperm(0)) # tensor([], dtype=torch.int64) print("n=1:", torch.randperm(1)) # tensor([0]) ``` ### 绝对无重复(核心优势) `torch.randperm()` 生成的序列包含 0~n-1 所有整数,且每个数仅出现一次,这是它和 `torch.randint()` 的核心差异: ```python # 对比:randint 可能重复,randperm 绝对无重复 randint_res = torch.randint(0, 5, size=(5,)) # 生成5个0~4的整数(可能重复) randperm_res = torch.randperm(5) # 生成0~4的全排列(无重复) print("randint 结果(可能重复):", randint_res) # 输出示例:tensor([2, 2, 4, 1, 2]) print("randperm 结果(无重复):", randperm_res) # 输出示例:tensor([3, 0, 4, 1, 2]) # 验证:randperm 包含所有0~4的数 print("randperm 包含所有数:", sorted(randperm_res.tolist()) == [0,1,2,3,4]) # True ``` ### 仅生成一维张量 `torch.randperm()` 只能生成一维张量,若需高维随机无重复索引,需先生成一维排列再重塑形状: ```python # 需求:生成 2×3 的无重复索引(0~5) perm = torch.randperm(6) # 先生成0~5的一维排列 perm_2d = perm.reshape(2, 3) # 重塑为二维 print("一维排列:", perm) print("二维无重复索引:\n", perm_2d) # 输出示例: # tensor([3, 1, 4, 0, 2, 5]) # tensor([[3, 1, 4], # [0, 2, 5]]) ``` # 实战核心场景 ### 打乱数据集(最常用) 训练模型时,需随机打乱数据集的样本顺序,`randperm()` 是最优选择: ```python # 模拟:数据集(10个样本,5维特征)+ 标签 data = torch.randn(10, 5) labels = torch.randint(0, 2, size=(10,)) # 步骤1:生成0~9的随机排列(打乱索引) shuffle_indices = torch.randperm(10) print("打乱索引:", shuffle_indices) # 步骤2:用索引打乱数据和标签(保证数据-标签对应) shuffled_data = data[shuffle_indices] shuffled_labels = labels[shuffle_indices] print("\n原始数据前3行:\n", data[:3]) print("打乱后数据前3行:\n", shuffled_data[:3]) print("\n原始标签:", labels) print("打乱后标签:", shuffled_labels) ``` ### 随机划分训练/测试集 生成无重复的索引,实现数据集的随机划分(无重叠): ```python # 场景:将100个样本按8:2划分为训练集和测试集 total_samples = 100 train_ratio = 0.8 train_size = int(total_samples * train_ratio) # 步骤1:生成0~99的随机排列 all_indices = torch.randperm(total_samples) # 步骤2:划分训练/测试索引 train_indices = all_indices[:train_size] test_indices = all_indices[train_size:] print("训练集索引数量:", len(train_indices)) # 80 print("测试集索引数量:", len(test_indices)) # 20 print("索引是否重叠:", len(set(train_indices.tolist()) & set(test_indices.tolist())) == 0) # True ``` ### 随机选取无重复样本 从数据集中选取指定数量的无重复样本(区别于 `randint()` 的可能重复): ```python # 场景:从50个样本中选10个无重复样本 total = 50 select_num = 10 # 生成0~49的随机排列,取前10个作为索引 select_indices = torch.randperm(total)[:select_num] # 模拟采样 data = torch.randn(total, 10) selected_data = data[select_indices] print("选取的索引:", select_indices) print("选取的样本形状:", selected_data.shape) # torch.Size([10, 10]) print("索引是否重复:", len(select_indices) == len(set(select_indices.tolist()))) # True ``` # 总结 1. `torch.randperm(n)` 生成 **0~n-1 的随机无重复全排列**,是唯一能保证整数无重复的随机索引生成函数。 2. 核心特性:仅生成一维张量,绝对无重复,固定种子可复现结果;与 `randint()` 相比,前者适合“无重复打乱”,后者适合“可重复随机整数”。 3. 适用场景:打乱数据集、划分训练/测试集、选取无重复样本等,是深度学习数据预处理的核心工具。 原文出处:http://malaoshi.top/show_1GW2c67TYa8B.html