这里有几个重要注意点,一是模型,它提供了好几个模型,但对时间序列以及初入学者,基本都是选择LGBModel(LightGBM)
二是因子,它有一些因子库,最常见的就是Alpha158,可以自己去学习,因为我是想做自己的因子,所以都是自己定义的,事实证明花一两周的时间,只是个匹毛,这里一定要搞清楚模型根据因子学习的原理,举个例子,传统量化可能是找金叉再加其它条件,可能触发信号,AI里没有这个概念,因为你要在AI里也这样写,就没有用AI的意义了。你只要告诉它注意dif和dea这两个因子的数据,它会自己去学习。如果写得太具体,就容易过拟合,也就是定制信号,搞不好这种后面都不再出现。所以写因子会是整个AI量化的绝大分工作。另外要注意,不支持跨周期是我放弃的最大原因。
三是可以参考它的例子开启训练,qlib_src\examples\benchmarks\LightGBM下有好几个yaml配置文件 ,启动训练:qrun 配置.yaml ,这里yaml里的参数,够学习很长时间了。
它这个是一个整体工作流,从训练,到回测到分析都有,但实际对刚接触AI的人非常难,所以为了学习我在这里就把这些步骤分拆了训练就单训练。
四,用配置文件是个非常不方便的方法 ,所以我用代码来实现了配置文件qlib_config_generator.py
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 使用配置生成器的Qlib训练脚本 - 参考qlib_train_save.py的结构 5 """ 6 7 import qlib 8 import joblib 9 import os 10 import json 11 from qlib.utils import init_instance_by_config 12 from datetime import datetime, date 13 from pathlib import Path 14 15 # 导入我们的配置生成器 16 from qlib_config_generator import QlibConfigGenerator 17 18 19 def create_model_directory(contract, freq): 20 """创建模型保存目录""" 21 base_dir = 'models' 22 model_dir = os.path.join(base_dir, contract, freq) 23 os.makedirs(model_dir, exist_ok=True) 24 return model_dir 25 26 def get_features_from_handler(dataset): 27 """从dataset的handler中获取特征配置""" 28 try: 29 handler = dataset.handler 30 if hasattr(handler, 'get_feature_config'): 31 return handler.get_feature_config() 32 elif hasattr(handler, 'data_loader') and hasattr(handler.data_loader, 'config'): 33 return handler.data_loader.config.get('feature', []) 34 else: 35 return [] 36 except Exception as e: 37 print(f"⚠ 获取特征配置失败: {e}") 38 return [] 39 40 def get_labels_from_handler(dataset): 41 """从dataset的handler中获取标签配置""" 42 try: 43 handler = dataset.handler 44 if hasattr(handler, 'get_label_config'): 45 return handler.get_label_config() 46 elif hasattr(handler, 'data_loader') and hasattr(handler.data_loader, 'config'): 47 return handler.data_loader.config.get('label', []) 48 else: 49 return [] 50 except Exception as e: 51 print(f"⚠ 获取标签配置失败: {e}") 52 return [] 53 54 def json_serializer(obj): 55 """自定义JSON序列化器,处理date和datetime对象""" 56 if isinstance(obj, (datetime, date)): 57 return obj.isoformat() 58 raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") 59 60 def train_model(contract, freq, start_time=None, end_time=None): 61 """ 62 训练模型 63 64 Args: 65 contract: 合约名称,如 "EBL8.DCE" 66 freq: 频率,如 "1min", "5min", "15min", "day" 67 start_time: 开始时间(可选,默认从数据中自动获取) 68 end_time: 结束时间(可选,默认从数据中自动获取) 69 """ 70 print(f"=== 按合约和周期训练并保存qlib模型 ===") 71 print(f"合约: {contract}, 周期: {freq}") 72 73 # 1. 生成配置 74 generator = QlibConfigGenerator() 75 config = generator.generate_config(contract, freq, start_time, end_time) 76 print("✓ 配置生成完成") 77 78 # 2. 创建模型保存目录 79 model_dir = create_model_directory(contract, freq) 80 print(f"✓ 模型保存目录: {model_dir}") 81 82 # 3. 初始化qlib 83 qlib_config = config['qlib_init'] 84 qlib.init( 85 provider_uri=qlib_config['provider_uri'], 86 region=qlib_config['region'] 87 ) 88 print("✓ Qlib初始化完成") 89 print(f" 数据路径: {qlib_config['provider_uri']}") 90 91 # 4. 创建数据集 92 dataset_config = config['task']['dataset'] 93 dataset = init_instance_by_config(dataset_config) 94 print("✓ 数据集创建完成") 95 96 # 5. 训练模型 97 model_config = config['task']['model'] 98 model = init_instance_by_config(model_config) 99 100 print("开始训练模型...") 101 model.fit(dataset) 102 print("✓ 模型训练完成") 103 104 # 6. 评估模型 105 performance_stats = {} 106 try: 107 test_pred = model.predict(dataset, segment="test") 108 pred_mean = test_pred.mean().iloc[0] if hasattr(test_pred.mean(), 'iloc') else test_pred.mean() 109 pred_std = test_pred.std().iloc[0] if hasattr(test_pred.std(), 'iloc') else test_pred.std() 110 111 performance_stats = { 112 'pred_mean': float(pred_mean), 113 'pred_std': float(pred_std), 114 'pred_count': len(test_pred) 115 } 116 117 print(f"模型评估 - 预测均值: {pred_mean:.6f}, 标准差: {pred_std:.6f}") 118 except Exception as e: 119 print(f"⚠ 模型评估失败: {e}") 120 performance_stats = {'error': str(e)} 121 122 # 7. 保存模型(按目录结构) 123 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 124 125 # 模型文件路径 126 model_file = os.path.join(model_dir, f'{contract}_{freq}_model_{timestamp}.pkl') 127 latest_file = os.path.join(model_dir, f'{contract}_{freq}_latest.pkl') 128 129 # 保存模型 130 joblib.dump(model, model_file) 131 joblib.dump(model, latest_file) 132 133 print(f"✓ 模型已保存到: {model_file}") 134 print(f"✓ 最新模型: {latest_file}") 135 136 # 8. 保存详细的模型信息 137 model_info = { 138 'contract': contract, 139 'frequency': freq, 140 'timestamp': timestamp, 141 'model_type': config['task']['model']['class'], 142 'model_params': config['task']['model']['kwargs'], 143 'training_period': { 144 'train': config['task']['dataset']['kwargs']['segments']['train'], 145 'valid': config['task']['dataset']['kwargs']['segments']['valid'], 146 'test': config['task']['dataset']['kwargs']['segments']['test'] 147 }, 148 'data_config': { 149 'instruments': config['task']['dataset']['kwargs']['handler']['kwargs']['instruments'], 150 'features': get_features_from_handler(dataset), 151 'label': get_labels_from_handler(dataset) 152 }, 153 'performance': performance_stats, 154 'files': { 155 'model_file': model_file, 156 'latest_file': latest_file 157 } 158 } 159 160 # 保存模型信息 161 info_file = os.path.join(model_dir, f'{contract}_{freq}_info_{timestamp}.json') 162 latest_info_file = os.path.join(model_dir, f'{contract}_{freq}_latest_info.json') 163 164 with open(info_file, 'w', encoding='utf-8') as f: 165 json.dump(model_info, f, ensure_ascii=False, indent=2, default=json_serializer) 166 167 with open(latest_info_file, 'w', encoding='utf-8') as f: 168 json.dump(model_info, f, ensure_ascii=False, indent=2, default=json_serializer) 169 170 print(f"✓ 模型信息已保存: {info_file}") 171 172 # 9. 创建模型索引文件(便于管理多个模型) 173 index_file = os.path.join('models', 'model_index.json') 174 175 # 读取现有索引 176 if os.path.exists(index_file): 177 with open(index_file, 'r', encoding='utf-8') as f: 178 index = json.load(f) 179 else: 180 index = {} 181 182 # 更新索引 183 key = f"{contract}_{freq}" 184 index[key] = { 185 'contract': contract, 186 'frequency': freq, 187 'latest_model': latest_file, 188 'latest_info': latest_info_file, 189 'last_updated': timestamp, 190 'model_dir': model_dir 191 } 192 193 # 保存索引 194 with open(index_file, 'w', encoding='utf-8') as f: 195 json.dump(index, f, ensure_ascii=False, indent=2, default=json_serializer) 196 197 print(f"✓ 模型索引已更新: {index_file}") 198 199 print(f"\n🎉 模型训练和保存完成!") 200 print(f"📁 模型目录结构:") 201 print(f" models/") 202 print(f" ├── {contract}/") 203 print(f" │ └── {freq}/") 204 print(f" │ ├── {contract}_{freq}_latest.pkl") 205 print(f" │ ├── {contract}_{freq}_latest_info.json") 206 print(f" │ └── {contract}_{freq}_model_{timestamp}.pkl") 207 print(f" └── model_index.json") 208 print(f"\n💡 可以使用模型索引文件快速定位和加载模型") 209 210 return model, model_info 211 212 def batch_train(contracts, frequencies, start_time=None, end_time=None): 213 """ 214 批量训练多个合约和频率 215 216 Args: 217 contracts: 合约列表,如 ["EBL8.DCE", "RBL8.SHFE"] 218 frequencies: 频率列表,如 ["1min", "15min", "day"] 219 start_time: 开始时间(可选,默认从数据中自动获取) 220 end_time: 结束时间(可选,默认从数据中自动获取) 221 """ 222 print(f"=== 批量训练 ===") 223 print(f"合约: {contracts}") 224 print(f"频率: {frequencies}") 225 if start_time and end_time: 226 print(f"时间范围: {start_time} 到 {end_time}") 227 else: 228 print("时间范围: 自动从数据中获取") 229 230 results = {} 231 232 for contract in contracts: 233 results[contract] = {} 234 for freq in frequencies: 235 print(f"\n{'='*80}") 236 try: 237 model, model_info = train_model(contract, freq, start_time, end_time) 238 results[contract][freq] = "成功" 239 except Exception as e: 240 print(f"✗ {contract}-{freq} 训练失败: {e}") 241 results[contract][freq] = f"失败: {e}" 242 import traceback 243 traceback.print_exc() 244 245 # 打印汇总结果 246 print(f"\n{'='*80}") 247 print("=== 批量训练结果汇总 ===") 248 for contract, freq_results in results.items(): 249 print(f"\n{contract}:") 250 for freq, status in freq_results.items(): 251 print(f" {freq}: {status}") 252 253 return results 254 255 def main(): 256 """主函数""" 257 print("=== Qlib训练脚本(使用配置生成器)===") 258 259 try: 260 # 单个训练示例 261 contract = "EBL8.DCE" 262 freq = "day" 263 264 print(f"训练合约: {contract}") 265 print(f"训练频率: {freq}") 266 267 model, model_info = train_model(contract, freq) 268 print("\n✓ 训练成功完成!") 269 270 # 如果需要批量训练,取消下面的注释 271 # print("\n" + "="*80) 272 # batch_results = batch_train( 273 # contracts=["EBL8.DCE", "RBL8.SHFE"], 274 # frequencies=["15min", "day"] 275 # ) 276 277 except Exception as e: 278 print(f"❌ 训练失败: {e}") 279 import traceback 280 traceback.print_exc() 281 282 if __name__ == "__main__": 283 main()
qlib_config_generator.py
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Qlib配置生成器 - 集成自定义Handler和配置文件生成 5 """ 6 7 import pandas as pd 8 import numpy as np 9 import yaml 10 import os 11 from pathlib import Path 12 from qlib.data.dataset.handler import DataHandlerLP 13 from qlib.log import get_module_logger 14 import glob 15 from datetime import datetime 16 import qlib 17 from qlib.data import D 18 19 # ==================== 自定义Handler类 ==================== 20 21 class CustomTrendHandler_m1(DataHandlerLP): 22 """自定义因子处理器 - 1分钟""" 23 24 def __init__(self, **kwargs): 25 self.freq = kwargs.pop('freq', '1min') 26 kwargs.pop('inst_processors', None) 27 28 feature_config = self.get_feature_config() 29 label_config = self.get_label_config() 30 31 kwargs['data_loader'] = { 32 'class': 'QlibDataLoader', 33 'module_path': 'qlib.data.dataset.loader', 34 'kwargs': { 35 'freq': self.freq, 36 'config': { 37 'feature': feature_config, 38 'label': label_config 39 } 40 } 41 } 42 43 super().__init__(**kwargs) 44 self.logger = get_module_logger(self.__class__.__name__) 45 self.logger.info(f"CustomTrendHandler初始化完成,使用 {len(feature_config)} 个因子") 46 47 def get_feature_config(self): 48 """1分钟级别的因子配置""" 49 ma7 = "Mean(Mean($close, 7), 3)" 50 ma20 = "Mean(Mean($close, 20), 3)" 51 ma60 = "Mean(Mean($close, 60), 3)" 52 v5 = "Mean(Mean($volume, 5), 3)" 53 v15= "Mean(Mean($volume, 15), 3)" 54 v50= "Mean(Mean($volume, 50), 3)" 55 56 volume_fields = [ 57 f"$volume / {v50}", 58 f"Ref($volume,1) / Ref({v50},1)", 59 f"Ref($volume,2) / Ref({v50},2)", 60 f"Ref($volume,3) / Ref({v50},3)", 61 f"({v5} - {v15})/(Ref({v5} - {v15},1))", 62 f"({v5} - {v50})/(Ref({v5} - {v50},1))", 63 f"({v15} - {v50})/(Ref({v15} - {v50},1))", 64 ] 65 66 ma_fields = [ 67 f"{ma7} / {ma20}", 68 f"{ma7} / {ma60}", 69 f"{ma20} / {ma60}", 70 f"(({ma7} - {ma20}) / Ref({ma7} - {ma20}, 1))", 71 f"(({ma7} - {ma60}) / Ref({ma7} - {ma60}, 1))", 72 f"(({ma20} - {ma60}) / Ref({ma20} - {ma60}, 1))", 73 ] 74 75 current_body = "Abs($close - $open)" 76 prev_4_avg = "Mean(Ref(Abs($close - $open), 1), 4)" 77 78 price_fields = [ 79 f"{current_body} / {prev_4_avg}", 80 current_body, 81 f"{current_body} / $close", 82 "($close - $open) / $close", 83 "($high - Greater($close, $open)) / $close", 84 "(Less($close, $open) - $low) / $close", 85 f"($close - Ref($close, 3)) / Ref($close, 3)", 86 ] 87 88 return volume_fields + ma_fields + price_fields 89 90 def get_label_config(self): 91 return ["Ref($close, -10) / $close - 1"] 92 93 class CustomTrendHandler_m5(DataHandlerLP): 94 """自定义因子处理器 - 5分钟""" 95 96 def __init__(self, **kwargs): 97 self.freq = kwargs.pop('freq', '5min') 98 kwargs.pop('inst_processors', None) 99 100 feature_config = self.get_feature_config() 101 label_config = self.get_label_config() 102 103 kwargs['data_loader'] = { 104 'class': 'QlibDataLoader', 105 'module_path': 'qlib.data.dataset.loader', 106 'kwargs': { 107 'freq': self.freq, 108 'config': { 109 'feature': feature_config, 110 'label': label_config 111 } 112 } 113 } 114 115 super().__init__(**kwargs) 116 self.logger = get_module_logger(self.__class__.__name__) 117 118 def get_feature_config(self): 119 """5分钟级别的因子配置""" 120 ma5 = "Mean($close, 5)" 121 ma15 = "Mean($close, 15)" 122 ma30 = "Mean($close, 30)" 123 124 volume_fields = [ 125 "$volume / Mean($volume, 20)", 126 "($volume - Ref($volume, 1)) / Ref($volume, 1)", 127 "$volume / Mean(Mean($volume, 30), 3)", 128 ] 129 130 ma_fields = [ 131 f"({ma5} - Ref({ma5}, 1)) / {ma5}", 132 f"({ma15} - Ref({ma15}, 1)) / {ma15}", 133 f"({ma30} - Ref({ma30}, 1)) / {ma30}", 134 f"{ma5} / {ma15}", 135 f"{ma15} / {ma30}", 136 ] 137 138 price_fields = [ 139 "Abs($close - $open) / $close", 140 "($close - $open) / $close", 141 "($high - Greater($close, $open)) / $close", 142 "(Less($close, $open) - $low) / $close", 143 ] 144 145 return volume_fields + ma_fields + price_fields 146 147 def get_label_config(self): 148 return ["Ref($close, -5) / $close - 1"] 149 150 class CustomTrendHandler_m15(DataHandlerLP): 151 """自定义因子处理器 - 15分钟""" 152 153 def __init__(self, **kwargs): 154 self.freq = kwargs.pop('freq', '15min') 155 kwargs.pop('inst_processors', None) 156 print("CustomTrendHandler_m15 初始化,freq =", self.freq) 157 feature_config = self.get_feature_config() 158 label_config = self.get_label_config() 159 160 kwargs['data_loader'] = { 161 'class': 'QlibDataLoader', 162 'module_path': 'qlib.data.dataset.loader', 163 'kwargs': { 164 'freq': self.freq, 165 'config': { 166 'feature': feature_config, 167 'label': label_config 168 } 169 } 170 } 171 172 super().__init__(**kwargs) 173 self.logger = get_module_logger(self.__class__.__name__) 174 175 def get_feature_config(self): 176 """15分钟级别的因子配置""" 177 ma4 = "Mean($close, 4)" 178 ma12 = "Mean($close, 12)" 179 ma24 = "Mean($close, 24)" 180 181 volume_fields = [ 182 "$volume / Mean($volume, 16)", 183 "($volume - Ref($volume, 1)) / Ref($volume, 1)", 184 "$volume / Mean(Mean($volume, 20), 2)", 185 ] 186 187 ma_fields = [ 188 f"({ma4} - Ref({ma4}, 1)) / {ma4}", 189 f"({ma12} - Ref({ma12}, 1)) / {ma12}", 190 f"({ma24} - Ref({ma24}, 1)) / {ma24}", 191 f"{ma4} / {ma12}", 192 f"{ma12} / {ma24}", 193 ] 194 195 price_fields = [ 196 "Abs($close - $open) / $close", 197 "($close - $open) / $close", 198 "($high - Greater($close, $open)) / $close", 199 "(Less($close, $open) - $low) / $close", 200 "($close - Ref($close, 2)) / Ref($close, 2)", 201 ] 202 203 return volume_fields + ma_fields + price_fields 204 205 def get_label_config(self): 206 return ["Ref($close, -3) / $close - 1"] 207 208 class CustomTrendHandler_day(DataHandlerLP): 209 """自定义因子处理器 - 日线""" 210 211 def __init__(self, **kwargs): 212 self.freq = kwargs.pop('freq', 'day') 213 kwargs.pop('inst_processors', None) 214 print("CustomTrendHandler_day 初始化,freq =", self.freq) 215 feature_config = self.get_feature_config() 216 label_config = self.get_label_config() 217 218 kwargs['data_loader'] = { 219 'class': 'QlibDataLoader', 220 'module_path': 'qlib.data.dataset.loader', 221 'kwargs': { 222 'freq': self.freq, 223 'config': { 224 'feature': feature_config, 225 'label': label_config 226 } 227 } 228 } 229 230 super().__init__(**kwargs) 231 self.logger = get_module_logger(self.__class__.__name__) 232 233 def get_feature_config(self): 234 """日线级别的因子配置 - 修复的大涨回调策略因子""" 235 236 # 基础均线 237 ma5 = "Mean($close, 5)" 238 ma20 = "Mean($close, 20)" 239 ma60 = "Mean($close, 60)" 240 241 # 原有的基础因子 242 volume_fields = [ 243 "$volume / Mean($volume, 20)", 244 "($volume - Ref($volume, 1)) / Ref($volume, 1)", 245 "$volume / Mean($volume, 5)", 246 ] 247 248 ma_fields = [ 249 f"({ma5} - Ref({ma5}, 1)) / {ma5}", 250 f"({ma20} - Ref({ma20}, 1)) / {ma20}", 251 f"({ma60} - Ref({ma60}, 1)) / {ma60}", 252 f"{ma5} / {ma20}", 253 f"{ma20} / {ma60}", 254 ] 255 256 price_fields = [ 257 "Abs($close - $open) / $close", 258 "($close - $open) / $close", 259 "($high - Greater($close, $open)) / $close", 260 "(Less($close, $open) - $low) / $close", 261 "($close - Ref($close, 1)) / Ref($close, 1)", 262 ] 263 264 # 修复的大涨回调策略因子 265 strategy_fields = [ 266 # 1. 检测各天的涨幅 267 "(Ref($close, 1) - Ref($close, 2)) / Ref($close, 2)", # 1天前涨幅 268 "(Ref($close, 2) - Ref($close, 3)) / Ref($close, 3)", # 2天前涨幅 269 "(Ref($close, 3) - Ref($close, 4)) / Ref($close, 4)", # 3天前涨幅 270 "(Ref($close, 4) - Ref($close, 5)) / Ref($close, 5)", # 4天前涨幅 271 "(Ref($close, 5) - Ref($close, 6)) / Ref($close, 6)", # 5天前涨幅 272 273 # 2. 检测是否有大涨(>3%) 274 "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, 1, 0)", # 1天前大涨 275 "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, 1, 0)", # 2天前大涨 276 "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, 1, 0)", # 3天前大涨 277 "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, 1, 0)", # 4天前大涨 278 "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, 1, 0)", # 5天前大涨 279 280 # 3. 5天内是否有大涨 281 "If(((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03) | " + 282 "((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03) | " + 283 "((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03) | " + 284 "((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03) | " + 285 "((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03), 1, 0)", # 5天内有大涨 286 287 # 4. 找到最近的大涨日收盘价(从近到远检查) 288 "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, Ref($close, 1), " + 289 "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, Ref($close, 2), " + 290 "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, Ref($close, 3), " + 291 "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, Ref($close, 4), " + 292 "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, Ref($close, 5), $close)))))", # 最近大涨日收盘价 293 294 # 5. 大涨后是否有回调(检查大涨日后的最低价是否低于大涨日收盘价) 295 # 1天前大涨的情况 296 "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, " + 297 "If($close < Ref($close, 1), 1, 0), 0)", # 1天前大涨后今天回调 298 299 # 2天前大涨的情况 300 "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, " + 301 "If(Min($close, 2) < Ref($close, 2), 1, 0), 0)", # 2天前大涨后有回调 302 303 # 3天前大涨的情况 304 "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, " + 305 "If(Min($close, 3) < Ref($close, 3), 1, 0), 0)", # 3天前大涨后有回调 306 307 # 4天前大涨的情况 308 "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, " + 309 "If(Min($close, 4) < Ref($close, 4), 1, 0), 0)", # 4天前大涨后有回调 310 311 # 5天前大涨的情况 312 "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, " + 313 "If(Min($close, 5) < Ref($close, 5), 1, 0), 0)", # 5天前大涨后有回调 314 315 # 6. 综合:是否有大涨后回调 316 "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, If($close < Ref($close, 1), 1, 0), " + 317 "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, If(Min($close, 2) < Ref($close, 2), 1, 0), " + 318 "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, If(Min($close, 3) < Ref($close, 3), 1, 0), " + 319 "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, If(Min($close, 4) < Ref($close, 4), 1, 0), " + 320 "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, If(Min($close, 5) < Ref($close, 5), 1, 0), 0)))))", 321 322 # 7. 当前价格相对最近大涨日的比例(检查是否跌破80%) 323 "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, $close / Ref($close, 1), " + 324 "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, $close / Ref($close, 2), " + 325 "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, $close / Ref($close, 3), " + 326 "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, $close / Ref($close, 4), " + 327 "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, $close / Ref($close, 5), 1)))))", 328 329 # 8. 今日放量上涨 330 "If(($close > $open) & ($volume > Mean($volume, 5) * 1.2), 1, 0)", # 今日放量上涨 331 "($volume / Mean($volume, 5)) * If($close > $open, 1, 0)", # 上涨放量倍数 332 333 # 9. 完整策略信号 334 "If((" + 335 # 条件1:5天内有大涨且不是今天 336 "((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03) | " + 337 "((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03) | " + 338 "((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03) | " + 339 "((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03) | " + 340 "((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03)" + 341 ") & (" + 342 # 条件2:大涨后有回调 343 "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, If($close < Ref($close, 1), 1, 0), " + 344 "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, If(Min($close, 2) < Ref($close, 2), 1, 0), " + 345 "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, If(Min($close, 3) < Ref($close, 3), 1, 0), " + 346 "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, If(Min($close, 4) < Ref($close, 4), 1, 0), " + 347 "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, If(Min($close, 5) < Ref($close, 5), 1, 0), 0))))) == 1" + 348 ") & (" + 349 # 条件3:未跌破80% 350 "If((Ref($close, 1) - Ref($close, 2)) / Ref($close, 2) > 0.03, $close / Ref($close, 1), " + 351 "If((Ref($close, 2) - Ref($close, 3)) / Ref($close, 3) > 0.03, $close / Ref($close, 2), " + 352 "If((Ref($close, 3) - Ref($close, 4)) / Ref($close, 4) > 0.03, $close / Ref($close, 3), " + 353 "If((Ref($close, 4) - Ref($close, 5)) / Ref($close, 5) > 0.03, $close / Ref($close, 4), " + 354 "If((Ref($close, 5) - Ref($close, 6)) / Ref($close, 6) > 0.03, $close / Ref($close, 5), 1))))) >= 0.8" + 355 ") & (" + 356 # 条件4:今日放量上涨 357 "($close > $open) & ($volume > Mean($volume, 5) * 1.2)" + 358 "), 1, 0)", 359 360 # 10. 辅助因子 361 "Max(($close - Ref($close, 1)) / Ref($close, 1), 5)", # 5天内最大涨幅 362 "($close - Min($close, 5)) / Min($close, 5)", # 相对5天最低点涨幅 363 "(Max($close, 5) - $close) / Max($close, 5)", # 相对5天最高点回调幅度 364 365 # 11. 量价配合 366 "($close - $open) / $close * ($volume / Mean($volume, 5))", # 涨幅*放量倍数 367 "If($volume > Mean($volume, 5) * 1.2, ($close - $open) / $close, 0)", # 放量时的涨幅 368 ] 369 370 return volume_fields + ma_fields + price_fields + strategy_fields 371 372 373 374 375 def get_label_config(self): 376 return ["Ref($close, -1) / $close - 1"] 377 378 # ==================== 配置生成器 ==================== 379 380 class QlibConfigGenerator: 381 """Qlib配置生成器""" 382 383 def __init__(self): 384 self.base_data_path = "C:/ctp/q_qlib/qlib_data/cn_data" 385 self.source_data_path = "C:/ctp/q_qlib/cleaned_data" 386 387 def get_data_time_range(self, contract, freq): 388 """ 389 从qlib数据源中获取时间范围 390 391 Args: 392 contract: 合约名称,如 "EBL8.DCE" 393 freq: 频率,如 "1min", "5min", "15min", "day" 394 395 Returns: 396 tuple: (start_time, end_time, train_end, valid_start, valid_end) 397 """ 398 try: 399 # 频率映射 400 freq_mapping = { 401 "1min": "1min", 402 "5min": "5min", 403 "15min": "15min", 404 "day": "day" 405 } 406 407 qlib_freq = freq_mapping.get(freq, freq) 408 contract_path = f"{self.base_data_path}/{contract}" 409 410 print(f" 从qlib数据源读取: {contract_path}") 411 412 # 临时初始化qlib以读取数据 413 try: 414 qlib.init(provider_uri=contract_path, region="cn") 415 except: 416 # 如果已经初始化过,忽略错误 417 pass 418 419 # 使用qlib的D.features接口读取数据 420 # 读取一个简单的字段来获取时间索引 421 data = D.features([contract], ["$close"], freq=qlib_freq) 422 423 if data is None or len(data) == 0: 424 raise ValueError(f"未找到qlib数据: {contract} - {qlib_freq},请检查数据路径: {contract_path}") 425 426 # 获取时间索引 427 time_index = data.index.get_level_values('datetime') 428 start_time = time_index.min() 429 end_time = time_index.max() 430 total_rows = len(time_index.unique()) 431 432 print(f" 数据时间范围: {start_time} 到 {end_time}") 433 print(f" 总数据量: {total_rows} 行") 434 435 # 计算分割点:90%用于训练,10%用于验证和测试 436 train_rows = int(total_rows * 0.9) 437 438 # 获取排序后的唯一时间点 439 unique_times = sorted(time_index.unique()) 440 train_end_time = unique_times[train_rows - 1] 441 442 # 格式化时间为字符串 443 start_str = start_time.strftime('%Y-%m-%d') 444 end_str = end_time.strftime('%Y-%m-%d') 445 train_end_str = train_end_time.strftime('%Y-%m-%d') 446 447 # valid和test都使用剩余的10%数据 448 valid_start_str = train_end_str 449 valid_end_str = end_str 450 451 print(f" 时间分割:") 452 print(f" 训练集: {start_str} 到 {train_end_str} ({train_rows} 行, {train_rows/total_rows*100:.1f}%)") 453 print(f" 验证集: {valid_start_str} 到 {valid_end_str} ({total_rows-train_rows} 行, {(total_rows-train_rows)/total_rows*100:.1f}%)") 454 print(f" 测试集: {valid_start_str} 到 {valid_end_str} ({total_rows-train_rows} 行, {(total_rows-train_rows)/total_rows*100:.1f}%)") 455 456 return start_str, end_str, train_end_str, valid_start_str, valid_end_str 457 458 except Exception as e: 459 print(f"✗ 从qlib获取数据时间范围失败: {e}") 460 import traceback 461 traceback.print_exc() 462 raise RuntimeError(f"无法获取合约 {contract} 频率 {freq} 的数据时间范围,请检查数据是否存在") from e 463 464 def generate_config(self, contract, freq, start_time=None, end_time=None, handler_module=None): 465 """ 466 生成配置字典 467 468 Args: 469 contract: 合约名称,如 "EBL8.DCE" 470 freq: 频率,如 "1min", "5min", "15min", "day" 471 start_time: 开始时间(可选,默认从数据中自动获取) 472 end_time: 结束时间(可选,默认从数据中自动获取) 473 handler_module: Handler模块路径(可选,默认自动检测) 474 """ 475 # 如果没有提供时间范围,从数据中自动获取 476 if start_time is None or end_time is None: 477 print(f"自动获取 {contract} - {freq} 的数据时间范围...") 478 auto_start, auto_end, train_end, valid_start, valid_end = self.get_data_time_range(contract, freq) 479 start_time = start_time or auto_start 480 end_time = end_time or auto_end 481 else: 482 # 使用提供的时间范围,按90%-10%分割 483 print(f"使用指定时间范围: {start_time} 到 {end_time}") 484 start_dt = datetime.strptime(start_time, '%Y-%m-%d') 485 end_dt = datetime.strptime(end_time, '%Y-%m-%d') 486 total_days = (end_dt - start_dt).days 487 train_days = int(total_days * 0.9) 488 train_end_dt = start_dt + pd.Timedelta(days=train_days) 489 490 train_end = train_end_dt.strftime('%Y-%m-%d') 491 valid_start = train_end 492 valid_end = end_time 493 494 # 频率映射 495 freq_mapping = { 496 "1min": {"handler": "CustomTrendHandler_m1", "qlib_freq": "1min"}, 497 "5min": {"handler": "CustomTrendHandler_m5", "qlib_freq": "5min"}, 498 "15min": {"handler": "CustomTrendHandler_m15", "qlib_freq": "15min"}, 499 "day": {"handler": "CustomTrendHandler_day", "qlib_freq": "day"} 500 } 501 502 if freq not in freq_mapping: 503 raise ValueError(f"不支持的频率: {freq}") 504 505 handler_info = freq_mapping[freq] 506 contract_path = f"{self.base_data_path}/{contract}" 507 508 # 自动检测Handler模块路径 509 if handler_module is None: 510 import inspect 511 frame = inspect.currentframe() 512 try: 513 # 获取调用者的模块名 514 caller_frame = frame.f_back.f_back # 跳过generate_config和train_model 515 caller_module = caller_frame.f_globals.get('__name__', 'qlib_config_generator') 516 if caller_module == '__main__': 517 # 如果是主模块,尝试获取文件名 518 caller_file = caller_frame.f_globals.get('__file__', '') 519 if caller_file: 520 import os 521 module_name = os.path.splitext(os.path.basename(caller_file))[0] 522 handler_module = module_name 523 else: 524 handler_module = 'qlib_config_generator' 525 else: 526 handler_module = caller_module 527 finally: 528 del frame 529 530 # 简化方案:直接使用配置生成器模块 531 handler_module = "qlib_config_generator" 532 print(f" 使用Handler模块: {handler_module}") 533 534 config = { 535 "qlib_init": { 536 "provider_uri": contract_path, 537 "region": "cn" 538 }, 539 "task": { 540 "model": { 541 "class": "LGBModel", 542 "module_path": "qlib.contrib.model.gbdt", 543 "kwargs": { 544 "loss": "mse", 545 "colsample_bytree": 0.8879, 546 "learning_rate": 0.2, 547 "subsample": 0.8789, 548 "max_depth": 8, 549 "num_leaves": 210, 550 "num_threads": 20 551 } 552 }, 553 "dataset": { 554 "class": "DatasetH", 555 "module_path": "qlib.data.dataset", 556 "kwargs": { 557 "handler": { 558 "class": handler_info["handler"], 559 "module_path": handler_module, # 动态检测的模块路径 560 "kwargs": { 561 "start_time": start_time, 562 "end_time": end_time, 563 "instruments": [contract], 564 "freq": handler_info["qlib_freq"] 565 } 566 }, 567 "segments": { 568 "train": [start_time, train_end], 569 "valid": [valid_start, valid_end], 570 "test": [valid_start, valid_end] 571 } 572 } 573 }, 574 "record": [ 575 { 576 "class": "SignalRecord", 577 "module_path": "qlib.workflow.record_temp", 578 "kwargs": { 579 "model": "<MODEL>", 580 "dataset": "<DATASET>" 581 } 582 }, 583 { 584 "class": "SigAnaRecord", 585 "module_path": "qlib.workflow.record_temp", 586 "kwargs": { 587 "ana_long_short": False, 588 "ann_scaler": 252 589 } 590 } 591 ] 592 } 593 } 594 595 return config 596 597 def save_config_yaml(self, config, filename): 598 """保存配置为YAML文件""" 599 with open(filename, 'w', encoding='utf-8') as f: 600 yaml.dump(config, f, default_flow_style=False, allow_unicode=True, indent=2) 601 print(f"✓ 配置文件已保存: {filename}") 602 603 def generate_all_configs(self, contract): 604 """为指定合约生成所有频率的配置文件""" 605 frequencies = ["1min", "5min", "15min", "day"] 606 607 for freq in frequencies: 608 config = self.generate_config(contract, freq) 609 filename = f"config_{contract.replace('.', '_')}_{freq}.yaml" 610 self.save_config_yaml(config, filename) 611 612 print(f"✓ 已为合约 {contract} 生成所有配置文件") 613 614 # ==================== 使用示例 ==================== 615 616 def main(): 617 """主函数 - 演示如何使用""" 618 generator = QlibConfigGenerator() 619 620 # 生成单个配置(自动获取时间范围) 621 print("=== 生成单个配置文件(自动时间范围)===") 622 contract = "EBL8.DCE" 623 freq = "15min" 624 625 config = generator.generate_config(contract, freq) 626 generator.save_config_yaml(config, f"my_config_{freq}.yaml") 627 628 # 生成指定时间范围的配置 629 print("\n=== 生成指定时间范围的配置文件 ===") 630 config_custom = generator.generate_config(contract, freq, "2020-01-01", "2024-12-31") 631 generator.save_config_yaml(config_custom, f"my_config_{freq}_custom.yaml") 632 633 # 生成所有频率的配置 634 print("\n=== 生成所有频率配置文件 ===") 635 generator.generate_all_configs(contract) 636 637 print("\n=== 配置生成完成 ===") 638 print("✓ 自动时间范围:从数据中获取第一行到最后一行,90%训练,10%验证/测试") 639 print("✓ 现在你可以直接使用这些配置文件进行训练,无需额外的factors目录!") 640 641 if __name__ == "__main__": 642 main()
浙公网安备 33010602011771号