There are two main ways to access subsets of the elements in a tensor, either of which should work for your example.
-
Use the indexing operator (based on
tf.slice()) to extract a contiguous slice from the tensor.input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) output = input[0, :] print sess.run(output) # ==> [1 2 3]The indexing operator supports many of the same slice specifications as NumPy does.
-
Use the
tf.gather()op to select a non-contiguous slice from the tensor.input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) output = tf.gather(input, 0) print sess.run(output) # ==> [1 2 3] output = tf.gather(input, [0, 2]) print sess.run(output) # ==> [[1 2 3] [7 8 9]]Note that
tf.gather()only allows you to select whole slices in the 0th dimension (whole rows in the example of a matrix), so you may need totf.reshape()ortf.transpose()your input to obtain the appropriate elements.