1
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np
import os
设置GPU内存增长(避免GPU内存错误)
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
print("TensorFlow版本:", tf.version)
try:
# 1. 加载和预处理数据
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
# 数据预处理
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# 标签one-hot编码
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
print("数据加载成功!")
print("训练集形状:", x_train.shape)
print("测试集形状:", x_test.shape)
# 2. 构建网络(简化版本,避免内存问题)
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(64, activation='relu'), # 减少神经元数量
layers.Dense(10, activation='softmax')
])
# 3. 编译网络
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# 4. 训练网络(减少epochs进行测试)
print("开始训练...")
history = model.fit(x_train, y_train,
epochs=5, # 先试5个epoch
batch_size=64,
validation_split=0.2,
verbose=1)
# 5. 评估网络
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"测试集精度: {test_accuracy:.4f}")
except Exception as e:
print(f"错误信息: {e}")
print("\n请检查以下问题:")
print("1. TensorFlow是否正确安装")
print("2. Python版本是否兼容")
print("3. 网络连接是否正常(下载数据集需要网络)")
浙公网安备 33010602011771号