import torch
import numpy as np
# 1. 张量和 numpy 的转换
def test01():
data_tensor = torch.tensor([2, 3, 4])
# 将张量转换为 numpy 数组
data_numpy = data_tensor.numpy()
# data_tensor 和 data_numpy 共享内存
# 修改其中任意一个变量, 另外一个也会发生改变
data_tensor[0] = 100
print(data_numpy)
data_numpy[0] = 200
print(data_tensor)
# 2. 写时拷贝
def test02():
data_numpy = np.array([2, 3, 4])
data_numpy[0] = 100
# data_tensor 与 data_numpy 共享内存
# 修改 data_numpy 会导致 data_tensor 发生改变
data_tensor = torch.from_numpy(data_numpy)
print(data_tensor)
# 当 data_tensor 修改时, 发生写时拷贝(延迟拷贝)
# data_numpy 并不会改变
data_tensor[0] = 200
print(data_numpy)
# 3. 标量张量和数字的转换
def test03():
# 当张量只包含一个元素时, 可以通过 item 函数提取出该值
data = torch.tensor([30,])
print(data.item())
data = torch.tensor(30)
print(data.item())
if __name__ == '__main__':
test02()
PyTorch Tensor 类型转换
未经允许不得转载:工具盒子 » PyTorch Tensor 类型转换