51工具盒子

依楼听风雨
笑看云卷云舒,淡观潮起潮落

PyTorch Tensor 索引操作

我们在操作张量时,经常需要去进行获取或者修改操作,掌握张量的花式索引操作是必须的一项能力。

import torch

data = torch.randint(0, 10, [4, 5]) print(data) print('-' * 50)

1. 简单行、列索引

def test01():

print(data[0])
print(data[:, 0])
print('-' * 50)

2. 列表索引

def test02():

# 返回 (0, 1)、(1, 2) 两个位置的元素
print(data[[0, 1], [1, 2]])
print('-' * 50)

# 返回 0、1 行的 1、2 列共4个元素
print(data[[[0], [1]], [1, 2]])

3. 范围索引

def test03(): # 前3行的前2列数据 print(data[:3, :2]) # 第2行到最后的前2列数据 print(data[2:, :2])

4. 布尔索引

def test04():

# 第三列大于5的行数据
print(data[data[:, 2] > 5])
# 第二行大于5的列数据
print(data[:, data[1] > 5])

5. 多维索引

def test05():

data = torch.randint(0, 10, [3, 4, 5])
print(data)
print('-' * 50)

print(data[0, :, :])
print(data[:, 0, :])
print(data[:, :, 0])

if name == 'main': test05()

赞(6)
未经允许不得转载:工具盒子 » PyTorch Tensor 索引操作