RNN不定长数据处理

先贴代码,之后来补解释说明
`def custom_collate(batch):
# Separate data and labels
data, labels = zip(*batch)

# Pad sequences to have the same length
padded_data = pad_sequence(data, batch_first=True, padding_value=-1)

# Convert labels to tensor
labels = torch.tensor(labels, dtype=torch.long)

return padded_data, labels`

dataloadersTrain = torch.utils.data.DataLoader(trainDatasets, batch_size=2, collate_fn=custom_collate, shuffle=True, num_workers=0)

posted @ 2024-07-17 21:22  chenanyee  阅读(31)  评论(0)    收藏  举报