实用指南:[特殊字符] 用 PyTorch 打造「CNN-LSTM-Attention」股票预测神器!——从 0 到 1 的保姆级教程(附完整源码)
前言:为什么这套模型能让你的策略胜率飙升?
在量化江湖里,CNN 擅于捕局部形态(如 K 线组合),LSTM 长于记长期记忆(如趋势),Attention 专治“信息过载”(自动给重要时间点加权)。把三大杀器融合,就是今天的主角——CNN-LSTM-Attention 多模态股价预测模型。
读完本文,你将收获:
1、一套可直接跑通的 PyTorch 源码(model.py + train.py + predict.py)
2、数据爬取→特征工程→训练→评估→预测的 全流程拆解
3、早停法、反标准化、维度对齐等 工程技巧
第 0 步:环境清单(30 秒搞定)
| 组件 | 版本 | 安装命令 |
|---|---|---|
| Python | ≥3.8 | — |
| PyTorch | ≥1.12 | pip install torch |
| mootdx | 最新 | pip install mootdx |
| numpy & pandas | 最新 | pip install numpy pandas |
提示:CUDA 驱动≥11.6 可启用 GPU 加速,训练速度提升 5×。
第 1 步:数据获取——用 mootdx 拉取 A 股 1 分钟都不耽误
from mootdx.reader import Reader
reader = Reader.factory(market='std', tdxdir='C:/new_tdx') # 你的通达信安装目录
df = reader.daily(symbol='000061') # 以「农产品」为例
- 返回字段:open, close, high, low, volume——刚好 5 维特征,完美契合模型输入。
- 数据已天然复权,无需额外清洗,量化小白也能 0 踩坑。
️ 第 2 步:模型架构——一张图看懂“三件套”如何串联
Input(30,5) ─►Permute ─►Conv1d ─►BN ─►ReLU ─►Permute ─►LSTM ─►Attention ─►FC ─►Output(1)
| 模块 | 输出形状 | 作用 |
|---|---|---|
| Conv1d | (64,30) | 提取局部波动特征,类似“识别 K 线组合” |
| LSTM | (64,) | 捕捉 30 天里的长期依赖 |
| Attention | (30,) 权重 | 自动聚焦“最关键那几天” |
| FC | 1 | 映射为下一天收盘价 |
代码亮点:
- 维度自动修复:x.dim()==2 时自动 unsqueeze,避免 RuntimeError
- 输入 time_step 可变,只需改 1 个参数,模型自动适配。
第 3 步:数据管道——60% 训练 / 20% 验证 / 20% 测试
train_size = int(0.6 * len(X))
val_size = int(0.2 * len(X))
- 采用 滚动窗口 生成样本,防止未来函数。
- 标准化使用全局 mean/std 并保存,预测阶段反向还原,保证线上线下一致性。
第 4 步:训练技巧——早停法 + 最佳模型保存
best_val_loss = float('inf')
patience = 5
...
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
- 早停法让模型 自动停在最优点,避免过拟合。
- 验证集与测试集 完全隔离,杜绝“偷看答案”。
第 5 步:一键预测——3 行代码给出明日收盘价
predictor = StockPredictor('stock_model.pth', '000061')
pred = predictor.predict_next_day(recent_30_days)
print(f"预测下个交易日收盘价: {
pred:.2f}")
- 类封装 + 反标准化,开箱即用。
- 支持批量换仓:只需循环调用 predict_next_day,即可生成全市场打分。
第 6 步:结果可视化——把loss曲线画出来,老板更爱看
import matplotlib.pyplot as plt
plt.plot(val_loss_list, label='Val Loss')
plt.plot(test_loss_list, label='Test Loss')
plt.title('CNN-LSTM-Attention Training Curve')
plt.savefig('loss_curve.png', dpi=300)
第 7 步:超参 tuning——3 个旋钮让预测误差再降 10%
| 超参 | 推荐范围 | 作用 |
|---|---|---|
| lstm_hidden | 32~128 | 越大记忆容量越高,但易过拟合 |
| cnn_channels | 32~128 | 控制卷积核数量,影响局部特征丰富度 |
| lr | 1e-4~1e-2 | 学习率过大震荡,过小收敛慢 |
建议使用 Optuna 自动搜索,10 次试验即可找到最佳组合。
第 8 步:风险声明——模型不是“印钞机”
1、过往业绩不代表未来表现,股市有风险,投资需谨慎。
2、本文仅供 教育 & 研究 之用,不构成任何投资建议。
3、实盘前请做 充分回测 & 压力测试,并配合风控系统。
附录:完整源码
网络结构定义model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN_LSTM_Attention(nn.Module):
def __init__(self, input_dim, time_step, lstm_hidden=64
浙公网安备 33010602011771号