shape = (3,4,5)
a = tf.random.uniform(shape)
a_t = tf.transpose(a,(1,0,2)) # permuting first and second axis
a_concat = tf.concat([tf.reshape(a[i:i+1,:,:],(shape[1],1,shape[2])) for i in range(shape[0])],axis=1)
tf.debugging.assert_equal(a_t,a_concat)