Python 多线程曲线救国

搜了好久的python多线程,用过线程池,但是在我电脑上跑个稍微复杂的函数就不行了,思来想去,pytorch的Dataloader不就是现成的嘛,真的是。

from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class DataSet_h(Dataset):
    def __init__(self, s11):
        super(DataSet_h, self).__init__()
        # 这里的数组就是我们取实际的数据了
        self.Arr = list(s11[0].value_counts().index[:])

    def __len__(self):
        # 数组的长度
        return len(self.Arr)

    def __getitem__(self, item):
        # 取数据的时候要按着自己定义模型需要来,一般都会有x, y
        return self.Arr[item]

trainDataSet = DataSet_h(s11)
trainDataLoader = DataLoader(trainDataSet, batch_size=1)

for i, batch in enumerate(tqdm(trainDataLoader)):
    # 里面你的代码
    print(batch[0])
    pass
posted @ 2022-09-01 22:16  赫凯  阅读(19)  评论(0)    收藏  举报