![]()
from flask import Flask, request, jsonify
from flask_cors import CORS # 导入 CORS
from PIL import Image
import io
import torch
from torchvision import transforms
from models import EnhancedResNet152 # 确保导入修改后的模型
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
app = Flask(__name__)
# 启用 CORS
CORS(app) # 允许所有跨域请求
# 加载模型
try:
model = EnhancedResNet152(num_classes=8)
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
model.eval()
logging.info("模型加载成功!")
except Exception as e:
logging.error(f"模型加载失败:{str(e)}")
exit(1)
# 图像预处理(与验证预处理一致)
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图像大小
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
@app.route('/recognize', methods=['POST'])
def recognize():
try:
# 检查是否上传了左右眼图像
if 'left_image' not in request.files or 'right_image' not in request.files:
return jsonify({"error": "缺少必要的图像文件"}), 400
# 加载图像
left_image = Image.open(io.BytesIO(request.files['left_image'].read())).convert('RGB')
right_image = Image.open(io.BytesIO(request.files['right_image'].read())).convert('RGB')
logging.info("图像加载成功!")
# 预处理图像
left_image = transform(left_image).unsqueeze(0)
right_image = transform(right_image).unsqueeze(0)
logging.info("图像预处理成功!")
# 进行预测
with torch.no_grad():
outputs = model(left_image, right_image)
predictions = torch.sigmoid(outputs).numpy()[0]
logging.info("预测成功!")
# 将 NumPy 数组转换为 Python 列表
predictions = predictions.tolist()
# 构造返回结果
labels = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']
results = dict(zip(labels, predictions))
return jsonify(results)
except Exception as e:
logging.error(f"Error in /recognize: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/')
def index():
return "欢迎使用眼疾智能识别系统!请通过 /recognize 路由上传图像进行识别。"
if __name__ == '__main__':
app.run(debug=True)