Flatten batch in tensorflow

You can do it easily with tf.reshape() without knowing the batch size.

x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list()        # a list: [None, 9, 2]
dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim])           # -1 means "all"

The -1 in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().


Update: shape = [None, 3, None]

Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.

x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])

tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.

I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can’t handle the second.

Leave a Comment

Hata!: SQLSTATE[HY000] [1045] Access denied for user 'divattrend_liink'@'localhost' (using password: YES)