51工具盒子

依楼听风雨
笑看云卷云舒,淡观潮起潮落

PyTorch Tensor 拼接操作

张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如: 在残差网络、注意力机制中都使用到了张量拼接。

  1. torch.cat 函数的使用

  2. torch.stack 函数的使用

  3. torch.cat 函数的使用 {#title-0} =============================

torch.cat 函数主要用于根据指定的维度将两个张量拼接到一起,用于拼接的两个张量的维度一般相同。

import torch

def test():

data1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])

print(data1)
print(data2)
print('-' * 50)

# 1. 按0维度拼接
new_data = torch.cat([data1, data2], dim=0)
print(new_data)
print('-' * 50)

# 2. 按1维度拼接
new_data = torch.cat([data1, data2], dim=1)
print(new_data)
print('-' * 50)

# 3. 按2维度拼接
new_data = torch.cat([data1, data2], dim=2)
print(new_data)

if name == 'main': test()

输出结果:

tensor([[[1, 0, 3, 9],
         [4, 1, 7, 3],
         [5, 6, 5, 1],
         [8, 8, 1, 3],
         [3, 6, 8, 9]],
    [[0, 0, 9, 3],
     [3, 9, 3, 5],
     [9, 2, 8, 6],
     [5, 3, 6, 9],
     [6, 2, 0, 4]],

    [[0, 7, 3, 0],
     [0, 2, 8, 2],
     [4, 9, 6, 8],
     [0, 7, 9, 9],
     [9, 9, 8, 1]]])

tensor([[[4, 6, 3, 0], [4, 4, 4, 3], [0, 2, 4, 6], [5, 6, 6, 7], [0, 7, 1, 6]],

    [[0, 2, 5, 8],
     [8, 2, 1, 8],
     [9, 4, 9, 7],
     [3, 6, 7, 8],
     [4, 8, 7, 0]],

    [[6, 0, 2, 5],
     [2, 4, 6, 3],
     [3, 7, 7, 0],
     [8, 6, 0, 0],
     [6, 7, 3, 8]]])

tensor([[[1, 0, 3, 9], [4, 1, 7, 3], [5, 6, 5, 1], [8, 8, 1, 3], [3, 6, 8, 9]],

    [[0, 0, 9, 3],
     [3, 9, 3, 5],
     [9, 2, 8, 6],
     [5, 3, 6, 9],
     [6, 2, 0, 4]],

    [[0, 7, 3, 0],
     [0, 2, 8, 2],
     [4, 9, 6, 8],
     [0, 7, 9, 9],
     [9, 9, 8, 1]],

    [[4, 6, 3, 0],
     [4, 4, 4, 3],
     [0, 2, 4, 6],
     [5, 6, 6, 7],
     [0, 7, 1, 6]],

    [[0, 2, 5, 8],
     [8, 2, 1, 8],
     [9, 4, 9, 7],
     [3, 6, 7, 8],
     [4, 8, 7, 0]],

    [[6, 0, 2, 5],
     [2, 4, 6, 3],
     [3, 7, 7, 0],
     [8, 6, 0, 0],
     [6, 7, 3, 8]]])

tensor([[[1, 0, 3, 9], [4, 1, 7, 3], [5, 6, 5, 1], [8, 8, 1, 3], [3, 6, 8, 9], [4, 6, 3, 0], [4, 4, 4, 3], [0, 2, 4, 6], [5, 6, 6, 7], [0, 7, 1, 6]],

    [[0, 0, 9, 3],
     [3, 9, 3, 5],
     [9, 2, 8, 6],
     [5, 3, 6, 9],
     [6, 2, 0, 4],
     [0, 2, 5, 8],
     [8, 2, 1, 8],
     [9, 4, 9, 7],
     [3, 6, 7, 8],
     [4, 8, 7, 0]],

    [[0, 7, 3, 0],
     [0, 2, 8, 2],
     [4, 9, 6, 8],
     [0, 7, 9, 9],
     [9, 9, 8, 1],
     [6, 0, 2, 5],
     [2, 4, 6, 3],
     [3, 7, 7, 0],
     [8, 6, 0, 0],
     [6, 7, 3, 8]]])

tensor([[[1, 0, 3, 9, 4, 6, 3, 0], [4, 1, 7, 3, 4, 4, 4, 3], [5, 6, 5, 1, 0, 2, 4, 6], [8, 8, 1, 3, 5, 6, 6, 7], [3, 6, 8, 9, 0, 7, 1, 6]],

    [[0, 0, 9, 3, 0, 2, 5, 8],
     [3, 9, 3, 5, 8, 2, 1, 8],
     [9, 2, 8, 6, 9, 4, 9, 7],
     [5, 3, 6, 9, 3, 6, 7, 8],
     [6, 2, 0, 4, 4, 8, 7, 0]],

    [[0, 7, 3, 0, 6, 0, 2, 5],
     [0, 2, 8, 2, 2, 4, 6, 3],
     [4, 9, 6, 8, 3, 7, 7, 0],
     [0, 7, 9, 9, 8, 6, 0, 0],
     [9, 9, 8, 1, 6, 7, 3, 8]]])


  1. torch.stack 函数的使用 {#title-1} ===============================

torch.cat 函数主要用于根据指定的维度将两个张量叠加到一起,用于拼接的两个张量的维度一般相同,其结果会使得数据增加一维。

import torch

def test():

data1= torch.randint(0, 10, [2, 3])
data2= torch.randint(0, 10, [2, 3])
print(data1)
print(data2)
print('-' * 50)

new_data = torch.stack([data1, data2], dim=0)
print(new_data)
print('-' * 50)

new_data = torch.stack([data1, data2], dim=1)
print(new_data)
print('-' * 50)

new_data = torch.stack([data1, data2], dim=2)
print(new_data)

if name == 'main': test()

输出结果:

tensor([[2, 9, 8],
        [9, 0, 1]])
tensor([[7, 3, 6],
        [1, 0, 3]])
--------------------------------------------------
tensor([[[2, 9, 8],
         [9, 0, 1]],
    [[7, 3, 6],
     [1, 0, 3]]])

tensor([[[2, 9, 8], [7, 3, 6]],

    [[9, 0, 1],
     [1, 0, 3]]])

tensor([[[2, 7], [9, 3], [8, 6]],

    [[9, 1],
     [0, 0],
     [1, 3]]])

至此,本篇文章结束。

赞(6)
未经允许不得转载:工具盒子 » PyTorch Tensor 拼接操作