RNN不定长数据处理
先贴代码,之后来补解释说明
`def custom_collate(batch):
# Separate data and labels
data, labels = zip(*batch)
# Pad sequences to have the same length
padded_data = pad_sequence(data, batch_first=True, padding_value=-1)
# Convert labels to tensor
labels = torch.tensor(labels, dtype=torch.long)
return padded_data, labels`
dataloadersTrain = torch.utils.data.DataLoader(trainDatasets, batch_size=2, collate_fn=custom_collate, shuffle=True, num_workers=0)
浙公网安备 33010602011771号