简单粗暴的tensorflow-tf.function
#tf.function :图执行模式
# 基础使用
import tensorflow as tf
import time
from zh.model.mnist.cnn import CNN
from zh.model.utils import MNISTLoader
num_batches = 1000
batch_size = 50
learning_rate = 0.001
data_loader = MNISTLoader()
model = CNN()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
@tf.function
def train_one_step(X, y):
with tf.GradientTape() as tape:
y_pred = model(X)
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
loss = tf.reduce_mean(loss)
# 注意这里使用了TensorFlow内置的tf.print()。@tf.function不支持Python内置的print方法
tf.print("loss", loss)
grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
start_time = time.time()
for batch_index in range(num_batches):
X, y = data_loader.get_batch(batch_size)
train_one_step(X, y)
end_time = time.time()
print(end_time - start_time)
天道酬勤 循序渐进 技压群雄
浙公网安备 33010602011771号