transformers关键代码(需要完善)
1、训练参数的配置
training_args=Seq2SeqTrainingArguments(
# dataloader_num_workers=4,
num_train_epochs=epochNo,
save_strategy='epoch',
evaluation_strategy=evaluation_strategy,#是否全量'no' if constants.ifFullData else 'epoch',
logging_steps=50,
save_total_limit=save_total_limit, #最多保存模型个数
metric_for_best_model='eval_cider', #修改衡量指标
greater_is_better=True,
learning_rate=lr,
warmup_ratio=0.03,
seed=userSeed,overwrite_output_dir=True,
per_device_eval_batch_size=batchsize,
per_device_train_batch_size=batchsize,
output_dir=outputPath,
do_train=True,
do_eval=do_eval,#是否全量False if constants.ifFullData else True,
predict_with_generate=True,
label_smoothing_factor=0.1 if constants.isSMOOTH else 0
)
2、 Datasets 数据的构建
首先定义一个dict,其value是list
results={'summarization':[],'article':[]}
然后
results=Dataset.from_dict(results)
print( isinstance( results, torch.utils.data.IterableDataset))
我当记事本用的

浙公网安备 33010602011771号