对IterableDataset进行训练、验证集划分
IterableDataset不能随机访问和shuffle的背景
IterableDataset适用于无法全部加载到内存或无法随机访问的数据源(如流式数据、大型文本文件、数据库、远程数据等)。- 它不支持索引和shuffle,即不能像
Dataset那样直接切片或乱序采样。
面对巨大的数据源时的训练集和验证集划分方案
1. 基于数据流的顺序分割(常用)
假设你的数据是一个连续流(如大文件、Kafka流、数据库游标等),可以采用如下方式:
-
顺序分割法
比如前80%数据作为训练集,后20%作为验证集。这种方法不会打乱数据,适合数据本身没有强时序依赖或分布均匀的场景。class SplitIterableDataset(torch.utils.data.IterableDataset): def __init__(self, data_iter, split='train', split_ratio=0.8): self.data_iter = data_iter self.split = split self.split_ratio = split_ratio def __iter__(self): for idx, item in enumerate(self.data_iter): if self.split == 'train': if idx < int(self.split_ratio * self._get_total_length()): yield item else: if idx >= int(self.split_ratio * self._get_total_length()): yield item def _get_total_length(self): # 这里需要你能获得数据总长度,否则只能用分片法 pass注意:如果数据源无法获得总长度,可以用分片法(见下)。
2. 分片法(Sharding)
-
每个worker只处理部分数据
比如:第0个worker处理第0、N、2N...条数据,第1个worker处理第1、N+1、2N+1...,以此类推。 -
用于训练/验证划分:
你可以规定所有偶数编号为训练集,奇数编号为验证集,或者用某种哈希函数决定属于哪一集。class ShardedIterableDataset(torch.utils.data.IterableDataset): def __init__(self, data_iter, split='train', split_ratio=0.8): self.data_iter = data_iter self.split = split self.split_ratio = split_ratio def __iter__(self): for idx, item in enumerate(self.data_iter): if self.split == 'train': if hash(idx) % 10 < int(self.split_ratio * 10): yield item else: if hash(idx) % 10 >= int(self.split_ratio * 10): yield item这种方法可以在数据量极大、无法提前统计总量的场景下使用。
3. 基于特征的划分
如果每条数据有唯一ID(如用户ID、样本ID),可以对ID做哈希:
- 例如:
hash(id) % 10 < 8归为训练集,其余为验证集,这样可以保证同一个ID始终只在一个集合里。
4. 提前处理/生成索引文件
- 预处理阶段
在数据准备阶段,提前对数据做一次遍历,生成训练/验证的索引文件(如保存样本ID或偏移量),然后分别用不同的IterableDataset加载。 - 适合数据量极大但可以离线处理一次的场景。
总结建议
- 顺序分割:适合数据无时序依赖,且样本分布均匀。
- 哈希分割:适合有唯一标识的数据,可保证分布均匀且无泄漏。
- 分片法/取模法:适合无法提前知道总量的数据流。
- 提前生成索引:适合可离线处理一次的大型数据。
参考代码(哈希分割)
class HashSplitIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, data_iter, split='train', split_ratio=0.8, id_fn=lambda x: x['id']):
self.data_iter = data_iter
self.split = split
self.split_ratio = split_ratio
self.id_fn = id_fn
def __iter__(self):
for item in self.data_iter:
idx = hash(self.id_fn(item))
if self.split == 'train':
if idx % 10 < int(self.split_ratio * 10):
yield item
else:
if idx % 10 >= int(self.split_ratio * 10):
yield item
如有具体数据源类型或数据结构,可进一步细化实现建议。

浙公网安备 33010602011771号