tensorflow Lambda用法

import numpy as np
from tensorflow.keras.layers import Lambda
import tensorflow as tf

x=np.array([[1,2],[3,4]])


x = Lambda(lambda x: x**2)(x)

print(x)

print('---------------------------------------')
x = Lambda(lambda x: tf.reshape(x, (-1,x.shape[1], tf.shape(x)[0])))(x)
print(x)
print('---------------------------------------')

tf.Tensor(
[[ 1  4]
 [ 9 16]], shape=(2, 2), dtype=int64)
---------------------------------------
tf.Tensor(
[[[ 1  4]
  [ 9 16]]], shape=(1, 2, 2), dtype=int64)
---------------------------------------
posted @ 2022-08-19 22:49  luoganttcc  阅读(11)  评论(0)    收藏  举报