TensorFlow——Eager模式
Eager模式简介
eager模式是一个命令式编程环境,可以立即评估操作产生的结果,无需构建计算图。
tensorflow的交互模式
- tensorflow2.0默认使用eager模式
- eager模式支持GPU加速和大多数tensorflow操作
- eager模式下tensorflow可与NumPy很好的协作
tf.executing_eagerly() # 判断是否在eager模式下
自然控制流-eager模式下使用Python控制流而不是图控制流,简化了动态模型的创建。
张量(Tensor)
张量=容器
- 0维张量/标量
- 1维张量/向量
- 2维张量/矩阵
- 3维张量是一个数字构成的立方体
tensorflow数据基本操作
# 两个矩阵相乘
m = tf.matmul(x,x)
c = tf.multiply(a,b)
# 使用tensorflow方法建立一个常量
a = tf.constant([[1,2],[3,4]])
# 将tensor对象转为numpy
a.numpy()
# 将数字转为tensor对象
g = tf.convert_to_tensor(10)
b = tf.add(a,1)
# 创建变量
v = tf.Variable(0.0)
# 改变变量值
v.assign(5)
# 加一
v.assign_add(1)
# 读取变量值
v.read_value()
自动微分运算
# 自动微分运算(只能计算float类型的数据)
w = tf.Variable([[1.0]])
with tf.GradientTape() as t: # 上下文管理器
loss = w * w
# 求解loss对w的微分/梯度
grad = t.gradient(loss,w)
# 自动微分运算
w = tf.constant(3.0)
with tf.GradientTape() as t:
t.watch(w) # t跟踪w的运算
loss = w * w
dloss_dw = t.gradient(loss,w)
# 自动微分运算
w = tf.constant(3.0)
with tf.GradientTape(persistent = True) as t: # persistent = True将微分持久记录,可以多次计算微分
t.watch(w) # t跟踪w的运算
y = w * w
z = y * y
dy_dx = t.gradient(y,w)
dz_dw = t.gradient(z, w)
自定义训练
(train_images,train_labels) ,_ = tf.keras.datasets.mnist.load_data()
# 改变数据类型
train_images = tf.cast(train_images/255,tf.float32)
train_labels = tf.cast(train_labels,tf.int64)
train_images = tf.expand_dims(train_images,-1) # 扩张维度,扩张最后一列为1
dataset = tf.data.Dataset.from_tensor_slices((train_images,train_labels))
dataset = dataset.shuffle(10000).batch(32) # 不设置repeat()时默认会重复一次
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16,(3,3),activation='relu',input_shape=(None,None,1)), # (None,None,1)表示图片大小任意
tf.keras.layers.Conv2D(32,(3,3),activation='relu'),
tf.keras.layers.GlobalMaxPooling2D(),
tf.keras.layers.Dense(10)
])
# 自定义循环
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # 未激活,from_logits=True
features, labels = next (iter(dataset)) # 取出第一个批次的数据
# 定义损失函数,loss返回交叉熵损失
def loss(model, x, y):
y_ = model(x)
return loss_func(y,y_) # y_为预测值
# 训练一批次
def train_step(model,images,labels):
with tf.GradientTape() as t:
loss_step = loss(model,images,labels) # 每一步的损失值
grads = t.gradient(loss_step,model.trainable_variables) # 损失函数与可训练参数之间的梯度
optimizer.apply_gradients(zip(grads,model.trainable_variables)) # 优化函数应用梯度进行优化
# 训练
def train():
for epoch in range(10):
for (batch,(images,labels)) in enumerate(dataset):
train_step(model,images,labels)
print('Epoch{} is finish'.format(epoch))
tf.kears.metrics汇总计算模块
添加汇总计算模块后可以打印显示训练数据与测试数据的loss和acc值
基本操作
# 计算均值
m = tf.keras.metrics.Mean('acc')
# 得到计算均值的结果
m.result()
# 重置均值计算状态:0
m.reset_states()
a = tf.keras.metrics.SparseCategoricalAccuracy('acc')
# 自动选出model(features)中的最大值与labels进行比较
a(labels,model(features))
带汇总计算的自定义训练模型实例
# 初始化计算对象
train_loss = tf.keras.metrics.Mean('train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
test_loss = tf.keras.metrics.Mean('test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')
def train_step(model,images,labels):
with tf.GradientTape() as t:
pred = model(images)
loss_step = loss_func(labels,pred)
#loss_step = loss(model,images,labels) # 每一步的损失值
grads = t.gradient(loss_step,model.trainable_variables) # 损失函数与可训练参数之间的梯度
optimizer.apply_gradients(zip(grads,model.trainable_variables)) # 优化函数应用梯度进行优化
# 汇总计算平均loss
train_loss(loss_step)
# 汇总计算平均acc
train_accuracy(labels,pred)
def test_step(model,images,labels):
pred = model(images)
loss_step = loss_func(labels,pred)
# 汇总计算平均loss
test_loss(loss_step)
# 汇总计算平均acc
test_accuracy(labels,pred)
def train():
for epoch in range(10):
for (batch,(images,labels)) in enumerate(dataset):
train_step(model,images,labels)
print('Epoch{} loss is {}, accuracy is {}'.format(epoch,train_loss.result(),train_accuracy.result()))
for (batch,(images,labels)) in enumerate(test_dataset):
test_step(model,images,labels)
print('Epoch{} test_loss is {}, test_accuracy is {}'.format(epoch,test_loss.result(),test_accuracy.result()))
# 重置:每个循环结束清0
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()

浙公网安备 33010602011771号