tf.eye
tf.eye(2)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 0.],
       [0., 1.]], dtype=float32)>
tf.eye(2, batch_shape=[3])
<tf.Tensor: shape=(3, 2, 2), dtype=float32, numpy=
array([[[1., 0.],
        [0., 1.]],
       [[1., 0.],
        [0., 1.]],
       [[1., 0.],
        [0., 1.]]], dtype=float32)>
tf.eye(2, num_columns=3)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.]], dtype=float32)>
 
                    
                     
                    
                 
                    
                
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号