这步就比较简单,把训练后保存的模型重新加载进来,把要测试的数据给它进行预测。
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 期货合约模型测试脚本 - 专门用于测试期货交易信号 5 """ 6 7 import qlib 8 import yaml 9 import joblib 10 import pandas as pd 11 import numpy as np 12 import os 13 import json 14 from datetime import datetime 15 import matplotlib.pyplot as plt 16 import matplotlib.dates as mdates 17 from matplotlib import rcParams 18 19 # 设置中文字体 20 rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] 21 rcParams['axes.unicode_minus'] = False 22 23 def test_data_loading(contract="EBL8.DCE", freq="15min"): 24 """测试数据加载是否正常""" 25 print("=== 测试期货数据加载 ===") 26 27 # 1. 使用配置生成器生成配置 28 from qlib_config_generator import QlibConfigGenerator 29 generator = QlibConfigGenerator() 30 config = generator.generate_config(contract, freq) 31 print("✓ 配置生成完成") 32 33 # 2. 初始化qlib 34 qlib_config = config['qlib_init'] 35 qlib.init( 36 provider_uri=qlib_config['provider_uri'], 37 region=qlib_config['region'] 38 ) 39 print("✓ Qlib初始化完成") 40 41 # 3. 测试自定义Handler 42 try: 43 from qlib.utils import init_instance_by_config 44 dataset_config = config['task']['dataset'] 45 dataset = init_instance_by_config(dataset_config) 46 print("✓ 自定义Handler创建成功") 47 48 # 4. 获取合约信息 49 handler = dataset.handler 50 instruments = handler.instruments 51 print(f"✓ 期货合约: {instruments}") 52 print(f"Handler类型: {type(handler)}") 53 54 # 5. 显示因子配置 55 if hasattr(handler, 'get_feature_config'): 56 features = handler.get_feature_config() 57 print(f"✓ 因子配置获取成功,共 {len(features)} 个因子:") 58 for i, feature in enumerate(features, 1): 59 print(f" {i}. {feature}") 60 61 if hasattr(handler, 'get_label_config'): 62 labels = handler.get_label_config() 63 print(f"✓ 标签配置: {labels}") 64 65 # 6. 测试数据形状 66 try: 67 train_data = dataset.prepare("train") 68 valid_data = dataset.prepare("valid") 69 test_data = dataset.prepare("test") 70 71 # 安全地检查数据形状 72 def get_data_shape(data): 73 try: 74 if data is None: 75 return 'None' 76 if isinstance(data, (list, tuple)) and len(data) > 0: 77 return data[0].shape 78 elif hasattr(data, 'shape'): 79 return data.shape 80 else: 81 return f'Unknown type: {type(data)}' 82 except Exception as e: 83 return f'Error: {e}' 84 85 print(f"✓ 训练数据形状: {get_data_shape(train_data)}") 86 print(f"✓ 验证数据形状: {get_data_shape(valid_data)}") 87 print(f"✓ 测试数据形状: {get_data_shape(test_data)}") 88 89 # 显示数据时间范围 90 try: 91 if train_data is not None and isinstance(train_data, (list, tuple)) and len(train_data) > 0: 92 train_df = train_data[0] 93 if hasattr(train_df, 'index') and hasattr(train_df.index, 'get_level_values'): 94 dates = train_df.index.get_level_values('datetime') 95 print(f"✓ 训练数据时间范围: {dates.min()} 到 {dates.max()}") 96 except Exception as e: 97 print(f"⚠ 无法获取时间范围: {e}") 98 99 except Exception as e: 100 print(f"⚠ 数据准备失败: {e}") 101 return False 102 103 return True 104 105 except Exception as e: 106 print(f"❌ 数据加载测试失败: {e}") 107 import traceback 108 traceback.print_exc() 109 return False 110 111 def load_and_test_model(contract="EBL8.DCE", freq="15min"): 112 """加载并测试期货模型""" 113 print("=== 测试期货模型 ===") 114 115 # 1. 使用配置生成器生成配置 116 from qlib_config_generator import QlibConfigGenerator 117 generator = QlibConfigGenerator() 118 config = generator.generate_config(contract, freq) 119 print("✓ 配置生成完成") 120 121 # 2. 从配置获取合约信息 122 instruments = config['task']['dataset']['kwargs']['handler']['kwargs']['instruments'] 123 qlib_freq = config['task']['dataset']['kwargs']['handler']['kwargs']['freq'] 124 125 if isinstance(instruments, list) and len(instruments) > 0: 126 instrument = instruments[0] # 取第一个合约 127 else: 128 instrument = instruments 129 130 # 3. 构建模型文件路径 131 model_file = os.path.join('models', instrument, freq, f'{instrument}_{freq}_latest.pkl') 132 if not os.path.exists(model_file): 133 print(f"❌ 模型文件不存在: {model_file}") 134 print("请先运行 qlib_train_with_generator.py 训练模型") 135 return None 136 137 # 4. 初始化qlib 138 qlib_config = config['qlib_init'] 139 qlib.init( 140 provider_uri=qlib_config['provider_uri'], 141 region=qlib_config['region'] 142 ) 143 print("✓ Qlib初始化完成") 144 145 # 5. 加载模型 146 model = joblib.load(model_file) 147 print(f"✓ 模型加载成功: {model_file}") 148 print(f"模型类型: {type(model)}") 149 150 # 6. 加载模型信息 151 try: 152 info_file = os.path.join('models', instrument, freq, f'{instrument}_{freq}_latest_info.json') 153 with open(info_file, 'r', encoding='utf-8') as f: 154 model_info = json.load(f) 155 print(f"✓ 模型训练时间: {model_info['timestamp']}") 156 print(f"✓ 特征数量: {len(model_info['data_config']['features'])}") 157 print(f"✓ 合约: {model_info['data_config']['instruments']}") 158 except Exception as e: 159 print(f"⚠ 无法加载模型信息: {e}") 160 161 # 7. 重新创建dataset用于预测 162 print("创建dataset用于预测...") 163 from qlib.utils import init_instance_by_config 164 dataset_config = config['task']['dataset'] 165 dataset = init_instance_by_config(dataset_config) 166 167 # 8. 测试预测 168 print("测试模型预测...") 169 predictions = model.predict(dataset, segment="test") 170 print(f"✓ 预测成功,结果形状: {predictions.shape}") 171 172 # 9. 验证预测准确性 - 直接对比最高预测点 173 validate_top_predictions_simple(predictions, dataset, qlib_freq) 174 175 # 10. 分析预测结果 176 # analyze_futures_predictions(predictions) 177 178 # # 11. 生成交易信号和收益图表 179 # generate_trading_signals(predictions) 180 181 return model 182 183 def validate_top_predictions(predictions, dataset): 184 """验证预测收益最高的10个时间点""" 185 print("\n=== 验证预测准确性 ===") 186 187 try: 188 pred_df = predictions.reset_index() 189 score_col = pred_df.columns[-1] 190 191 # 确保有时间列 192 if 'datetime' not in pred_df.columns: 193 print("⚠ 预测数据中没有时间信息,无法进行时间点验证") 194 return 195 196 pred_df['datetime'] = pd.to_datetime(pred_df['datetime']) 197 pred_df = pred_df.sort_values('datetime') 198 199 # 找出预测收益最高和最低的10个时间点 200 top_predictions = pred_df.nlargest(10, score_col) 201 bottom_predictions = pred_df.nsmallest(10, score_col) 202 203 print(f"预测收益最高的10个时间点验证 (做多信号):") 204 print(f"{'序号':<4} {'时间':<20} {'预测收益':<12} {'验证结果'}") 205 print("-" * 70) 206 207 # 尝试获取原始价格数据进行验证 208 try: 209 from qlib.data import D 210 211 # 获取合约信息 212 handler = dataset.handler 213 instruments = handler.instruments 214 215 # 获取价格数据 (收盘价) 216 start_time = pred_df['datetime'].min() 217 end_time = pred_df['datetime'].max() + pd.Timedelta(days=5) # 多取几天数据 218 219 price_data = D.features( 220 instruments, 221 ["$close"], 222 start_time=start_time, 223 end_time=end_time 224 ) 225 226 if price_data is not None and not price_data.empty: 227 price_df = price_data.reset_index() 228 price_df['datetime'] = pd.to_datetime(price_df['datetime']) 229 price_df = price_df.sort_values('datetime') 230 231 correct_count = 0 232 total_verified = 0 233 234 # 验证每个高预测点 235 for i, (idx, row) in enumerate(top_predictions.iterrows(), 1): 236 pred_time = row['datetime'] 237 pred_score = row[score_col] 238 239 # 找到预测时间点的价格 240 current_price_row = price_df[price_df['datetime'] == pred_time] 241 242 if not current_price_row.empty: 243 current_price = current_price_row.iloc[0]['close'] 244 245 # 找到后续几个交易日的价格 246 future_prices = price_df[price_df['datetime'] > pred_time].head(3) 247 248 if not future_prices.empty: 249 # 计算实际收益率 250 next_price = future_prices.iloc[0]['close'] 251 actual_return = (next_price - current_price) / current_price 252 253 # 判断预测是否正确 254 pred_direction = "涨" if pred_score > 0 else "跌" 255 actual_direction = "涨" if actual_return > 0 else "跌" 256 correct = "✓" if (pred_score > 0) == (actual_return > 0) else "✗" 257 258 if (pred_score > 0) == (actual_return > 0): 259 correct_count += 1 260 total_verified += 1 261 262 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} {correct} 预测{pred_direction} 实际{actual_direction} ({actual_return:+.4f})") 263 264 # 显示后续价格变化 265 print(f" 价格变化: {current_price:.2f} → ", end="") 266 for j, future_row in future_prices.iterrows(): 267 future_price = future_row['close'] 268 change = (future_price - current_price) / current_price * 100 269 print(f"{future_price:.2f}({change:+.2f}%) ", end="") 270 print() 271 else: 272 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} ? 无后续价格数据") 273 else: 274 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} ? 找不到当日价格") 275 276 # 统计高预测点准确率 277 if total_verified > 0: 278 accuracy = correct_count / total_verified * 100 279 print(f"\n高预测点验证结果: {correct_count}/{total_verified} = {accuracy:.1f}%") 280 281 if accuracy >= 70: 282 print("🎉 做多预测质量: 优秀") 283 elif accuracy >= 60: 284 print("👍 做多预测质量: 良好") 285 elif accuracy >= 50: 286 print("👌 做多预测质量: 一般") 287 else: 288 print("👎 做多预测质量: 较差") 289 else: 290 print("⚠ 无法验证任何高预测点") 291 292 # 验证最低预测点 (做空信号) 293 print(f"\n" + "="*70) 294 print(f"预测收益最低的10个时间点验证 (做空信号):") 295 print(f"{'序号':<4} {'时间':<20} {'预测收益':<12} {'验证结果'}") 296 print("-" * 70) 297 298 correct_count_low = 0 299 total_verified_low = 0 300 301 # 验证每个低预测点 302 for i, (idx, row) in enumerate(bottom_predictions.iterrows(), 1): 303 pred_time = row['datetime'] 304 pred_score = row[score_col] 305 306 # 找到预测时间点的价格 307 current_price_row = price_df[price_df['datetime'] == pred_time] 308 309 if not current_price_row.empty: 310 current_price = current_price_row.iloc[0]['close'] 311 312 # 找到后续几个交易时间点的价格 313 future_prices = price_df[price_df['datetime'] > pred_time].head(3) 314 315 if not future_prices.empty: 316 # 计算实际收益率 317 next_price = future_prices.iloc[0]['close'] 318 actual_return = (next_price - current_price) / current_price 319 320 # 判断预测是否正确 (负预测应该对应负收益) 321 pred_direction = "涨" if pred_score > 0 else "跌" 322 actual_direction = "涨" if actual_return > 0 else "跌" 323 correct = "✓" if (pred_score > 0) == (actual_return > 0) else "✗" 324 325 if (pred_score > 0) == (actual_return > 0): 326 correct_count_low += 1 327 total_verified_low += 1 328 329 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} {correct} 预测{pred_direction} 实际{actual_direction} ({actual_return:+.4f})") 330 331 # 显示后续价格变化 332 print(f" 价格变化: {current_price:.2f} → ", end="") 333 for j, future_row in future_prices.iterrows(): 334 future_price = future_row['close'] 335 change = (future_price - current_price) / current_price * 100 336 print(f"{future_price:.2f}({change:+.2f}%) ", end="") 337 print() 338 else: 339 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} ? 无后续价格数据") 340 else: 341 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} ? 找不到当日价格") 342 343 # 统计低预测点准确率 344 if total_verified_low > 0: 345 accuracy_low = correct_count_low / total_verified_low * 100 346 print(f"\n低预测点验证结果: {correct_count_low}/{total_verified_low} = {accuracy_low:.1f}%") 347 348 if accuracy_low >= 70: 349 print("🎉 做空预测质量: 优秀") 350 elif accuracy_low >= 60: 351 print("👍 做空预测质量: 良好") 352 elif accuracy_low >= 50: 353 print("👌 做空预测质量: 一般") 354 else: 355 print("👎 做空预测质量: 较差") 356 else: 357 print("⚠ 无法验证任何低预测点") 358 359 # 综合统计 360 total_all_verified = total_verified + total_verified_low 361 total_all_correct = correct_count + correct_count_low 362 if total_all_verified > 0: 363 overall_accuracy = total_all_correct / total_all_verified * 100 364 print(f"\n综合验证结果: {total_all_correct}/{total_all_verified} = {overall_accuracy:.1f}%") 365 print(f"预测值与实际收益的数量级差异分析:") 366 print(f" - 这可能表明模型预测的是相对强弱,而非绝对收益率") 367 print(f" - 建议关注预测方向的准确性,而非绝对数值") 368 369 else: 370 print("⚠ 无法获取价格数据进行验证") 371 # 如果无法获取价格数据,至少显示预测结果 372 for i, (idx, row) in enumerate(top_predictions.iterrows(), 1): 373 pred_time = row['datetime'] 374 pred_score = row[score_col] 375 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} (无法验证)") 376 377 # 也显示最低预测点 378 print(f"\n" + "="*70) 379 print(f"预测收益最低的10个时间点 (做空信号):") 380 print(f"{'序号':<4} {'时间':<20} {'预测收益':<12} {'验证结果'}") 381 print("-" * 70) 382 for i, (idx, row) in enumerate(bottom_predictions.iterrows(), 1): 383 pred_time = row['datetime'] 384 pred_score = row[score_col] 385 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} (无法验证)") 386 387 except Exception as e: 388 print(f"⚠ 获取价格数据失败: {e}") 389 # 如果无法获取价格数据,至少显示预测结果 390 for i, (idx, row) in enumerate(top_predictions.iterrows(), 1): 391 pred_time = row['datetime'] 392 pred_score = row[score_col] 393 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} (无法验证)") 394 395 # 也显示最低预测点 396 print(f"\n" + "="*70) 397 print(f"预测收益最低的10个时间点 (做空信号):") 398 print(f"{'序号':<4} {'时间':<20} {'预测收益':<12} {'验证结果'}") 399 print("-" * 70) 400 for i, (idx, row) in enumerate(bottom_predictions.iterrows(), 1): 401 pred_time = row['datetime'] 402 pred_score = row[score_col] 403 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} (无法验证)") 404 405 except Exception as e: 406 print(f"⚠ 预测验证失败: {e}") 407 408 409 410 def main(): 411 """主函数""" 412 try: 413 print("=== 期货合约模型测试 ===") 414 415 # 配置测试参数 416 contract = "EBL8.DCE" 417 freq = "day" 418 419 print(f"测试合约: {contract}") 420 print(f"测试频率: {freq}") 421 422 # 1. 先测试数据加载 423 print("\n第一步:测试数据加载...") 424 data_ok = test_data_loading(contract, freq) 425 426 if not data_ok: 427 print("❌ 数据加载测试失败,请检查配置和数据") 428 return 429 430 print("\n" + "="*60) 431 432 # 2. 再测试模型加载(如果模型存在) 433 print("第二步:测试模型加载和预测...") 434 model = load_and_test_model(contract, freq) 435 436 if model is not None: 437 print(f"\n🎉 期货模型测试完成!") 438 print(f"\n💡 实盘交易建议:") 439 print(f"1. 根据预测信号调整仓位") 440 print(f"2. 设置合理的止损止盈") 441 print(f"3. 控制单次交易风险") 442 print(f"4. 定期重新训练模型") 443 else: 444 print(f"\n⚠ 数据加载正常,但模型不存在") 445 print(f"请先运行 qlib_train_with_generator.py 训练模型") 446 447 except Exception as e: 448 print(f"❌ 测试失败: {e}") 449 import traceback 450 traceback.print_exc() 451 452 def validate_top_predictions_simple(predictions, dataset, qlib_freq): 453 """验证预测收益最高和最低的10个时间点 - 简化版""" 454 print("\n=== 验证预测准确性 ===") 455 456 try: 457 pred_df = predictions.reset_index() 458 score_col = pred_df.columns[-1] 459 460 if 'datetime' not in pred_df.columns: 461 print("⚠ 预测数据中没有时间信息") 462 return 463 464 pred_df['datetime'] = pd.to_datetime(pred_df['datetime']) 465 pred_df = pred_df.sort_values('datetime') 466 467 # 直接从qlib获取价格数据 468 try: 469 from qlib.data import D 470 handler = dataset.handler 471 instruments = handler.instruments 472 473 start_time = pred_df['datetime'].min() 474 end_time = pred_df['datetime'].max() + pd.Timedelta(hours=1) 475 476 print(f"从qlib获取价格数据: {instruments}") 477 print(f"使用频率: {qlib_freq}") 478 print(f"时间范围: {start_time} 到 {end_time}") 479 480 # 使用配置中的频率获取收盘价数据 481 price_data = D.features(instruments, ["$close"], start_time=start_time, end_time=end_time, freq=qlib_freq) 482 483 if price_data is None or price_data.empty: 484 print("⚠ 无法获取价格数据") 485 return 486 487 price_df = price_data.reset_index() 488 price_df['datetime'] = pd.to_datetime(price_df['datetime']) 489 price_df = price_df.sort_values('datetime') 490 price_col = '$close' # qlib返回的列名是$close 491 492 print(f"获取到 {len(price_df)} 条价格数据") 493 print(f"实际时间范围: {price_df['datetime'].min()} 到 {price_df['datetime'].max()}") 494 print(f"价格列名: {price_col}") 495 496 except Exception as e: 497 print(f"⚠ 获取价格数据失败: {e}") 498 return 499 500 # 验证最高预测点 501 top_predictions = pred_df.nlargest(10, score_col) 502 print(f"预测收益最高的10个时间点:") 503 print(f"{'序号':<4} {'时间':<20} {'预测收益':<12} {'当前价格':<10} {'下期价格':<10} {'实际收益':<12} {'结果'}") 504 print("-" * 90) 505 506 close_index_pos=1 507 correct_high = 0 508 for i, (_, row) in enumerate(top_predictions.iterrows(), 1): 509 pred_time = row['datetime'] 510 pred_score = row[score_col] 511 512 # 找当前价格 513 current_row = price_df[price_df['datetime'] == pred_time] 514 if current_row.empty: 515 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} 找不到价格") 516 continue 517 518 current_price = current_row.iloc[0][price_col] 519 520 # 找下一根K线价格 521 next_row = price_df[price_df['datetime'] > pred_time].head(close_index_pos) 522 if next_row.empty: 523 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} {current_price:<10.2f} 无下期数据") 524 continue 525 526 next_price = next_row.iloc[close_index_pos-1][price_col] 527 actual_return = (next_price - current_price) / current_price 528 529 # 判断方向是否正确 530 correct = "✓" if (pred_score > 0) == (actual_return > 0) else "✗" 531 if (pred_score > 0) == (actual_return > 0): 532 correct_high += 1 533 534 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} {current_price:<10.2f} {next_price:<10.2f} {actual_return:+.6f} {correct}") 535 536 # 验证最低预测点 537 bottom_predictions = pred_df.nsmallest(10, score_col) 538 print(f"\n预测收益最低的10个时间点:") 539 print(f"{'序号':<4} {'时间':<20} {'预测收益':<12} {'当前价格':<10} {'下期价格':<10} {'实际收益':<12} {'结果'}") 540 print("-" * 90) 541 542 correct_low = 0 543 for i, (_, row) in enumerate(bottom_predictions.iterrows(), 1): 544 pred_time = row['datetime'] 545 pred_score = row[score_col] 546 547 # 找当前价格 548 current_row = price_df[price_df['datetime'] == pred_time] 549 if current_row.empty: 550 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} 找不到价格") 551 continue 552 553 current_price = current_row.iloc[0][price_col] 554 555 # 找下一根K线价格 556 next_row = price_df[price_df['datetime'] > pred_time].head(close_index_pos) 557 if next_row.empty: 558 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} {current_price:<10.2f} 无下期数据") 559 continue 560 561 next_price = next_row.iloc[close_index_pos-1][price_col] 562 actual_return = (next_price - current_price) / current_price 563 564 # 判断方向是否正确 565 correct = "✓" if (pred_score > 0) == (actual_return > 0) else "✗" 566 if (pred_score > 0) == (actual_return > 0): 567 correct_low += 1 568 569 print(f"{i:<4} {pred_time.strftime('%Y-%m-%d %H:%M:%S'):<20} {pred_score:+.6f} {current_price:<10.2f} {next_price:<10.2f} {actual_return:+.6f} {correct}") 570 571 # 统计结果 572 total_correct = correct_high + correct_low 573 total_tested = 20 # 最多20个点 574 if total_tested > 0: 575 accuracy = total_correct / total_tested * 100 576 print(f"\n验证结果: {total_correct}/{total_tested} = {accuracy:.1f}%") 577 print(f"做多准确率: {correct_high}/10 = {correct_high*10:.1f}%") 578 print(f"做空准确率: {correct_low}/10 = {correct_low*10:.1f}%") 579 580 except Exception as e: 581 print(f"⚠ 验证失败: {e}") 582 583 if __name__ == "__main__": 584 main()
浙公网安备 33010602011771号