张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如: 在残差网络、注意力机制中都使用到了张量拼接。
-
torch.cat 函数的使用
-
torch.stack 函数的使用
-
torch.cat 函数的使用 {#title-0} =============================
torch.cat 函数主要用于根据指定的维度将两个张量拼接到一起,用于拼接的两个张量的维度一般相同。
import torch
def test():
data1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])
print(data1)
print(data2)
print('-' * 50)
# 1. 按0维度拼接
new_data = torch.cat([data1, data2], dim=0)
print(new_data)
print('-' * 50)
# 2. 按1维度拼接
new_data = torch.cat([data1, data2], dim=1)
print(new_data)
print('-' * 50)
# 3. 按2维度拼接
new_data = torch.cat([data1, data2], dim=2)
print(new_data)
if name == 'main':
test()
输出结果:
tensor([[[1, 0, 3, 9],
[4, 1, 7, 3],
[5, 6, 5, 1],
[8, 8, 1, 3],
[3, 6, 8, 9]],
[[0, 0, 9, 3],
[3, 9, 3, 5],
[9, 2, 8, 6],
[5, 3, 6, 9],
[6, 2, 0, 4]],
[[0, 7, 3, 0],
[0, 2, 8, 2],
[4, 9, 6, 8],
[0, 7, 9, 9],
[9, 9, 8, 1]]])
tensor([[[4, 6, 3, 0],
[4, 4, 4, 3],
[0, 2, 4, 6],
[5, 6, 6, 7],
[0, 7, 1, 6]],
[[0, 2, 5, 8],
[8, 2, 1, 8],
[9, 4, 9, 7],
[3, 6, 7, 8],
[4, 8, 7, 0]],
[[6, 0, 2, 5],
[2, 4, 6, 3],
[3, 7, 7, 0],
[8, 6, 0, 0],
[6, 7, 3, 8]]])
tensor([[[1, 0, 3, 9],
[4, 1, 7, 3],
[5, 6, 5, 1],
[8, 8, 1, 3],
[3, 6, 8, 9]],
[[0, 0, 9, 3],
[3, 9, 3, 5],
[9, 2, 8, 6],
[5, 3, 6, 9],
[6, 2, 0, 4]],
[[0, 7, 3, 0],
[0, 2, 8, 2],
[4, 9, 6, 8],
[0, 7, 9, 9],
[9, 9, 8, 1]],
[[4, 6, 3, 0],
[4, 4, 4, 3],
[0, 2, 4, 6],
[5, 6, 6, 7],
[0, 7, 1, 6]],
[[0, 2, 5, 8],
[8, 2, 1, 8],
[9, 4, 9, 7],
[3, 6, 7, 8],
[4, 8, 7, 0]],
[[6, 0, 2, 5],
[2, 4, 6, 3],
[3, 7, 7, 0],
[8, 6, 0, 0],
[6, 7, 3, 8]]])
tensor([[[1, 0, 3, 9],
[4, 1, 7, 3],
[5, 6, 5, 1],
[8, 8, 1, 3],
[3, 6, 8, 9],
[4, 6, 3, 0],
[4, 4, 4, 3],
[0, 2, 4, 6],
[5, 6, 6, 7],
[0, 7, 1, 6]],
[[0, 0, 9, 3],
[3, 9, 3, 5],
[9, 2, 8, 6],
[5, 3, 6, 9],
[6, 2, 0, 4],
[0, 2, 5, 8],
[8, 2, 1, 8],
[9, 4, 9, 7],
[3, 6, 7, 8],
[4, 8, 7, 0]],
[[0, 7, 3, 0],
[0, 2, 8, 2],
[4, 9, 6, 8],
[0, 7, 9, 9],
[9, 9, 8, 1],
[6, 0, 2, 5],
[2, 4, 6, 3],
[3, 7, 7, 0],
[8, 6, 0, 0],
[6, 7, 3, 8]]])
tensor([[[1, 0, 3, 9, 4, 6, 3, 0],
[4, 1, 7, 3, 4, 4, 4, 3],
[5, 6, 5, 1, 0, 2, 4, 6],
[8, 8, 1, 3, 5, 6, 6, 7],
[3, 6, 8, 9, 0, 7, 1, 6]],
[[0, 0, 9, 3, 0, 2, 5, 8],
[3, 9, 3, 5, 8, 2, 1, 8],
[9, 2, 8, 6, 9, 4, 9, 7],
[5, 3, 6, 9, 3, 6, 7, 8],
[6, 2, 0, 4, 4, 8, 7, 0]],
[[0, 7, 3, 0, 6, 0, 2, 5],
[0, 2, 8, 2, 2, 4, 6, 3],
[4, 9, 6, 8, 3, 7, 7, 0],
[0, 7, 9, 9, 8, 6, 0, 0],
[9, 9, 8, 1, 6, 7, 3, 8]]])
- torch.stack 函数的使用 {#title-1} ===============================
torch.cat 函数主要用于根据指定的维度将两个张量叠加到一起,用于拼接的两个张量的维度一般相同,其结果会使得数据增加一维。
import torch
def test():
data1= torch.randint(0, 10, [2, 3])
data2= torch.randint(0, 10, [2, 3])
print(data1)
print(data2)
print('-' * 50)
new_data = torch.stack([data1, data2], dim=0)
print(new_data)
print('-' * 50)
new_data = torch.stack([data1, data2], dim=1)
print(new_data)
print('-' * 50)
new_data = torch.stack([data1, data2], dim=2)
print(new_data)
if name == 'main':
test()
输出结果:
tensor([[2, 9, 8],
[9, 0, 1]])
tensor([[7, 3, 6],
[1, 0, 3]])
--------------------------------------------------
tensor([[[2, 9, 8],
[9, 0, 1]],
[[7, 3, 6],
[1, 0, 3]]])
tensor([[[2, 9, 8],
[7, 3, 6]],
[[9, 0, 1],
[1, 0, 3]]])
tensor([[[2, 7],
[9, 3],
[8, 6]],
[[9, 1],
[0, 0],
[1, 3]]])
至此,本篇文章结束。