YOLO8专注度识别
通过借用疲劳驾驶的模型训练模型,最后根据训练出来的模型,完成判断公式。
import io
import json
import base64
from typing import Dict, Any, Tuple, Optional
from flask import Flask, request, jsonify
from PIL import Image
from ultralytics import YOLO
import logging
import time
import logging
# 配置 logger
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# ===== 可配置参数 =====
# 模型路径(根据你的实际权重位置调整)
MODEL_PATH = r"D:\Project\2025aunt\aiclassroom\aiclassroom\YOLO8\Drowsy Driver Dataset.v7i.yolov8\exp_drowsy_v8n\weights\best.pt"
# 如果权重在 train/weights/best.pt:
# MODEL_PATH = r"d:\MyData\MyCode\YOLO8\train\weights\best.pt"
DEFAULT_CONF = 0.25 # 置信度阈值,与你前端 320x240 抽帧匹配
DEFAULT_IMGSZ = 320 # 推理图片尺寸
DEFAULT_IOU = 0.45 # IoU 阈值
DEFAULT_DEVICE = "cpu" # "cpu" or "cuda:0"
# 类别映射。如果你的模型为 ['awake', 'drowsy'],按下面设置。
# 如果不确定,会从 model.names 自动读取。
CLS_IDX_TO_NAME: Optional[Dict[int, str]] = {0: "awake", 1: "drowsy"}
app = Flask(__name__)
model = YOLO(MODEL_PATH)
# 自动从模型获取类别映射(如果未手动提供)
if CLS_IDX_TO_NAME is None and hasattr(model, "names"):
CLS_IDX_TO_NAME = {int(k): str(v) for k, v in model.names.items()}
def parse_params() -> Tuple[float, int, float, str]:
"""从请求参数解析 conf、imgsz、iou、device,可选覆盖默认值"""
conf = float(request.form.get("conf", request.args.get("conf", DEFAULT_CONF)))
imgsz = int(request.form.get("imgsz", request.args.get("imgsz", DEFAULT_IMGSZ)))
iou = float(request.form.get("iou", request.args.get("iou", DEFAULT_IOU)))
device = request.form.get("device", request.args.get("device", DEFAULT_DEVICE))
return conf, imgsz, iou, device
def read_image_from_request() -> Tuple[Optional[Image.Image], Optional[str]]:
"""
读取请求中的图片:
- 优先使用 multipart/form-data 的 'image' 文件字段
- 也支持 form/json 中的 'image_base64'(data URL 或纯 base64)
返回 (PIL.Image, 错误字符串)
"""
try:
if "image" in request.files:
file = request.files["image"]
image_bytes = file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return img, None
# 兼容 base64
b64 = request.form.get("image_base64") or (request.json.get("image_base64") if request.is_json else None)
if b64:
# 去掉 data URL 前缀
if "," in b64:
b64 = b64.split(",", 1)[1]
image_bytes = base64.b64decode(b64)
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return img, None
return None, "missing image: please provide 'image' file or 'image_base64'"
except Exception as e:
return None, f"invalid image: {str(e)}"
def apply_roi_crop(img: Image.Image) -> Image.Image:
"""
可选裁剪 ROI:
- form/json 的 'roi' 字段,JSON 字符串或对象
- 支持像素坐标:{x,y,w,h}
- 支持归一化坐标:{x,y,w,h, normalized:true}(0~1 相对原图)
裁剪非法则原样返回。
"""
roi_raw = request.form.get("roi") or (request.json.get("roi") if request.is_json else None)
if not roi_raw:
return img
try:
roi = json.loads(roi_raw) if isinstance(roi_raw, str) else dict(roi_raw)
W, H = img.size
x, y, w, h = float(roi.get("x", 0)), float(roi.get("y", 0)), float(roi.get("w", 0)), float(roi.get("h", 0))
normalized = bool(roi.get("normalized", False))
if normalized:
x, y, w, h = x * W, y * H, w * W, h * H
# 约束边界
x1 = max(0, min(W - 1, int(x)))
y1 = max(0, min(H - 1, int(y)))
x2 = max(x1 + 1, min(W, int(x + w)))
y2 = max(y1 + 1, min(H, int(y + h)))
if x2 <= x1 or y2 <= y1:
return img
return img.crop((x1, y1, x2, y2))
except Exception:
return img
def map_attributes(det_summary: Dict[str, float]) -> Dict[str, Any]:
"""
将检测的二分类汇总为观众属性:
- state: 'awake' | 'drowsy' | 'unknown'
- drowsy: 布尔
- attentionScore: 0~1,用 'awake' 最大置信度近似
"""
awake_p = det_summary.get("awake", 0.0)
drowsy_p = det_summary.get("drowsy", 0.0)
if awake_p == 0.0 and drowsy_p == 0.0:
state = "unknown"
else:
state = "drowsy" if drowsy_p >= awake_p else "awake"
attention_score = max(0.0, min(1.0, awake_p))
return {
"state": state,
"drowsy": state == "drowsy",
"attentionScore": attention_score
}
@app.get("/health")
def health():
logger.info("[health] ok")
return jsonify({"status": "ok"})
@app.post("/infer")
def infer():
conf, imgsz, iou, device = parse_params()
logger.info(f"[infer] params conf={conf}, imgsz={imgsz}, iou={iou}, device={device}")
img, err = read_image_from_request()
if err:
logger.warning(f"[infer] bad request: {err}")
return jsonify({"error": err}), 400
# 可选裁剪
img = apply_roi_crop(img)
try:
w, h = img.size
logger.info(f"[infer] image size: {w}x{h}")
except Exception:
pass
# 执行推理
try:
start_t = time.time()
results = model.predict(
img,
conf=conf,
imgsz=imgsz,
iou=iou,
device=device,
verbose=False
)
elapsed_ms = (time.time() - start_t) * 1000.0
logger.info(f"[infer] inference done in {elapsed_ms:.1f} ms")
except Exception as e:
logger.exception(f"[infer] inference failed: {e}")
return jsonify({"error": f"inference failed: {str(e)}"}), 500
dets = []
det_summary = {"awake": 0.0, "drowsy": 0.0}
# 解析结果
try:
if results:
res = results[0]
if hasattr(res, "boxes") and res.boxes is not None:
for b in res.boxes:
cls_idx = int(b.cls[0].item()) if hasattr(b, "cls") else None
conf_v = float(b.conf[0].item()) if hasattr(b, "conf") else 0.0
name = CLS_IDX_TO_NAME.get(cls_idx, str(cls_idx)) if CLS_IDX_TO_NAME else str(cls_idx)
if name in ("awake", "drowsy"):
det_summary[name] = max(det_summary.get(name, 0.0), conf_v)
xyxy = b.xyxy[0].tolist() if hasattr(b, "xyxy") else []
dets.append({
"class": name,
"conf": conf_v,
"bbox": xyxy
})
except Exception as e:
logger.exception(f"[infer] parse results failed: {e}")
return jsonify({"error": f"parse results failed: {str(e)}"}), 500
attributes = map_attributes(det_summary)
logger.info(f"[infer] attributes: {json.dumps(attributes, ensure_ascii=False)}")
logger.info(f"[infer] detections({len(dets)}): {json.dumps(dets, ensure_ascii=False)}")
return jsonify({
"attributes": attributes,
"detections": dets,
"meta": {
"conf": conf,
"imgsz": imgsz,
"iou": iou,
"device": device,
"model_path": MODEL_PATH,
"classes": CLS_IDX_TO_NAME
}
})
if __name__ == "__main__":
# 注意:首次运行可能触发 Windows 防火墙弹窗,请允许访问
app.run(host="0.0.0.0", port=5001, debug=False)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s"
)
logger = logging.getLogger("yolo_server")
浙公网安备 33010602011771号