Numpy: change max in each row to 1, all other numbers to 0

Method #1, tweaking yours:

>>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])
>>> b = np.zeros_like(a)
>>> b[np.arange(len(a)), a.argmax(1)] = 1
>>> b
array([[0, 1],
       [0, 1],
       [0, 1],
       [0, 1],
       [1, 0]])

[Actually, range will work just fine; I wrote arange out of habit.]

Method #2, using max instead of argmax to handle the case where multiple elements reach the maximum value:

>>> a = np.array([[0, 1], [2, 2], [4, 3]])
>>> (a == a.max(axis=1)[:,None]).astype(int)
array([[0, 1],
       [1, 1],
       [1, 0]])

Leave a Comment

Hata!: SQLSTATE[HY000] [1045] Access denied for user 'divattrend_liink'@'localhost' (using password: YES)