TensorFlow: Max of a tensor along an axis

The tf.reduce_max() operator provides exactly this functionality. By default it computes the global maximum of the given tensor, but you can specify a list of reduction_indices, which has the same meaning as axis in NumPy. To complete your example:

x = tf.constant([[1, 220, 55], [4, 3, -1]])
x_max = tf.reduce_max(x, reduction_indices=[1])
print sess.run(x_max)  # ==> "array([220,   4], dtype=int32)"

If you compute the argmax using tf.argmax(), you could obtain the the values from a different tensor y by flattening y using tf.reshape(), converting the argmax indices into vector indices as follows, and using tf.gather() to extract the appropriate values:

ind_max = tf.argmax(x, dimension=1)
y = tf.constant([[1, 2, 3], [6, 5, 4]])

flat_y = tf.reshape(y, [-1])  # Reshape to a vector.

# N.B. Handles 2-D case only.
flat_ind_max = ind_max + tf.cast(tf.range(tf.shape(y)[0]) * tf.shape(y)[1], tf.int64)

y_ = tf.gather(flat_y, flat_ind_max)

print sess.run(y_) # ==> "array([2, 6], dtype=int32)"

Leave a Comment

Hata!: SQLSTATE[HY000] [1045] Access denied for user 'divattrend_liink'@'localhost' (using password: YES)