import numpy as np
from tensorflow.keras.layers import Lambda
import tensorflow as tf
x=np.array([[1,2],[3,4]])
x = Lambda(lambda x: x**2)(x)
print(x)
print('---------------------------------------')
x = Lambda(lambda x: tf.reshape(x, (-1,x.shape[1], tf.shape(x)[0])))(x)
print(x)
print('---------------------------------------')
tf.Tensor(
[[ 1 4]
[ 9 16]], shape=(2, 2), dtype=int64)
---------------------------------------
tf.Tensor(
[[[ 1 4]
[ 9 16]]], shape=(1, 2, 2), dtype=int64)
---------------------------------------