pytorch 常用函数 max ,eq说明
max找出tensor 的行或者列最大的值:
找出每行的最大值:
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,1))
输出:
(tensor([ 1., 2., 3.]), tensor([ 0, 0, 0]))
找出每列的最大值:
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,0))
输出结果:
(tensor([ 3.]), tensor([ 2]))
Tensor比较eq相等:
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))
输出结果:
tensor([[ 0],
[ 1],
[ 1]], dtype=torch.uint8)
使用sum() 统计相等的个数:
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data).cpu().sum())
输出结果:
tensor(2)
补充知识:PyTorch - torch.eq、torch.ne、torch.gt、torch.lt、torch.ge、torch.le
flyfish
torch.eq、torch.ne、torch.gt、torch.lt、torch.ge、torch.le
以上全是简写
参数是input, other, out=None
逐元素比较input和other
返回是torch.BoolTensor
import torch
a=torch.tensor([[1, 2], [3, 4]])
b=torch.tensor([[1, 2], [4, 3]])
print(torch.eq(a,b))#equals
# tensor([[ True, True],
# [False, False]])
print(torch.ne(a,b))#not equal to
# tensor([[False, False],
# [ True, True]])
print(torch.gt(a,b))#greater than
# tensor([[False, False],
# [False, True]])
print(torch.lt(a,b))#less than
# tensor([[False, False],
# [ True, False]])
print(torch.ge(a,b))#greater than or equal to
# tensor([[ True, True],
# [False, True]])
print(torch.le(a,b))#less than or equal to
# tensor([[ True, True],
# [ True, False]])
以上这篇pytorch 常用函数 max ,eq说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
以上是 pytorch 常用函数 max ,eq说明 的全部内容, 来源链接: utcz.com/z/335917.html