小隐的博客

人生在世,笑饮一生
  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

AI量化qlib学习笔记 一:清洗数据

Posted on 2025-08-09 16:27  隐客  阅读(152)  评论(0)    收藏  举报

最近关注了一个AI量化投资框架qlib,略做笔记,因为某些无法实现原因,不得不放弃实战。

首先这个框架仅仅是对特征有效,并没有其它关联,比如有些股是有时间规律的,比如农产品,还有一些有环境 规律 ,比如A股一般是权重先涨,稳住大盘后,换成小盘股唱戏。还有一些是政策类的,比如出台了某些可能是较长影响的政策利好某板块的。

当然这些可以自己在训练的时候当成因子写进去,这个就比较花时间和精力去做了。

首先把源码从github git下来 地址:https://github.com/microsoft/qlib

有个注意情况 ,项目主目录和源码目录不要用qlib做为目录 ,改个名,比如 qlib_src这样的。

是python环境 ,我用的是poetry 搞的虚拟 环境 ,后面所有都是在当前项目的虚拟环境中进行。

然后进入源码目录安装打包安装以及相关依赖  pip install -e .

然后开始相关准备,最开始是准备数据,它的示例中是从yahoo财经下载的,这不是开玩笑嘛,不过有个捷径,自行下载:https://github.com/chenditc/investment_data/releases

说实话,如果要是搞的不是A股,请自行准备数据,这个我花了很长时间,因为qlib对数据的要求太高了,或太多不支持了。

我从通达信获取的数据,不仅有日线,还有分时的,不得不说,通达信还是很不错的,无偿提供。

拿到数据,需要清洗,以下是清洗脚本:

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 """
  4 简单的CSV数据清洗脚本 - 为qlib准备数据
  5 """
  6 
  7 import pandas as pd
  8 import os
  9 
 10 def convert_filename(original_filename):
 11     """转换文件名格式:品种代码.市场代码.周期.csv -> 品种代码.交易所代码.周期.csv"""
 12     # 市场代码到交易所代码的映射
 13     market_mapping = {
 14         '30': 'SHFE',  # 上海期货交易所
 15         '28': 'CZCE',  # 郑州商品交易所
 16         '29': 'DCE',   # 大连商品交易所
 17         '31': 'INE',   # 上海国际能源交易中心
 18         '66': 'GFEX',   # 广州期货交易所
 19         '42': 'ZS'   # 指数
 20     }
 21     
 22     parts = original_filename.split('.')
 23     if len(parts) >= 4:  # 品种代码.市场代码.周期.csv
 24         symbol = parts[0]
 25         market_code = parts[1]
 26         period = parts[2]
 27         extension = parts[3]
 28         
 29         # 转换市场代码
 30         exchange_code = market_mapping.get(market_code, market_code)
 31         new_filename = f"{symbol}.{exchange_code}.{period}.{extension}"
 32         
 33         print(f"文件名转换: {original_filename} -> {new_filename}")
 34         return new_filename
 35     
 36     # 如果格式不匹配,返回原文件名
 37     return original_filename
 38 
 39 def format_time_column(time_value):
 40     """格式化时间列:905 -> 09:05"""
 41     time_str = str(int(time_value)).zfill(4)  # 确保4位数字,前面补0
 42     return f"{time_str[:2]}:{time_str[2:]}"
 43 
 44 def clean_csv_file(input_file, output_file):
 45     """清洗单个CSV文件"""
 46     print(f"正在处理: {input_file}")
 47     
 48     # 读取原始文件
 49     with open(input_file, 'r', encoding='gbk') as f:
 50         lines = f.readlines()
 51     
 52     # 去掉第一行(提示信息)和最后一行
 53     data_lines = lines[1:-1]
 54     
 55     # 处理标题行(第二行,现在是第一行)
 56     header_line = data_lines[0].strip()
 57     
 58     # 中文到英文的映射
 59     header_mapping = {
 60         '日期': 'date',
 61         '时间': 'time', 
 62         '开盘': 'open',
 63         '最高': 'high',
 64         '最低': 'low',
 65         '收盘': 'close',
 66         '成交量': 'volume',
 67         '持仓量': 'open_interest',
 68         '结算价': 'settlement'
 69     }
 70     
 71     # 将tab分隔的标题转为英文逗号分隔
 72     headers = [h.strip() for h in header_line.split('\t')]
 73     english_headers = [header_mapping.get(h, h) for h in headers]
 74     
 75     # 创建输出目录
 76     os.makedirs(os.path.dirname(output_file), exist_ok=True)
 77     
 78     # 先写入临时文件,然后格式化时间列
 79     temp_file = output_file + '.tmp'
 80     with open(temp_file, 'w', encoding='utf-8') as f:
 81         # 写入英文标题
 82         f.write(','.join(english_headers) + '\n')
 83         # 写入数据行
 84         for line in data_lines[1:]:
 85             f.write(line)
 86     
 87     # 读取临时文件并格式化时间列
 88     df = pd.read_csv(temp_file, encoding='utf-8')
 89     
 90     # 如果有时间列,格式化时间
 91     if 'time' in df.columns:
 92         df['time'] = df['time'].apply(format_time_column)
 93     
 94     # 关键修改:如果是分钟级数据(1m或5m),合并日期和时间字段为datetime
 95     filename = os.path.basename(output_file)
 96     is_minute_data = any(period in filename for period in ['1m', '5m'])
 97     
 98     if is_minute_data and 'date' in df.columns and 'time' in df.columns:
 99         period_type = '1分钟' if '1m' in filename else '5分钟'
100         print(f"检测到{period_type}原始数据,正在合并日期和时间字段...")
101         
102         # 合并日期和时间
103         df['datetime'] = pd.to_datetime(df['date'].astype(str) + ' ' + df['time'].astype(str))
104         
105         # 添加instrument列(从文件名提取)
106         instrument_name = filename.split('.')[0] + '.' + filename.split('.')[1]  # 如RBL8.SHFE
107         df['instrument'] = instrument_name
108         
109         # 重新排列列顺序,符合qlib标准格式
110         qlib_columns = ['datetime', 'instrument']
111         for col in ['open', 'high', 'low', 'close', 'volume']:
112             if col in df.columns:
113                 qlib_columns.append(col)
114         
115         # 添加其他列
116         for col in df.columns:
117             if col not in qlib_columns and col not in ['date', 'time']:
118                 qlib_columns.append(col)
119         
120         df = df[qlib_columns]
121         
122         print(f"已合并datetime字段,数据形状: {df.shape}")
123         print(f"时间范围: {df['datetime'].min()} 到 {df['datetime'].max()}")
124     
125     # 如果是日线数据,也需要处理datetime格式
126     elif 'date' in df.columns and 'time' not in df.columns:
127         print("检测到日线数据,正在处理日期格式...")
128         
129         # 日线数据只有日期,添加默认时间
130         df['datetime'] = pd.to_datetime(df['date'].astype(str) + ' 15:00:00')
131         
132         # 添加instrument列
133         filename = os.path.basename(output_file)
134         instrument_name = filename.split('.')[0] + '.' + filename.split('.')[1]
135         df['instrument'] = instrument_name
136         
137         # 重新排列列顺序
138         qlib_columns = ['datetime', 'instrument']
139         for col in ['open', 'high', 'low', 'close', 'volume']:
140             if col in df.columns:
141                 qlib_columns.append(col)
142         
143         # 添加其他列
144         for col in df.columns:
145             if col not in qlib_columns and col not in ['date', 'time']:
146                 qlib_columns.append(col)
147         
148         df = df[qlib_columns]
149         
150         print(f"已处理日线数据,数据形状: {df.shape}")
151         print(f"时间范围: {df['datetime'].min()} 到 {df['datetime'].max()}")
152     
153     # 保存最终文件
154     df.to_csv(output_file, index=False, encoding='utf-8')
155     
156     # 删除临时文件
157     os.remove(temp_file)
158     
159     print(f"已保存到: {output_file}")
160     return output_file
161 
162 def find_csv_files(directory='.'):
163     """查找目录中的所有CSV文件"""
164     csv_files = []
165     for file in os.listdir(directory):
166         if file.endswith('.csv') and os.path.isfile(file):
167             csv_files.append(file)
168     return sorted(csv_files)
169 
170 def main():
171     """主函数"""
172     # 自动查找所有CSV文件
173     csv_files = find_csv_files()
174     
175     if not csv_files:
176         print("当前目录没有找到CSV文件")
177         return
178     
179     print(f"找到 {len(csv_files)} 个CSV文件:")
180     for file in csv_files:
181         print(f"  - {file}")
182     
183     # 创建输出目录
184     output_dir = 'cleaned_data'
185     os.makedirs(output_dir, exist_ok=True)
186     
187     # 清洗所有文件
188     cleaned_files = []
189     for input_file in csv_files:
190         try:
191             # 转换文件名
192             new_filename = convert_filename(input_file)
193             output_file = os.path.join(output_dir, new_filename)
194             cleaned_file = clean_csv_file(input_file, output_file)
195             cleaned_files.append(os.path.basename(cleaned_file))
196         except Exception as e:
197             print(f"处理文件 {input_file} 时出错: {e}")
198     
199     # 从5分钟数据生成更高周期数据(15m, 30m, 60m)
200     print(f"\n开始从5分钟数据生成更高周期数据...")
201     all_generated_files = []
202     for cleaned_file in cleaned_files:
203         file_path = os.path.join(output_dir, cleaned_file)
204         generated = generate_higher_timeframes(file_path, output_dir)
205         all_generated_files.extend(generated)
206     
207     print(f"\n数据清洗完成!")
208     print(f"清洗后的文件保存在: {output_dir} 目录")
209     print(f"原始数据文件: {len(cleaned_files)} 个 (1m, 5m, 日线)")
210     print(f"合成的高周期文件: {len(all_generated_files)} 个 (15m, 30m, 60m)")
211     print(f"总计: {len(cleaned_files + all_generated_files)} 个文件")
212 
213 def is_session_start_time(time_str):
214     """判断是否是交易时段开始时间(09:05或21:05)"""
215     if isinstance(time_str, str) and ':' in time_str:
216         return time_str in ['09:05', '21:05']
217     else:
218         time_num = int(time_str)
219         return time_num in [905, 2105]
220 
221 def detect_trading_sessions(df):
222     """检测交易时段边界"""
223     sessions = []
224     current_session = []
225     
226     # 检查是否有datetime列
227     has_datetime_column = 'datetime' in df.columns
228     
229     for i, row in df.iterrows():
230         if has_datetime_column:
231             # 从datetime列提取时间
232             datetime_obj = pd.to_datetime(row['datetime'])
233             time_str = datetime_obj.strftime('%H:%M')
234         else:
235             # 使用原有的time列
236             time_str = row['time']
237         
238         if i == 0:
239             current_session = [i]
240         else:
241             # 如果当前时间是新时段开始时间,结束上一个时段
242             if is_session_start_time(time_str) and len(current_session) > 0:
243                 sessions.append(current_session)
244                 current_session = [i]
245             else:
246                 current_session.append(i)
247     
248     if current_session:
249         sessions.append(current_session)
250     
251     return sessions
252 
253 def resample_session_data(df_session, period_minutes, base_minutes=5, has_datetime_column=False):
254     """按你的逻辑进行重采样"""
255     if len(df_session) == 0:
256         return []
257     
258     # 根据周期确定每组K线个数
259     klines_per_group = period_minutes // base_minutes  # 如:15分钟÷5分钟=3根,15分钟÷1分钟=15根
260     
261     resampled_data = []
262     i = 0
263     
264     while i < len(df_session):
265         # 收集当前组的K线
266         group_indices = []
267         
268         # 正常情况:收集指定数量的K线
269         for j in range(klines_per_group):
270             if i + j < len(df_session):
271                 group_indices.append(i + j)
272             else:
273                 break
274         
275         # 检查是否需要提前结束当前组
276         # 如果下一根K线是新时段开始(09:05或21:05),当前组就结束
277         next_index = i + len(group_indices)
278         if next_index < len(df_session):
279             if has_datetime_column:
280                 # 如果有datetime列,从中提取时间
281                 next_datetime = pd.to_datetime(df_session.iloc[next_index]['datetime'])
282                 next_time = next_datetime.strftime('%H:%M')
283             else:
284                 next_time = df_session.iloc[next_index]['time']
285             
286             if is_session_start_time(next_time):
287                 # 下一根是新时段开始,当前组结束(不管够不够数)
288                 pass  # 使用当前已收集的K线
289         
290         # 如果已经到了时段末尾,也要结束当前组
291         if next_index >= len(df_session):
292             # 已经是最后几根K线了
293             pass
294         
295         # 合并当前组的K线
296         if group_indices:
297             group = df_session.iloc[group_indices]
298             
299             # 根据是否有datetime列来处理时间
300             if has_datetime_column:
301                 # 使用最后一根K线的datetime
302                 last_datetime = pd.to_datetime(group.iloc[-1]['datetime'])
303                 new_kline = {
304                     'datetime': last_datetime,
305                     'instrument': group.iloc[0]['instrument'],
306                     'open': group.iloc[0]['open'],
307                     'high': group['high'].max(),
308                     'low': group['low'].min(),
309                     'close': group.iloc[-1]['close'],
310                     'volume': group['volume'].sum()
311                 }
312                 
313                 # 添加其他列
314                 for col in group.columns:
315                     if col not in ['datetime', 'instrument', 'open', 'high', 'low', 'close', 'volume']:
316                         if col == 'open_interest':
317                             new_kline[col] = group.iloc[-1][col]
318                         elif col == 'settlement':
319                             new_kline[col] = group.iloc[-1][col]
320                 
321                 print(f"合并K线组: {pd.to_datetime(group.iloc[0]['datetime']).strftime('%H:%M')}-{last_datetime.strftime('%H:%M')} ({len(group)}根) -> {last_datetime.strftime('%H:%M')}")
322             else:
323                 # 原有逻辑,使用date和time列
324                 new_kline = {
325                     'date': group.iloc[0]['date'],
326                     'time': group.iloc[-1]['time'],  # 关键:使用最后一根K线的时间
327                     'open': group.iloc[0]['open'],
328                     'high': group['high'].max(),
329                     'low': group['low'].min(),
330                     'close': group.iloc[-1]['close'],
331                     'volume': group['volume'].sum(),
332                     'open_interest': group.iloc[-1]['open_interest']
333                 }
334                 
335                 if 'settlement' in group.columns:
336                     new_kline['settlement'] = group.iloc[-1]['settlement']
337                 
338                 print(f"合并K线组: {group.iloc[0]['time']}-{group.iloc[-1]['time']} ({len(group)}根) -> {new_kline['time']}")
339             
340             resampled_data.append(new_kline)
341         
342         # 移动到下一组
343         i += len(group_indices)
344         
345         # 如果下一根是新时段开始,跳出当前时段的处理
346         if next_index < len(df_session):
347             if has_datetime_column:
348                 next_datetime = pd.to_datetime(df_session.iloc[next_index]['datetime'])
349                 next_time = next_datetime.strftime('%H:%M')
350             else:
351                 next_time = df_session.iloc[next_index]['time']
352             
353             if is_session_start_time(next_time):
354                 break
355     
356     return resampled_data
357 
358 def resample_kline_data(df, period_minutes, base_minutes=5):
359     """智能重采样K线数据,考虑交易时段边界"""
360     print(f"智能重采样:从{base_minutes}分钟生成{period_minutes}分钟K线数据")
361     
362     # 检查是否有datetime列
363     has_datetime_column = 'datetime' in df.columns
364     
365     # 检测交易时段
366     sessions = detect_trading_sessions(df)
367     print(f"检测到 {len(sessions)} 个交易时段")
368     
369     all_resampled_data = []
370     
371     # 对每个交易时段分别处理
372     for session_idx, session_indices in enumerate(sessions):
373         df_session = df.iloc[session_indices]
374         session_data = resample_session_data(df_session, period_minutes, base_minutes, has_datetime_column)
375         all_resampled_data.extend(session_data)
376         
377         if len(session_data) > 0:
378             if has_datetime_column:
379                 start_datetime = pd.to_datetime(df_session.iloc[0]['datetime'])
380                 end_datetime = pd.to_datetime(df_session.iloc[-1]['datetime'])
381                 start_time = start_datetime.strftime('%H:%M')
382                 end_time = end_datetime.strftime('%H:%M')
383             else:
384                 start_time = df_session.iloc[0]['time']
385                 end_time = df_session.iloc[-1]['time']
386             
387             print(f"时段{session_idx+1}: {start_time}-{end_time}, 原始{len(df_session)}根 -> 重采样{len(session_data)}根")
388     
389     return pd.DataFrame(all_resampled_data)
390 
391 def generate_higher_timeframes(file_path, output_dir):
392     """只从5分钟数据生成15分钟、30分钟、60分钟数据"""
393     base_filename = os.path.basename(file_path)
394     
395     # 只处理5分钟数据
396     if '5m' not in base_filename:
397         return []
398     
399     print(f"正在从5分钟数据生成更高周期: {base_filename}")
400     
401     # 读取5分钟数据
402     df = pd.read_csv(file_path, encoding='utf-8')
403     
404     # 检查是否有datetime列
405     if 'datetime' not in df.columns:
406         print(f"跳过 {base_filename}:没有datetime列")
407         return []
408     
409     generated_files = []
410     
411     # 从5分钟数据生成:15m, 30m, 60m
412     periods = {
413         '15m': 15,
414         '30m': 30,
415         '60m': 60
416     }
417     
418     for period_name, period_minutes in periods.items():
419         try:
420             # 重采样数据(基础周期是5分钟)
421             resampled_df = resample_kline_data(df.copy(), period_minutes, base_minutes=5)
422             
423             if len(resampled_df) == 0:
424                 print(f"警告:{period_name} 重采样结果为空")
425                 continue
426             
427             # 生成新文件名
428             new_filename = base_filename.replace('5m', period_name)
429             new_file_path = os.path.join(output_dir, new_filename)
430             
431             # 保存文件
432             resampled_df.to_csv(new_file_path, index=False, encoding='utf-8')
433             generated_files.append(new_filename)
434             
435             print(f"已生成 {period_name} 数据: {new_filename}")
436             print(f"  数据形状: {resampled_df.shape}")
437             print(f"  时间范围: {resampled_df['datetime'].min()} 到 {resampled_df['datetime'].max()}")
438             
439         except Exception as e:
440             print(f"生成 {period_name} 数据时出错: {e}")
441             import traceback
442             traceback.print_exc()
443     
444     return generated_files
445 if __name__ == "__main__":
446     main()

 

QQ交流