小隐的博客

人生在世,笑饮一生
  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

AI量化qlib学习笔记三:训练模型

Posted on 2025-08-09 16:49  隐客  阅读(228)  评论(0)    收藏  举报

这里有几个重要注意点,一是模型,它提供了好几个模型,但对时间序列以及初入学者,基本都是选择LGBModel(LightGBM)

二是因子,它有一些因子库,最常见的就是Alpha158,可以自己去学习,因为我是想做自己的因子,所以都是自己定义的,事实证明花一两周的时间,只是个匹毛,这里一定要搞清楚模型根据因子学习的原理,举个例子,传统量化可能是找金叉再加其它条件,可能触发信号,AI里没有这个概念,因为你要在AI里也这样写,就没有用AI的意义了。你只要告诉它注意dif和dea这两个因子的数据,它会自己去学习。如果写得太具体,就容易过拟合,也就是定制信号,搞不好这种后面都不再出现。所以写因子会是整个AI量化的绝大分工作。另外要注意,不支持跨周期是我放弃的最大原因。

三是可以参考它的例子开启训练,qlib_src\examples\benchmarks\LightGBM下有好几个yaml配置文件 ,启动训练:qrun 配置.yaml ,这里yaml里的参数,够学习很长时间了。

它这个是一个整体工作流,从训练,到回测到分析都有,但实际对刚接触AI的人非常难,所以为了学习我在这里就把这些步骤分拆了训练就单训练。

四,用配置文件是个非常不方便的方法 ,所以我用代码来实现了配置文件qlib_config_generator.py

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 """
  4 使用配置生成器的Qlib训练脚本 - 参考qlib_train_save.py的结构
  5 """
  6 
  7 import qlib
  8 import joblib
  9 import os
 10 import json
 11 from qlib.utils import init_instance_by_config
 12 from datetime import datetime, date
 13 from pathlib import Path
 14 
 15 # 导入我们的配置生成器
 16 from qlib_config_generator import QlibConfigGenerator
 17 
 18 
 19 def create_model_directory(contract, freq):
 20     """创建模型保存目录"""
 21     base_dir = 'models'
 22     model_dir = os.path.join(base_dir, contract, freq)
 23     os.makedirs(model_dir, exist_ok=True)
 24     return model_dir
 25 
 26 def get_features_from_handler(dataset):
 27     """从dataset的handler中获取特征配置"""
 28     try:
 29         handler = dataset.handler
 30         if hasattr(handler, 'get_feature_config'):
 31             return handler.get_feature_config()
 32         elif hasattr(handler, 'data_loader') and hasattr(handler.data_loader, 'config'):
 33             return handler.data_loader.config.get('feature', [])
 34         else:
 35             return []
 36     except Exception as e:
 37         print(f"⚠ 获取特征配置失败: {e}")
 38         return []
 39 
 40 def get_labels_from_handler(dataset):
 41     """从dataset的handler中获取标签配置"""
 42     try:
 43         handler = dataset.handler
 44         if hasattr(handler, 'get_label_config'):
 45             return handler.get_label_config()
 46         elif hasattr(handler, 'data_loader') and hasattr(handler.data_loader, 'config'):
 47             return handler.data_loader.config.get('label', [])
 48         else:
 49             return []
 50     except Exception as e:
 51         print(f"⚠ 获取标签配置失败: {e}")
 52         return []
 53 
 54 def json_serializer(obj):
 55     """自定义JSON序列化器,处理date和datetime对象"""
 56     if isinstance(obj, (datetime, date)):
 57         return obj.isoformat()
 58     raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
 59 
 60 def train_model(contract, freq, start_time=None, end_time=None):
 61     """
 62     训练模型
 63     
 64     Args:
 65         contract: 合约名称,如 "EBL8.DCE"
 66         freq: 频率,如 "1min", "5min", "15min", "day"
 67         start_time: 开始时间(可选,默认从数据中自动获取)
 68         end_time: 结束时间(可选,默认从数据中自动获取)
 69     """
 70     print(f"=== 按合约和周期训练并保存qlib模型 ===")
 71     print(f"合约: {contract}, 周期: {freq}")
 72     
 73     # 1. 生成配置
 74     generator = QlibConfigGenerator()
 75     config = generator.generate_config(contract, freq, start_time, end_time)
 76     print("✓ 配置生成完成")
 77     
 78     # 2. 创建模型保存目录
 79     model_dir = create_model_directory(contract, freq)
 80     print(f"✓ 模型保存目录: {model_dir}")
 81     
 82     # 3. 初始化qlib
 83     qlib_config = config['qlib_init']
 84     qlib.init(
 85         provider_uri=qlib_config['provider_uri'],
 86         region=qlib_config['region']
 87     )
 88     print("✓ Qlib初始化完成")
 89     print(f"  数据路径: {qlib_config['provider_uri']}")
 90     
 91     # 4. 创建数据集
 92     dataset_config = config['task']['dataset']
 93     dataset = init_instance_by_config(dataset_config)
 94     print("✓ 数据集创建完成")
 95     
 96     # 5. 训练模型
 97     model_config = config['task']['model']
 98     model = init_instance_by_config(model_config)
 99     
100     print("开始训练模型...")
101     model.fit(dataset)
102     print("✓ 模型训练完成")
103     
104     # 6. 评估模型
105     performance_stats = {}
106     try:
107         test_pred = model.predict(dataset, segment="test")
108         pred_mean = test_pred.mean().iloc[0] if hasattr(test_pred.mean(), 'iloc') else test_pred.mean()
109         pred_std = test_pred.std().iloc[0] if hasattr(test_pred.std(), 'iloc') else test_pred.std()
110         
111         performance_stats = {
112             'pred_mean': float(pred_mean),
113             'pred_std': float(pred_std),
114             'pred_count': len(test_pred)
115         }
116         
117         print(f"模型评估 - 预测均值: {pred_mean:.6f}, 标准差: {pred_std:.6f}")
118     except Exception as e:
119         print(f"⚠ 模型评估失败: {e}")
120         performance_stats = {'error': str(e)}
121     
122     # 7. 保存模型(按目录结构)
123     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
124     
125     # 模型文件路径
126     model_file = os.path.join(model_dir, f'{contract}_{freq}_model_{timestamp}.pkl')
127     latest_file = os.path.join(model_dir, f'{contract}_{freq}_latest.pkl')
128     
129     # 保存模型
130     joblib.dump(model, model_file)
131     joblib.dump(model, latest_file)
132     
133     print(f"✓ 模型已保存到: {model_file}")
134     print(f"✓ 最新模型: {latest_file}")
135     
136     # 8. 保存详细的模型信息
137     model_info = {
138         'contract': contract,
139         'frequency': freq,
140         'timestamp': timestamp,
141         'model_type': config['task']['model']['class'],
142         'model_params': config['task']['model']['kwargs'],
143         'training_period': {
144             'train': config['task']['dataset']['kwargs']['segments']['train'],
145             'valid': config['task']['dataset']['kwargs']['segments']['valid'],
146             'test': config['task']['dataset']['kwargs']['segments']['test']
147         },
148         'data_config': {
149             'instruments': config['task']['dataset']['kwargs']['handler']['kwargs']['instruments'],
150             'features': get_features_from_handler(dataset),
151             'label': get_labels_from_handler(dataset)
152         },
153         'performance': performance_stats,
154         'files': {
155             'model_file': model_file,
156             'latest_file': latest_file
157         }
158     }
159     
160     # 保存模型信息
161     info_file = os.path.join(model_dir, f'{contract}_{freq}_info_{timestamp}.json')
162     latest_info_file = os.path.join(model_dir, f'{contract}_{freq}_latest_info.json')
163     
164     with open(info_file, 'w', encoding='utf-8') as f:
165         json.dump(model_info, f, ensure_ascii=False, indent=2, default=json_serializer)
166     
167     with open(latest_info_file, 'w', encoding='utf-8') as f:
168         json.dump(model_info, f, ensure_ascii=False, indent=2, default=json_serializer)
169     
170     print(f"✓ 模型信息已保存: {info_file}")
171     
172     # 9. 创建模型索引文件(便于管理多个模型)
173     index_file = os.path.join('models', 'model_index.json')
174     
175     # 读取现有索引
176     if os.path.exists(index_file):
177         with open(index_file, 'r', encoding='utf-8') as f:
178             index = json.load(f)
179     else:
180         index = {}
181     
182     # 更新索引
183     key = f"{contract}_{freq}"
184     index[key] = {
185         'contract': contract,
186         'frequency': freq,
187         'latest_model': latest_file,
188         'latest_info': latest_info_file,
189         'last_updated': timestamp,
190         'model_dir': model_dir
191     }
192     
193     # 保存索引
194     with open(index_file, 'w', encoding='utf-8') as f:
195         json.dump(index, f, ensure_ascii=False, indent=2, default=json_serializer)
196     
197     print(f"✓ 模型索引已更新: {index_file}")
198     
199     print(f"\n🎉 模型训练和保存完成!")
200     print(f"📁 模型目录结构:")
201     print(f"   models/")
202     print(f"   ├── {contract}/")
203     print(f"   │   └── {freq}/")
204     print(f"   │       ├── {contract}_{freq}_latest.pkl")
205     print(f"   │       ├── {contract}_{freq}_latest_info.json")
206     print(f"   │       └── {contract}_{freq}_model_{timestamp}.pkl")
207     print(f"   └── model_index.json")
208     print(f"\n💡 可以使用模型索引文件快速定位和加载模型")
209     
210     return model, model_info
211 
212 def batch_train(contracts, frequencies, start_time=None, end_time=None):
213     """
214     批量训练多个合约和频率
215     
216     Args:
217         contracts: 合约列表,如 ["EBL8.DCE", "RBL8.SHFE"]
218         frequencies: 频率列表,如 ["1min", "15min", "day"]
219         start_time: 开始时间(可选,默认从数据中自动获取)
220         end_time: 结束时间(可选,默认从数据中自动获取)
221     """
222     print(f"=== 批量训练 ===")
223     print(f"合约: {contracts}")
224     print(f"频率: {frequencies}")
225     if start_time and end_time:
226         print(f"时间范围: {start_time} 到 {end_time}")
227     else:
228         print("时间范围: 自动从数据中获取")
229     
230     results = {}
231     
232     for contract in contracts:
233         results[contract] = {}
234         for freq in frequencies:
235             print(f"\n{'='*80}")
236             try:
237                 model, model_info = train_model(contract, freq, start_time, end_time)
238                 results[contract][freq] = "成功"
239             except Exception as e:
240                 print(f"✗ {contract}-{freq} 训练失败: {e}")
241                 results[contract][freq] = f"失败: {e}"
242                 import traceback
243                 traceback.print_exc()
244     
245     # 打印汇总结果
246     print(f"\n{'='*80}")
247     print("=== 批量训练结果汇总 ===")
248     for contract, freq_results in results.items():
249         print(f"\n{contract}:")
250         for freq, status in freq_results.items():
251             print(f"  {freq}: {status}")
252     
253     return results
254 
255 def main():
256     """主函数"""
257     print("=== Qlib训练脚本(使用配置生成器)===")
258     
259     try:
260         # 单个训练示例
261         contract = "EBL8.DCE"
262         freq = "day"
263         
264         print(f"训练合约: {contract}")
265         print(f"训练频率: {freq}")
266         
267         model, model_info = train_model(contract, freq)
268         print("\n✓ 训练成功完成!")
269         
270         # 如果需要批量训练,取消下面的注释
271         # print("\n" + "="*80)
272         # batch_results = batch_train(
273         #     contracts=["EBL8.DCE", "RBL8.SHFE"], 
274         #     frequencies=["15min", "day"]
275         # )
276         
277     except Exception as e:
278         print(f"❌ 训练失败: {e}")
279         import traceback
280         traceback.print_exc()
281 
282 if __name__ == "__main__":
283     main()

qlib_config_generator.py

 

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 """
  4 Qlib配置生成器 - 集成自定义Handler和配置文件生成
  5 """
  6 
  7 import pandas as pd
  8 import numpy as np
  9 import yaml
 10 import os
 11 from pathlib import Path
 12 from qlib.data.dataset.handler import DataHandlerLP
 13 from qlib.log import get_module_logger
 14 import glob
 15 from datetime import datetime
 16 import qlib
 17 from qlib.data import D
 18 
 19 # ==================== 自定义Handler类 ====================
 20 
 21 class CustomTrendHandler_m1(DataHandlerLP):
 22     """自定义因子处理器 - 1分钟"""
 23     
 24     def __init__(self, **kwargs):
 25         self.freq = kwargs.pop('freq', '1min')
 26         kwargs.pop('inst_processors', None)
 27         
 28         feature_config = self.get_feature_config()
 29         label_config = self.get_label_config()
 30         
 31         kwargs['data_loader'] = {
 32             'class': 'QlibDataLoader',
 33             'module_path': 'qlib.data.dataset.loader',
 34             'kwargs': {
 35                 'freq': self.freq,
 36                 'config': {
 37                     'feature': feature_config,
 38                     'label': label_config
 39                 }
 40             }
 41         }
 42                 
 43         super().__init__(**kwargs)
 44         self.logger = get_module_logger(self.__class__.__name__)
 45         self.logger.info(f"CustomTrendHandler初始化完成,使用 {len(feature_config)} 个因子")
 46         
 47     def get_feature_config(self):
 48         """1分钟级别的因子配置"""
 49         ma7 = "Mean(Mean($close, 7), 3)"
 50         ma20 = "Mean(Mean($close, 20), 3)" 
 51         ma60 = "Mean(Mean($close, 60), 3)"
 52         v5 = "Mean(Mean($volume, 5), 3)"
 53         v15= "Mean(Mean($volume, 15), 3)"
 54         v50= "Mean(Mean($volume, 50), 3)"
 55         
 56         volume_fields = [
 57             f"$volume / {v50}",
 58             f"Ref($volume,1) / Ref({v50},1)",
 59             f"Ref($volume,2) / Ref({v50},2)",
 60             f"Ref($volume,3) / Ref({v50},3)",
 61             f"({v5} - {v15})/(Ref({v5} - {v15},1))",
 62             f"({v5} - {v50})/(Ref({v5} - {v50},1))",
 63             f"({v15} - {v50})/(Ref({v15} - {v50},1))",
 64         ]
 65         
 66         ma_fields = [
 67             f"{ma7} / {ma20}",
 68             f"{ma7} / {ma60}",
 69             f"{ma20} / {ma60}",  
 70             f"(({ma7} - {ma20}) / Ref({ma7} - {ma20}, 1))",
 71             f"(({ma7} - {ma60}) / Ref({ma7} - {ma60}, 1))",
 72             f"(({ma20} - {ma60}) / Ref({ma20} - {ma60}, 1))",
 73         ]
 74         
 75         current_body = "Abs($close - $open)"
 76         prev_4_avg = "Mean(Ref(Abs($close - $open), 1), 4)"
 77         
 78         price_fields = [
 79             f"{current_body} / {prev_4_avg}",
 80             current_body,
 81             f"{current_body} / $close",
 82             "($close - $open) / $close",
 83             "($high - Greater($close, $open)) / $close",
 84             "(Less($close, $open) - $low) / $close",
 85             f"($close - Ref($close, 3)) / Ref($close, 3)",
 86         ]
 87         
 88         return volume_fields + ma_fields + price_fields
 89     
 90     def get_label_config(self):
 91         return ["Ref($close, -10) / $close - 1"]
 92 
 93 class CustomTrendHandler_m5(DataHandlerLP):
 94     """自定义因子处理器 - 5分钟"""
 95     
 96     def __init__(self, **kwargs):
 97         self.freq = kwargs.pop('freq', '5min')
 98         kwargs.pop('inst_processors', None)
 99         
100         feature_config = self.get_feature_config()
101         label_config = self.get_label_config()
102         
103         kwargs['data_loader'] = {
104             'class': 'QlibDataLoader',
105             'module_path': 'qlib.data.dataset.loader',
106             'kwargs': {
107                 'freq': self.freq,
108                 'config': {
109                     'feature': feature_config,
110                     'label': label_config
111                 }
112             }
113         }
114                 
115         super().__init__(**kwargs)
116         self.logger = get_module_logger(self.__class__.__name__)
117         
118     def get_feature_config(self):
119         """5分钟级别的因子配置"""
120         ma5 = "Mean($close, 5)"
121         ma15 = "Mean($close, 15)" 
122         ma30 = "Mean($close, 30)"
123         
124         volume_fields = [
125             "$volume / Mean($volume, 20)",
126             "($volume - Ref($volume, 1)) / Ref($volume, 1)",
127             "$volume / Mean(Mean($volume, 30), 3)",
128         ]
129         
130         ma_fields = [
131             f"({ma5} - Ref({ma5}, 1)) / {ma5}",
132             f"({ma15} - Ref({ma15}, 1)) / {ma15}",
133             f"({ma30} - Ref({ma30}, 1)) / {ma30}",
134             f"{ma5} / {ma15}",
135             f"{ma15} / {ma30}",
136         ]
137         
138         price_fields = [
139             "Abs($close - $open) / $close",
140             "($close - $open) / $close",
141             "($high - Greater($close, $open)) / $close",
142             "(Less($close, $open) - $low) / $close",
143         ]
144         
145         return volume_fields + ma_fields + price_fields
146     
147     def get_label_config(self):
148         return ["Ref($close, -5) / $close - 1"]
149 
150 class CustomTrendHandler_m15(DataHandlerLP):
151     """自定义因子处理器 - 15分钟"""
152     
153     def __init__(self, **kwargs):
154         self.freq = kwargs.pop('freq', '15min')
155         kwargs.pop('inst_processors', None)
156         print("CustomTrendHandler_m15 初始化,freq =", self.freq)
157         feature_config = self.get_feature_config()
158         label_config = self.get_label_config()
159         
160         kwargs['data_loader'] = {
161             'class': 'QlibDataLoader',
162             'module_path': 'qlib.data.dataset.loader',
163             'kwargs': {
164                 'freq': self.freq,
165                 'config': {
166                     'feature': feature_config,
167                     'label': label_config
168                 }
169             }
170         }
171                 
172         super().__init__(**kwargs)
173         self.logger = get_module_logger(self.__class__.__name__)
174         
175     def get_feature_config(self):
176         """15分钟级别的因子配置"""
177         ma4 = "Mean($close, 4)"
178         ma12 = "Mean($close, 12)" 
179         ma24 = "Mean($close, 24)"
180         
181         volume_fields = [
182             "$volume / Mean($volume, 16)",
183             "($volume - Ref($volume, 1)) / Ref($volume, 1)",
184             "$volume / Mean(Mean($volume, 20), 2)",
185         ]
186         
187         ma_fields = [
188             f"({ma4} - Ref({ma4}, 1)) / {ma4}",
189             f"({ma12} - Ref({ma12}, 1)) / {ma12}",
190             f"({ma24} - Ref({ma24}, 1)) / {ma24}",
191             f"{ma4} / {ma12}",
192             f"{ma12} / {ma24}",
193         ]
194         
195         price_fields = [
196             "Abs($close - $open) / $close",
197             "($close - $open) / $close",
198             "($high - Greater($close, $open)) / $close",
199             "(Less($close, $open) - $low) / $close",
200             "($close - Ref($close, 2)) / Ref($close, 2)",
201         ]
202         
203         return volume_fields + ma_fields + price_fields
204     
205     def get_label_config(self):
206         return ["Ref($close, -3) / $close - 1"]
207 
208 class CustomTrendHandler_day(DataHandlerLP):
209     """自定义因子处理器 - 日线"""
210     
211     def __init__(self, **kwargs):
212         self.freq = kwargs.pop('freq', 'day')
213         kwargs.pop('inst_processors', None)
214         print("CustomTrendHandler_day 初始化,freq =", self.freq)
215         feature_config = self.get_feature_config()
216         label_config = self.get_label_config()
217         
218         kwargs['data_loader'] = {
219             'class': 'QlibDataLoader',
220             'module_path': 'qlib.data.dataset.loader',
221             'kwargs': {
222                 'freq': self.freq,
223                 'config': {
224                     'feature': feature_config,
225                     'label': label_config
226                 }
227             }
228         }
229                 
230         super().__init__(**kwargs)
231         self.logger = get_module_logger(self.__class__.__name__)
232         
233     def get_feature_config(self):
234         """日线级别的因子配置 - 修复的大涨回调策略因子"""
235         
236         # 基础均线
237         ma5 = "Mean($close, 5)"
238         ma20 = "Mean($close, 20)" 
239         ma60 = "Mean($close, 60)"
240         
241         # 原有的基础因子
242         volume_fields = [
243             "$volume / Mean($volume, 20)",
244             "($volume - Ref($volume, 1)) / Ref($volume, 1)",
245             "$volume / Mean($volume, 5)",
246         ]
247         
248         ma_fields = [
249             f"({ma5} - Ref({ma5}, 1)) / {ma5}",
250             f"({ma20} - Ref({ma20}, 1)) / {ma20}",
251             f"({ma60} - Ref({ma60}, 1)) / {ma60}",
252             f"{ma5} / {ma20}",
253             f"{ma20} / {ma60}",
254         ]
255         
256         price_fields = [
257             "Abs($close - $open) / $close",
258             "($close - $open) / $close",
259             "($high - Greater($close, $open)) / $close",
260             "(Less($close, $open) - $low) / $close",
261             "($close - Ref($close, 1)) / Ref($close, 1)",
262         ]
263         
264         # 修复的大涨回调策略因子
265         strategy_fields = [
266             # 1. 检测各天的涨幅
267             "(Ref($close, 1) - Ref($close, 2)) / Ref($close, 2)",  # 1天前涨幅
268             "(Ref($close, 2) - Ref($close, 3)) / Ref($close, 3)",  # 2天前涨幅
269             "(Ref($close, 3) - Ref($close, 4)) / Ref($close, 4)",  # 3天前涨幅
270             "(Ref($close, 4) - Ref($close, 5)) / Ref($close, 5)",  # 4天前涨幅
271             "(Ref($close, 5) - Ref($close, 6)) / Ref($close, 6)",  # 5天前涨幅
272             
273             # 2. 检测是否有大涨(>3%)
274             "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, 1, 0)",  # 1天前大涨
275             "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, 1, 0)",  # 2天前大涨
276             "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, 1, 0)",  # 3天前大涨
277             "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, 1, 0)",  # 4天前大涨
278             "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, 1, 0)",  # 5天前大涨
279             
280             # 3. 5天内是否有大涨
281             "If(((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03) | " +
282             "((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03) | " +
283             "((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03) | " +
284             "((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03) | " +
285             "((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03), 1, 0)",  # 5天内有大涨
286             
287             # 4. 找到最近的大涨日收盘价(从近到远检查)
288             "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, Ref($close, 1), " +
289             "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, Ref($close, 2), " +
290             "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, Ref($close, 3), " +
291             "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, Ref($close, 4), " +
292             "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, Ref($close, 5), $close)))))",  # 最近大涨日收盘价
293             
294             # 5. 大涨后是否有回调(检查大涨日后的最低价是否低于大涨日收盘价)
295             # 1天前大涨的情况
296             "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, " +
297             "If($close < Ref($close, 1), 1, 0), 0)",  # 1天前大涨后今天回调
298             
299             # 2天前大涨的情况
300             "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, " +
301             "If(Min($close, 2) < Ref($close, 2), 1, 0), 0)",  # 2天前大涨后有回调
302             
303             # 3天前大涨的情况
304             "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, " +
305             "If(Min($close, 3) < Ref($close, 3), 1, 0), 0)",  # 3天前大涨后有回调
306             
307             # 4天前大涨的情况
308             "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, " +
309             "If(Min($close, 4) < Ref($close, 4), 1, 0), 0)",  # 4天前大涨后有回调
310             
311             # 5天前大涨的情况
312             "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, " +
313             "If(Min($close, 5) < Ref($close, 5), 1, 0), 0)",  # 5天前大涨后有回调
314             
315             # 6. 综合:是否有大涨后回调
316             "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, If($close < Ref($close, 1), 1, 0), " +
317             "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, If(Min($close, 2) < Ref($close, 2), 1, 0), " +
318             "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, If(Min($close, 3) < Ref($close, 3), 1, 0), " +
319             "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, If(Min($close, 4) < Ref($close, 4), 1, 0), " +
320             "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, If(Min($close, 5) < Ref($close, 5), 1, 0), 0)))))",
321             
322             # 7. 当前价格相对最近大涨日的比例(检查是否跌破80%)
323             "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, $close / Ref($close, 1), " +
324             "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, $close / Ref($close, 2), " +
325             "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, $close / Ref($close, 3), " +
326             "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, $close / Ref($close, 4), " +
327             "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, $close / Ref($close, 5), 1)))))",
328             
329             # 8. 今日放量上涨
330             "If(($close > $open) & ($volume > Mean($volume, 5) * 1.2), 1, 0)",  # 今日放量上涨
331             "($volume / Mean($volume, 5)) * If($close > $open, 1, 0)",  # 上涨放量倍数
332             
333             # 9. 完整策略信号
334             "If((" +
335             # 条件1:5天内有大涨且不是今天
336             "((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03) | " +
337             "((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03) | " +
338             "((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03) | " +
339             "((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03) | " +
340             "((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03)" +
341             ") & (" +
342             # 条件2:大涨后有回调
343             "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, If($close < Ref($close, 1), 1, 0), " +
344             "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, If(Min($close, 2) < Ref($close, 2), 1, 0), " +
345             "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, If(Min($close, 3) < Ref($close, 3), 1, 0), " +
346             "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, If(Min($close, 4) < Ref($close, 4), 1, 0), " +
347             "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, If(Min($close, 5) < Ref($close, 5), 1, 0), 0))))) == 1" +
348             ") & (" +
349             # 条件3:未跌破80%
350             "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, $close / Ref($close, 1), " +
351             "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, $close / Ref($close, 2), " +
352             "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, $close / Ref($close, 3), " +
353             "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, $close / Ref($close, 4), " +
354             "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, $close / Ref($close, 5), 1))))) >= 0.8" +
355             ") & (" +
356             # 条件4:今日放量上涨
357             "($close > $open) & ($volume > Mean($volume, 5) * 1.2)" +
358             "), 1, 0)",
359             
360             # 10. 辅助因子
361             "Max(($close - Ref($close, 1)) / Ref($close, 1), 5)",  # 5天内最大涨幅
362             "($close - Min($close, 5)) / Min($close, 5)",  # 相对5天最低点涨幅
363             "(Max($close, 5) - $close) / Max($close, 5)",  # 相对5天最高点回调幅度
364             
365             # 11. 量价配合
366             "($close - $open) / $close * ($volume / Mean($volume, 5))",  # 涨幅*放量倍数
367             "If($volume > Mean($volume, 5) * 1.2, ($close - $open) / $close, 0)",  # 放量时的涨幅
368         ]
369         
370         return volume_fields + ma_fields + price_fields + strategy_fields
371 
372 
373 
374 
375     def get_label_config(self):
376         return ["Ref($close, -1) / $close - 1"]
377 
378 # ==================== 配置生成器 ====================
379 
380 class QlibConfigGenerator:
381     """Qlib配置生成器"""
382     
383     def __init__(self):
384         self.base_data_path = "C:/ctp/q_qlib/qlib_data/cn_data"
385         self.source_data_path = "C:/ctp/q_qlib/cleaned_data"
386     
387     def get_data_time_range(self, contract, freq):
388         """
389         从qlib数据源中获取时间范围
390         
391         Args:
392             contract: 合约名称,如 "EBL8.DCE"
393             freq: 频率,如 "1min", "5min", "15min", "day"
394         
395         Returns:
396             tuple: (start_time, end_time, train_end, valid_start, valid_end)
397         """
398         try:
399             # 频率映射
400             freq_mapping = {
401                 "1min": "1min",
402                 "5min": "5min", 
403                 "15min": "15min",
404                 "day": "day"
405             }
406             
407             qlib_freq = freq_mapping.get(freq, freq)
408             contract_path = f"{self.base_data_path}/{contract}"
409             
410             print(f"  从qlib数据源读取: {contract_path}")
411             
412             # 临时初始化qlib以读取数据
413             try:
414                 qlib.init(provider_uri=contract_path, region="cn")
415             except:
416                 # 如果已经初始化过,忽略错误
417                 pass
418             
419             # 使用qlib的D.features接口读取数据
420             # 读取一个简单的字段来获取时间索引
421             data = D.features([contract], ["$close"], freq=qlib_freq)
422             
423             if data is None or len(data) == 0:
424                 raise ValueError(f"未找到qlib数据: {contract} - {qlib_freq},请检查数据路径: {contract_path}")
425             
426             # 获取时间索引
427             time_index = data.index.get_level_values('datetime')
428             start_time = time_index.min()
429             end_time = time_index.max()
430             total_rows = len(time_index.unique())
431             
432             print(f"  数据时间范围: {start_time} 到 {end_time}")
433             print(f"  总数据量: {total_rows} 行")
434             
435             # 计算分割点:90%用于训练,10%用于验证和测试
436             train_rows = int(total_rows * 0.9)
437             
438             # 获取排序后的唯一时间点
439             unique_times = sorted(time_index.unique())
440             train_end_time = unique_times[train_rows - 1]
441             
442             # 格式化时间为字符串
443             start_str = start_time.strftime('%Y-%m-%d')
444             end_str = end_time.strftime('%Y-%m-%d')
445             train_end_str = train_end_time.strftime('%Y-%m-%d')
446             
447             # valid和test都使用剩余的10%数据
448             valid_start_str = train_end_str
449             valid_end_str = end_str
450             
451             print(f"  时间分割:")
452             print(f"    训练集: {start_str} 到 {train_end_str} ({train_rows} 行, {train_rows/total_rows*100:.1f}%)")
453             print(f"    验证集: {valid_start_str} 到 {valid_end_str} ({total_rows-train_rows} 行, {(total_rows-train_rows)/total_rows*100:.1f}%)")
454             print(f"    测试集: {valid_start_str} 到 {valid_end_str} ({total_rows-train_rows} 行, {(total_rows-train_rows)/total_rows*100:.1f}%)")
455             
456             return start_str, end_str, train_end_str, valid_start_str, valid_end_str
457             
458         except Exception as e:
459             print(f"✗ 从qlib获取数据时间范围失败: {e}")
460             import traceback
461             traceback.print_exc()
462             raise RuntimeError(f"无法获取合约 {contract} 频率 {freq} 的数据时间范围,请检查数据是否存在") from e
463         
464     def generate_config(self, contract, freq, start_time=None, end_time=None, handler_module=None):
465         """
466         生成配置字典
467         
468         Args:
469             contract: 合约名称,如 "EBL8.DCE"
470             freq: 频率,如 "1min", "5min", "15min", "day"
471             start_time: 开始时间(可选,默认从数据中自动获取)
472             end_time: 结束时间(可选,默认从数据中自动获取)
473             handler_module: Handler模块路径(可选,默认自动检测)
474         """
475         # 如果没有提供时间范围,从数据中自动获取
476         if start_time is None or end_time is None:
477             print(f"自动获取 {contract} - {freq} 的数据时间范围...")
478             auto_start, auto_end, train_end, valid_start, valid_end = self.get_data_time_range(contract, freq)
479             start_time = start_time or auto_start
480             end_time = end_time or auto_end
481         else:
482             # 使用提供的时间范围,按90%-10%分割
483             print(f"使用指定时间范围: {start_time} 到 {end_time}")
484             start_dt = datetime.strptime(start_time, '%Y-%m-%d')
485             end_dt = datetime.strptime(end_time, '%Y-%m-%d')
486             total_days = (end_dt - start_dt).days
487             train_days = int(total_days * 0.9)
488             train_end_dt = start_dt + pd.Timedelta(days=train_days)
489             
490             train_end = train_end_dt.strftime('%Y-%m-%d')
491             valid_start = train_end
492             valid_end = end_time
493         
494         # 频率映射
495         freq_mapping = {
496             "1min": {"handler": "CustomTrendHandler_m1", "qlib_freq": "1min"},
497             "5min": {"handler": "CustomTrendHandler_m5", "qlib_freq": "5min"},
498             "15min": {"handler": "CustomTrendHandler_m15", "qlib_freq": "15min"},
499             "day": {"handler": "CustomTrendHandler_day", "qlib_freq": "day"}
500         }
501         
502         if freq not in freq_mapping:
503             raise ValueError(f"不支持的频率: {freq}")
504         
505         handler_info = freq_mapping[freq]
506         contract_path = f"{self.base_data_path}/{contract}"
507         
508         # 自动检测Handler模块路径
509         if handler_module is None:
510             import inspect
511             frame = inspect.currentframe()
512             try:
513                 # 获取调用者的模块名
514                 caller_frame = frame.f_back.f_back  # 跳过generate_config和train_model
515                 caller_module = caller_frame.f_globals.get('__name__', 'qlib_config_generator')
516                 if caller_module == '__main__':
517                     # 如果是主模块,尝试获取文件名
518                     caller_file = caller_frame.f_globals.get('__file__', '')
519                     if caller_file:
520                         import os
521                         module_name = os.path.splitext(os.path.basename(caller_file))[0]
522                         handler_module = module_name
523                     else:
524                         handler_module = 'qlib_config_generator'
525                 else:
526                     handler_module = caller_module
527             finally:
528                 del frame
529         
530         # 简化方案:直接使用配置生成器模块
531         handler_module = "qlib_config_generator"
532         print(f"  使用Handler模块: {handler_module}")
533         
534         config = {
535             "qlib_init": {
536                 "provider_uri": contract_path,
537                 "region": "cn"
538             },
539             "task": {
540                 "model": {
541                     "class": "LGBModel",
542                     "module_path": "qlib.contrib.model.gbdt",
543                     "kwargs": {
544                         "loss": "mse",
545                         "colsample_bytree": 0.8879,
546                         "learning_rate": 0.2,
547                         "subsample": 0.8789,
548                         "max_depth": 8,
549                         "num_leaves": 210,
550                         "num_threads": 20
551                     }
552                 },
553                 "dataset": {
554                     "class": "DatasetH",
555                     "module_path": "qlib.data.dataset",
556                     "kwargs": {
557                         "handler": {
558                             "class": handler_info["handler"],
559                             "module_path": handler_module,  # 动态检测的模块路径
560                             "kwargs": {
561                                 "start_time": start_time,
562                                 "end_time": end_time,
563                                 "instruments": [contract],
564                                 "freq": handler_info["qlib_freq"]
565                             }
566                         },
567                         "segments": {
568                             "train": [start_time, train_end],
569                             "valid": [valid_start, valid_end],
570                             "test": [valid_start, valid_end]
571                         }
572                     }
573                 },
574                 "record": [
575                     {
576                         "class": "SignalRecord",
577                         "module_path": "qlib.workflow.record_temp",
578                         "kwargs": {
579                             "model": "<MODEL>",
580                             "dataset": "<DATASET>"
581                         }
582                     },
583                     {
584                         "class": "SigAnaRecord",
585                         "module_path": "qlib.workflow.record_temp",
586                         "kwargs": {
587                             "ana_long_short": False,
588                             "ann_scaler": 252
589                         }
590                     }
591                 ]
592             }
593         }
594         
595         return config
596     
597     def save_config_yaml(self, config, filename):
598         """保存配置为YAML文件"""
599         with open(filename, 'w', encoding='utf-8') as f:
600             yaml.dump(config, f, default_flow_style=False, allow_unicode=True, indent=2)
601         print(f"✓ 配置文件已保存: {filename}")
602     
603     def generate_all_configs(self, contract):
604         """为指定合约生成所有频率的配置文件"""
605         frequencies = ["1min", "5min", "15min", "day"]
606         
607         for freq in frequencies:
608             config = self.generate_config(contract, freq)
609             filename = f"config_{contract.replace('.', '_')}_{freq}.yaml"
610             self.save_config_yaml(config, filename)
611         
612         print(f"✓ 已为合约 {contract} 生成所有配置文件")
613 
614 # ==================== 使用示例 ====================
615 
616 def main():
617     """主函数 - 演示如何使用"""
618     generator = QlibConfigGenerator()
619     
620     # 生成单个配置(自动获取时间范围)
621     print("=== 生成单个配置文件(自动时间范围)===")
622     contract = "EBL8.DCE"
623     freq = "15min"
624     
625     config = generator.generate_config(contract, freq)
626     generator.save_config_yaml(config, f"my_config_{freq}.yaml")
627     
628     # 生成指定时间范围的配置
629     print("\n=== 生成指定时间范围的配置文件 ===")
630     config_custom = generator.generate_config(contract, freq, "2020-01-01", "2024-12-31")
631     generator.save_config_yaml(config_custom, f"my_config_{freq}_custom.yaml")
632     
633     # 生成所有频率的配置
634     print("\n=== 生成所有频率配置文件 ===")
635     generator.generate_all_configs(contract)
636     
637     print("\n=== 配置生成完成 ===")
638     print("✓ 自动时间范围:从数据中获取第一行到最后一行,90%训练,10%验证/测试")
639     print("✓ 现在你可以直接使用这些配置文件进行训练,无需额外的factors目录!")
640 
641 if __name__ == "__main__":
642     main()

 

QQ交流