最近关注了一个AI量化投资框架qlib,略做笔记,因为某些无法实现原因,不得不放弃实战。
首先这个框架仅仅是对特征有效,并没有其它关联,比如有些股是有时间规律的,比如农产品,还有一些有环境 规律 ,比如A股一般是权重先涨,稳住大盘后,换成小盘股唱戏。还有一些是政策类的,比如出台了某些可能是较长影响的政策利好某板块的。
当然这些可以自己在训练的时候当成因子写进去,这个就比较花时间和精力去做了。
首先把源码从github git下来 地址:https://github.com/microsoft/qlib
有个注意情况 ,项目主目录和源码目录不要用qlib做为目录 ,改个名,比如 qlib_src这样的。
是python环境 ,我用的是poetry 搞的虚拟 环境 ,后面所有都是在当前项目的虚拟环境中进行。
然后进入源码目录安装打包安装以及相关依赖 pip install -e .
然后开始相关准备,最开始是准备数据,它的示例中是从yahoo财经下载的,这不是开玩笑嘛,不过有个捷径,自行下载:https://github.com/chenditc/investment_data/releases
说实话,如果要是搞的不是A股,请自行准备数据,这个我花了很长时间,因为qlib对数据的要求太高了,或太多不支持了。
我从通达信获取的数据,不仅有日线,还有分时的,不得不说,通达信还是很不错的,无偿提供。
拿到数据,需要清洗,以下是清洗脚本:
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 简单的CSV数据清洗脚本 - 为qlib准备数据 5 """ 6 7 import pandas as pd 8 import os 9 10 def convert_filename(original_filename): 11 """转换文件名格式:品种代码.市场代码.周期.csv -> 品种代码.交易所代码.周期.csv""" 12 # 市场代码到交易所代码的映射 13 market_mapping = { 14 '30': 'SHFE', # 上海期货交易所 15 '28': 'CZCE', # 郑州商品交易所 16 '29': 'DCE', # 大连商品交易所 17 '31': 'INE', # 上海国际能源交易中心 18 '66': 'GFEX', # 广州期货交易所 19 '42': 'ZS' # 指数 20 } 21 22 parts = original_filename.split('.') 23 if len(parts) >= 4: # 品种代码.市场代码.周期.csv 24 symbol = parts[0] 25 market_code = parts[1] 26 period = parts[2] 27 extension = parts[3] 28 29 # 转换市场代码 30 exchange_code = market_mapping.get(market_code, market_code) 31 new_filename = f"{symbol}.{exchange_code}.{period}.{extension}" 32 33 print(f"文件名转换: {original_filename} -> {new_filename}") 34 return new_filename 35 36 # 如果格式不匹配,返回原文件名 37 return original_filename 38 39 def format_time_column(time_value): 40 """格式化时间列:905 -> 09:05""" 41 time_str = str(int(time_value)).zfill(4) # 确保4位数字,前面补0 42 return f"{time_str[:2]}:{time_str[2:]}" 43 44 def clean_csv_file(input_file, output_file): 45 """清洗单个CSV文件""" 46 print(f"正在处理: {input_file}") 47 48 # 读取原始文件 49 with open(input_file, 'r', encoding='gbk') as f: 50 lines = f.readlines() 51 52 # 去掉第一行(提示信息)和最后一行 53 data_lines = lines[1:-1] 54 55 # 处理标题行(第二行,现在是第一行) 56 header_line = data_lines[0].strip() 57 58 # 中文到英文的映射 59 header_mapping = { 60 '日期': 'date', 61 '时间': 'time', 62 '开盘': 'open', 63 '最高': 'high', 64 '最低': 'low', 65 '收盘': 'close', 66 '成交量': 'volume', 67 '持仓量': 'open_interest', 68 '结算价': 'settlement' 69 } 70 71 # 将tab分隔的标题转为英文逗号分隔 72 headers = [h.strip() for h in header_line.split('\t')] 73 english_headers = [header_mapping.get(h, h) for h in headers] 74 75 # 创建输出目录 76 os.makedirs(os.path.dirname(output_file), exist_ok=True) 77 78 # 先写入临时文件,然后格式化时间列 79 temp_file = output_file + '.tmp' 80 with open(temp_file, 'w', encoding='utf-8') as f: 81 # 写入英文标题 82 f.write(','.join(english_headers) + '\n') 83 # 写入数据行 84 for line in data_lines[1:]: 85 f.write(line) 86 87 # 读取临时文件并格式化时间列 88 df = pd.read_csv(temp_file, encoding='utf-8') 89 90 # 如果有时间列,格式化时间 91 if 'time' in df.columns: 92 df['time'] = df['time'].apply(format_time_column) 93 94 # 关键修改:如果是分钟级数据(1m或5m),合并日期和时间字段为datetime 95 filename = os.path.basename(output_file) 96 is_minute_data = any(period in filename for period in ['1m', '5m']) 97 98 if is_minute_data and 'date' in df.columns and 'time' in df.columns: 99 period_type = '1分钟' if '1m' in filename else '5分钟' 100 print(f"检测到{period_type}原始数据,正在合并日期和时间字段...") 101 102 # 合并日期和时间 103 df['datetime'] = pd.to_datetime(df['date'].astype(str) + ' ' + df['time'].astype(str)) 104 105 # 添加instrument列(从文件名提取) 106 instrument_name = filename.split('.')[0] + '.' + filename.split('.')[1] # 如RBL8.SHFE 107 df['instrument'] = instrument_name 108 109 # 重新排列列顺序,符合qlib标准格式 110 qlib_columns = ['datetime', 'instrument'] 111 for col in ['open', 'high', 'low', 'close', 'volume']: 112 if col in df.columns: 113 qlib_columns.append(col) 114 115 # 添加其他列 116 for col in df.columns: 117 if col not in qlib_columns and col not in ['date', 'time']: 118 qlib_columns.append(col) 119 120 df = df[qlib_columns] 121 122 print(f"已合并datetime字段,数据形状: {df.shape}") 123 print(f"时间范围: {df['datetime'].min()} 到 {df['datetime'].max()}") 124 125 # 如果是日线数据,也需要处理datetime格式 126 elif 'date' in df.columns and 'time' not in df.columns: 127 print("检测到日线数据,正在处理日期格式...") 128 129 # 日线数据只有日期,添加默认时间 130 df['datetime'] = pd.to_datetime(df['date'].astype(str) + ' 15:00:00') 131 132 # 添加instrument列 133 filename = os.path.basename(output_file) 134 instrument_name = filename.split('.')[0] + '.' + filename.split('.')[1] 135 df['instrument'] = instrument_name 136 137 # 重新排列列顺序 138 qlib_columns = ['datetime', 'instrument'] 139 for col in ['open', 'high', 'low', 'close', 'volume']: 140 if col in df.columns: 141 qlib_columns.append(col) 142 143 # 添加其他列 144 for col in df.columns: 145 if col not in qlib_columns and col not in ['date', 'time']: 146 qlib_columns.append(col) 147 148 df = df[qlib_columns] 149 150 print(f"已处理日线数据,数据形状: {df.shape}") 151 print(f"时间范围: {df['datetime'].min()} 到 {df['datetime'].max()}") 152 153 # 保存最终文件 154 df.to_csv(output_file, index=False, encoding='utf-8') 155 156 # 删除临时文件 157 os.remove(temp_file) 158 159 print(f"已保存到: {output_file}") 160 return output_file 161 162 def find_csv_files(directory='.'): 163 """查找目录中的所有CSV文件""" 164 csv_files = [] 165 for file in os.listdir(directory): 166 if file.endswith('.csv') and os.path.isfile(file): 167 csv_files.append(file) 168 return sorted(csv_files) 169 170 def main(): 171 """主函数""" 172 # 自动查找所有CSV文件 173 csv_files = find_csv_files() 174 175 if not csv_files: 176 print("当前目录没有找到CSV文件") 177 return 178 179 print(f"找到 {len(csv_files)} 个CSV文件:") 180 for file in csv_files: 181 print(f" - {file}") 182 183 # 创建输出目录 184 output_dir = 'cleaned_data' 185 os.makedirs(output_dir, exist_ok=True) 186 187 # 清洗所有文件 188 cleaned_files = [] 189 for input_file in csv_files: 190 try: 191 # 转换文件名 192 new_filename = convert_filename(input_file) 193 output_file = os.path.join(output_dir, new_filename) 194 cleaned_file = clean_csv_file(input_file, output_file) 195 cleaned_files.append(os.path.basename(cleaned_file)) 196 except Exception as e: 197 print(f"处理文件 {input_file} 时出错: {e}") 198 199 # 从5分钟数据生成更高周期数据(15m, 30m, 60m) 200 print(f"\n开始从5分钟数据生成更高周期数据...") 201 all_generated_files = [] 202 for cleaned_file in cleaned_files: 203 file_path = os.path.join(output_dir, cleaned_file) 204 generated = generate_higher_timeframes(file_path, output_dir) 205 all_generated_files.extend(generated) 206 207 print(f"\n数据清洗完成!") 208 print(f"清洗后的文件保存在: {output_dir} 目录") 209 print(f"原始数据文件: {len(cleaned_files)} 个 (1m, 5m, 日线)") 210 print(f"合成的高周期文件: {len(all_generated_files)} 个 (15m, 30m, 60m)") 211 print(f"总计: {len(cleaned_files + all_generated_files)} 个文件") 212 213 def is_session_start_time(time_str): 214 """判断是否是交易时段开始时间(09:05或21:05)""" 215 if isinstance(time_str, str) and ':' in time_str: 216 return time_str in ['09:05', '21:05'] 217 else: 218 time_num = int(time_str) 219 return time_num in [905, 2105] 220 221 def detect_trading_sessions(df): 222 """检测交易时段边界""" 223 sessions = [] 224 current_session = [] 225 226 # 检查是否有datetime列 227 has_datetime_column = 'datetime' in df.columns 228 229 for i, row in df.iterrows(): 230 if has_datetime_column: 231 # 从datetime列提取时间 232 datetime_obj = pd.to_datetime(row['datetime']) 233 time_str = datetime_obj.strftime('%H:%M') 234 else: 235 # 使用原有的time列 236 time_str = row['time'] 237 238 if i == 0: 239 current_session = [i] 240 else: 241 # 如果当前时间是新时段开始时间,结束上一个时段 242 if is_session_start_time(time_str) and len(current_session) > 0: 243 sessions.append(current_session) 244 current_session = [i] 245 else: 246 current_session.append(i) 247 248 if current_session: 249 sessions.append(current_session) 250 251 return sessions 252 253 def resample_session_data(df_session, period_minutes, base_minutes=5, has_datetime_column=False): 254 """按你的逻辑进行重采样""" 255 if len(df_session) == 0: 256 return [] 257 258 # 根据周期确定每组K线个数 259 klines_per_group = period_minutes // base_minutes # 如:15分钟÷5分钟=3根,15分钟÷1分钟=15根 260 261 resampled_data = [] 262 i = 0 263 264 while i < len(df_session): 265 # 收集当前组的K线 266 group_indices = [] 267 268 # 正常情况:收集指定数量的K线 269 for j in range(klines_per_group): 270 if i + j < len(df_session): 271 group_indices.append(i + j) 272 else: 273 break 274 275 # 检查是否需要提前结束当前组 276 # 如果下一根K线是新时段开始(09:05或21:05),当前组就结束 277 next_index = i + len(group_indices) 278 if next_index < len(df_session): 279 if has_datetime_column: 280 # 如果有datetime列,从中提取时间 281 next_datetime = pd.to_datetime(df_session.iloc[next_index]['datetime']) 282 next_time = next_datetime.strftime('%H:%M') 283 else: 284 next_time = df_session.iloc[next_index]['time'] 285 286 if is_session_start_time(next_time): 287 # 下一根是新时段开始,当前组结束(不管够不够数) 288 pass # 使用当前已收集的K线 289 290 # 如果已经到了时段末尾,也要结束当前组 291 if next_index >= len(df_session): 292 # 已经是最后几根K线了 293 pass 294 295 # 合并当前组的K线 296 if group_indices: 297 group = df_session.iloc[group_indices] 298 299 # 根据是否有datetime列来处理时间 300 if has_datetime_column: 301 # 使用最后一根K线的datetime 302 last_datetime = pd.to_datetime(group.iloc[-1]['datetime']) 303 new_kline = { 304 'datetime': last_datetime, 305 'instrument': group.iloc[0]['instrument'], 306 'open': group.iloc[0]['open'], 307 'high': group['high'].max(), 308 'low': group['low'].min(), 309 'close': group.iloc[-1]['close'], 310 'volume': group['volume'].sum() 311 } 312 313 # 添加其他列 314 for col in group.columns: 315 if col not in ['datetime', 'instrument', 'open', 'high', 'low', 'close', 'volume']: 316 if col == 'open_interest': 317 new_kline[col] = group.iloc[-1][col] 318 elif col == 'settlement': 319 new_kline[col] = group.iloc[-1][col] 320 321 print(f"合并K线组: {pd.to_datetime(group.iloc[0]['datetime']).strftime('%H:%M')}-{last_datetime.strftime('%H:%M')} ({len(group)}根) -> {last_datetime.strftime('%H:%M')}") 322 else: 323 # 原有逻辑,使用date和time列 324 new_kline = { 325 'date': group.iloc[0]['date'], 326 'time': group.iloc[-1]['time'], # 关键:使用最后一根K线的时间 327 'open': group.iloc[0]['open'], 328 'high': group['high'].max(), 329 'low': group['low'].min(), 330 'close': group.iloc[-1]['close'], 331 'volume': group['volume'].sum(), 332 'open_interest': group.iloc[-1]['open_interest'] 333 } 334 335 if 'settlement' in group.columns: 336 new_kline['settlement'] = group.iloc[-1]['settlement'] 337 338 print(f"合并K线组: {group.iloc[0]['time']}-{group.iloc[-1]['time']} ({len(group)}根) -> {new_kline['time']}") 339 340 resampled_data.append(new_kline) 341 342 # 移动到下一组 343 i += len(group_indices) 344 345 # 如果下一根是新时段开始,跳出当前时段的处理 346 if next_index < len(df_session): 347 if has_datetime_column: 348 next_datetime = pd.to_datetime(df_session.iloc[next_index]['datetime']) 349 next_time = next_datetime.strftime('%H:%M') 350 else: 351 next_time = df_session.iloc[next_index]['time'] 352 353 if is_session_start_time(next_time): 354 break 355 356 return resampled_data 357 358 def resample_kline_data(df, period_minutes, base_minutes=5): 359 """智能重采样K线数据,考虑交易时段边界""" 360 print(f"智能重采样:从{base_minutes}分钟生成{period_minutes}分钟K线数据") 361 362 # 检查是否有datetime列 363 has_datetime_column = 'datetime' in df.columns 364 365 # 检测交易时段 366 sessions = detect_trading_sessions(df) 367 print(f"检测到 {len(sessions)} 个交易时段") 368 369 all_resampled_data = [] 370 371 # 对每个交易时段分别处理 372 for session_idx, session_indices in enumerate(sessions): 373 df_session = df.iloc[session_indices] 374 session_data = resample_session_data(df_session, period_minutes, base_minutes, has_datetime_column) 375 all_resampled_data.extend(session_data) 376 377 if len(session_data) > 0: 378 if has_datetime_column: 379 start_datetime = pd.to_datetime(df_session.iloc[0]['datetime']) 380 end_datetime = pd.to_datetime(df_session.iloc[-1]['datetime']) 381 start_time = start_datetime.strftime('%H:%M') 382 end_time = end_datetime.strftime('%H:%M') 383 else: 384 start_time = df_session.iloc[0]['time'] 385 end_time = df_session.iloc[-1]['time'] 386 387 print(f"时段{session_idx+1}: {start_time}-{end_time}, 原始{len(df_session)}根 -> 重采样{len(session_data)}根") 388 389 return pd.DataFrame(all_resampled_data) 390 391 def generate_higher_timeframes(file_path, output_dir): 392 """只从5分钟数据生成15分钟、30分钟、60分钟数据""" 393 base_filename = os.path.basename(file_path) 394 395 # 只处理5分钟数据 396 if '5m' not in base_filename: 397 return [] 398 399 print(f"正在从5分钟数据生成更高周期: {base_filename}") 400 401 # 读取5分钟数据 402 df = pd.read_csv(file_path, encoding='utf-8') 403 404 # 检查是否有datetime列 405 if 'datetime' not in df.columns: 406 print(f"跳过 {base_filename}:没有datetime列") 407 return [] 408 409 generated_files = [] 410 411 # 从5分钟数据生成:15m, 30m, 60m 412 periods = { 413 '15m': 15, 414 '30m': 30, 415 '60m': 60 416 } 417 418 for period_name, period_minutes in periods.items(): 419 try: 420 # 重采样数据(基础周期是5分钟) 421 resampled_df = resample_kline_data(df.copy(), period_minutes, base_minutes=5) 422 423 if len(resampled_df) == 0: 424 print(f"警告:{period_name} 重采样结果为空") 425 continue 426 427 # 生成新文件名 428 new_filename = base_filename.replace('5m', period_name) 429 new_file_path = os.path.join(output_dir, new_filename) 430 431 # 保存文件 432 resampled_df.to_csv(new_file_path, index=False, encoding='utf-8') 433 generated_files.append(new_filename) 434 435 print(f"已生成 {period_name} 数据: {new_filename}") 436 print(f" 数据形状: {resampled_df.shape}") 437 print(f" 时间范围: {resampled_df['datetime'].min()} 到 {resampled_df['datetime'].max()}") 438 439 except Exception as e: 440 print(f"生成 {period_name} 数据时出错: {e}") 441 import traceback 442 traceback.print_exc() 443 444 return generated_files 445 if __name__ == "__main__": 446 main()
浙公网安备 33010602011771号