Numpy: changement de max dans chaque ligne à 1, tous les autres nombres de 0

Je suis en train de mettre en œuvre une fonction numpy qui remplace le max dans chaque ligne d'un tableau 2D avec 1, et tous les autres chiffres avec zéro:

>>> a = np.array([[0, 1],
...               [2, 3],
...               [4, 5],
...               [6, 7],
...               [9, 8]])
>>> b = some_function(a)
>>> b
[[0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [1. 0.]]

Ce que j'ai essayé jusqu'à présent

def some_function(x):
    a = np.zeros(x.shape)
    a[:,np.argmax(x, axis=1)] = 1
    return a

>>> b = some_function(a)
>>> b
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]
InformationsquelleAutor MikeRand | 2013-11-30