如何在 PyTorch 中找到张量的第 k 个和前“k”个元素?

PyTorch 提供了一种查找张量的第 k 个元素的方法。它返回按升序排序的张量的第 k 个元素的值,以及该元素在原始张量中的索引。torch.kthvalue()

torch.topk()方法用于查找前“k”个元素。它返回张量中最高的“k”或最大的“k”元素。

脚步

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库是torch。确保您已经安装了它。

  • 创建一个 PyTorch 张量并打印它。

  • 计算。它返回两个张量。将这两个张量分配给两个新变量"value"和"index"。这里,输入是一个张量,k 是一个整数。torch.kthvalue(input, k)

  • 计算。它返回两个张量。第一个张量具有前“k”个元素的值,第二个张量具有原始张量中这些元素的索引。将这两个张量分配给新变量"values"和"indices"。torch.topk(input, k)

  • 打印张量的第 k 个元素的值和索引,以及张量的前“k”个元素的值和索引。

示例 1

这个 python 程序展示了如何找到张量的第 k 个元素。

# Python program to find k-th element of a tensor

# import necessary library

import torch

# Create a 1D tensor

T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])

print("Original Tensor:\n", T)

# Find the 3rd element in sorted tensor. First it sorts the

# tensor in ascending order then returns the kth element value

# from sorted tensor and the index of element in original tensor

value, index = torch.kthvalue(T, 3)

# print 3rd element with value and index

print("第三个元素值:", value)

print("第三个元素索引:", index)

输出结果
Original Tensor:

   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])

第三个元素值: tensor(2.3340)

第三个元素索引: tensor(0)

示例 2

以下 Python 程序显示了如何查找张量的前“k”个或最大的“k”个元素。

# Python program to find to top k elements of a tensor

# import necessary library

import torch

# Create a 1D tensor

T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])

print("Original Tensor:\n", T)

# Find the top k=2 or 2 largest elements of the tensor

# returns the 2 largest values and their indices in original

# tensor

values, indices = torch.topk(T, 2)

# print top 2 elements with value and index

print("前 2 个元素值:", values)

print("前 2 个元素索引:", indices)

输出结果
Original Tensor:

   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])

前 2 个元素值: tensor([5.0000, 4.4430])

前 2 个元素索引: tensor([4, 5])

以上是 如何在 PyTorch 中找到张量的第 k 个和前“k”个元素? 的全部内容, 来源链接: utcz.com/z/338670.html

回到顶部