numpy:将每行的max更改为1,所有其他数字更改为0

我正在尝试实现一个numpy函数,该函数将2D数组每一行的max替换为1,并将所有其他数字替换为零:

>>> a = np.array([[0, 1],

... [2, 3],

... [4, 5],

... [6, 7],

... [9, 8]])

>>> b = some_function(a)

>>> b

[[0. 1.]

[0. 1.]

[0. 1.]

[0. 1.]

[1. 0.]]

def some_function(x):

a = np.zeros(x.shape)

a[:,np.argmax(x, axis=1)] = 1

return a

>>> b = some_function(a)

>>> b

[[1. 1.]

[1. 1.]

[1. 1.]

[1. 1.]

[1. 1.]]

回答:

方法1,调整您的方法:

>>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])

>>> b = np.zeros_like(a)

>>> b[np.arange(len(a)), a.argmax(1)] = 1

>>> b

array([[0, 1],

[0, 1],

[0, 1],

[0, 1],

[1, 0]])

[实际上,range可以正常工作;我arange出于习惯写信。]

方法2,使用max而不是argmax处理多个元素达到最大值的情况:

>>> a = np.array([[0, 1], [2, 2], [4, 3]])

>>> (a == a.max(axis=1)[:,None]).astype(int)

array([[0, 1],

[1, 1],

[1, 0]])

以上是 numpy:将每行的max更改为1,所有其他数字更改为0 的全部内容, 来源链接: utcz.com/qa/423141.html

回到顶部