如何在 PyTorch 中对张量的元素进行排序?

要在 PyTorch 中对张量的元素进行排序,我们可以使用方法。此方法返回两个张量。第一个张量是具有元素排序值的张量,第二个张量是原始张量中元素索引的张量。我们可以按行和列计算二维张量。torch.sort()

脚步

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库是torch。确保您已经安装了它。

  • 创建一个 PyTorch 张量并打印它。

  • 要对上面创建的张量的元素进行排序,请计算。将此值分配给一个新变量"v" 。这里,输入是输入张量,dim是元素排序的维度。对元素按行排序,dim 设置为 1,按列对元素排序,dim设置为 0。torch.sort(input, dim)

  • 具有排序值的张量可以作为v[0]访问,排序元素的索引张量作为v[1] 访问。

  • 打印带有排序值的张量和带有排序值索引的张量。

示例 1

以下 Python 程序展示了如何对一维张量的元素进行排序。

# Python program to sort elements of a tensor

# import necessary library

import torch

# Create a tensor

T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])

print("Original Tensor:\n", T)

# sort the tensor T

# it sorts the tensor in ascending order

v = torch.sort(T)

# print(v)

# print tensor of sorted value

print("Tensor with sorted value:\n", v[0])

# print indices of sorted value

print("Indices of sorted value:\n", v[1])

输出结果
Original Tensor:

   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])

Tensor with sorted value:

   tensor([-4.3300, -0.4330, 2.3340, 4.4330, 4.4430, 5.0000])

Indices of sorted value:

   tensor([2, 3, 0, 1, 5, 4])

示例 2

以下 Python 程序显示了如何对 2D 张量的元素进行排序。

# Python program to sort elements of a 2-D tensor

# import the library

import torch

# Create a 2-D tensor

T = torch.Tensor([[2,3,-32],

                  [43,4,-53],

                  [4,37,-4],

                  [3,-75,34]])

print("Original Tensor:\n", T)

# sort tensor T

# it sorts the tensor in ascending order

v = torch.sort(T)

# print(v)

# print tensor of sorted value

print("Tensor with sorted value:\n", v[0])

# print indices of sorted value

print("Indices of sorted value:\n", v[1])

print("Sort tensor Column-wise")

v = torch.sort(T, 0)

# print(v)

# print tensor of sorted value

print("Tensor with sorted value:\n", v[0])

# print indices of sorted value

print("Indices of sorted value:\n", v[1])

print("Sort tensor Row-wise")

v = torch.sort(T, 1)

# print(v)

# print tensor of sorted value

print("Tensor with sorted value:\n", v[0])

# print indices of sorted value

print("Indices of sorted value:\n", v[1])

输出结果
Original Tensor:

tensor([[ 2., 3., -32.],

        [ 43., 4., -53.],

        [ 4., 37., -4.],

        [ 3., -75., 34.]])

Tensor with sorted value:

tensor([[-32., 2., 3.],

         [-53., 4., 43.],

         [ -4., 4., 37.],

         [-75., 3., 34.]])

Indices of sorted value:

tensor([[2, 0, 1],

         [2, 1, 0],

         [2, 0, 1],

         [1, 0, 2]])

Sort tensor Column-wise

Tensor with sorted value:

tensor([[ 2., -75., -53.],

         [ 3., 3., -32.],

         [ 4., 4., -4.],

         [ 43., 37., 34.]])

Indices of sorted value:

tensor([[0, 3, 1],

         [3, 0, 0],

         [2, 1, 2],

         [1, 2, 3]])

Sort tensor Row-wise

Tensor with sorted value:

tensor([[-32., 2., 3.],

         [-53., 4., 43.],

         [ -4., 4., 37.],

         [-75., 3., 34.]])

Indices of sorted value:

tensor([[2, 0, 1],

         [2, 1, 0],

         [2, 0, 1],

         [1, 0, 2]])

以上是 如何在 PyTorch 中对张量的元素进行排序? 的全部内容, 来源链接: utcz.com/z/359950.html

回到顶部