RNN不定长数据处理

先贴代码,之后来补解释说明
`def custom_collate(batch):
# Separate data and labels
data, labels = zip(*batch)

Pad sequences to have the same length

padded_data = pad_sequence(data, batch_first=True, padding_value=-1)

Convert labels to tensor

labels = torch.tensor(labels, dtype=torch.long)

return padded_data, labels dataloadersTrain = torch.utils.data.DataLoader(trainDatasets,
batch_size=2,
collate_fn=custom_collate,
shuffle=True,
num_workers=0)`

posted @ 2024-07-17 21:22  chenanyee  阅读(44)  评论(0)    收藏  举报