分类任务
import numpy as np
import evaluate
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
logits 是模型的输出
labels是真实标签
用 numpy 的 argmax 函数沿着最后一个维度(即每个样本的类别维度)找到分数最大的索引,这些索引即为模型的预测类别
返回准确率
生成任务
BLEU、ROUGE、METEOR 等,这些指标用于比较生成的文本和参考文本
浙公网安备 33010602011771号