如何在 PyTorch 中获取张量的数据类型?

PyTorch 张量是同质的,即张量的所有元素都具有相同的数据类型。我们可以使用张量的“.dtype”属性访问张量的数据类型。它返回张量的数据类型。

脚步

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库是torch。确保您已经安装了它。

  • 创建一个张量并打印它。

  • 计算T.dtype。这里 T 是我们想要获取数据类型的张量。

  • 打印张量的数据类型。

示例 1

以下 Python 程序展示了如何获取张量的数据类型。

# Import the library

import torch

# Create a tensor of random numbers of size 3x4

T = torch.randn(3,4)

print("Original Tensor T:\n", T)

# Get the data type of above tensor

data_type = T.dtype

# Print the data type of the tensor

print("Data type of tensor T:\n", data_type)

输出结果
Original Tensor T:

tensor([[ 2.1768, -0.1328, 0.8155, -0.7967],

         [ 0.1194, 1.0465, 0.0779, 0.9103],

         [-0.1809, 1.8085, 0.8393, -0.2463]])

Data type of tensor T:

torch.float32

示例 2

# Python program to get data type of a tensor

# Import the library

import torch

# Create a tensor of random numbers of size 3x4

T = torch.Tensor([1,2,3,4])

print("Original Tensor T:\n", T)

# Get the data type of above tensor

data_type = T.dtype

# Print the data type of the tensor

print("Data type of tensor T:\n", data_type)

输出结果
Original Tensor T:

   tensor([1., 2., 3., 4.])

Data type of tensor T:

   torch.float32

以上是 如何在 PyTorch 中获取张量的数据类型? 的全部内容, 来源链接: utcz.com/z/343732.html

回到顶部