tf.map_fn

import numpy as np
import tensorflow as tf
elems = np.array([1, 2, 3, 4, 5, 6])
tf.map_fn(lambda x: x * x, elems)
<tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1,  4,  9, 16, 25, 36])>
posted @ 2022-08-19 22:50  luoganttcc  阅读(6)  评论(0)    收藏  举报