FastCorrect&Fairseq学习笔记

一 工作说明:

  FastCorrect,字面意思就是快速纠错;这项主要是对asr的识别结果进行纠错,提升识别率;

  目前大部分的纠错模型采用了基于注意力机制的端到端自回归模型(seq2seq model to correct an ASR output sentence autoregressively)结构。这种结构延迟较大;为此,微软亚洲研究院机器学习组与微软 Azure 语音团队合作,推出了 FastCorrect 系列工作,提出了低延迟的纠错模型;相关研究论文已被 NeurIPS 2021 和 EMNLP 2021 收录;

  Fairseq,是一个开源的序列建模工具,由Facebook AI Research于2017年9月推出,Fairseq基于python&pytorch,更加简单,人性化;主要应用场景是nlp任务,支持多种常用模型;

  FastCorrect就是基于Fairseq工具进行训练的模型;

  本文记录FastCorrect的学习过程,中间对机器学习,Fairseq的学习,理解和记录;初入此道,欢迎讨论;

  FastCorrect git:https://github.com/microsoft/NeuralSpeech/tree/master/FastCorrect

  Fairseq git:https://github.com/facebookresearch/fairseq

  Fairseq 文档:https://fairseq.readthedocs.io/en/latest/command_line_tools.html

 

二 Fairseq训练流程&核心代码阅读:

1 训练命令解读:

  训练命令:

    fairseq-train $DATA_PATH --task fastcorrect \
    --arch fastcorrect --lr 5e-4 --lr-scheduler inverse_sqrt \
    --length-loss-factor 0.5 \
    --noise full_mask \
    --src-with-werdur \
    --dur-predictor-type "v2" \
    --dropout 0.3 --weight-decay 0.0001 \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --criterion fc_loss --label-smoothing 0.1 \
    --max-tokens 9000 \
    --werdur-max-predict 3 \
    --assist-edit-loss \
    --save-dir $SAVE_DIR \
    --user-dir $EXP_HOME/FastCorrect \
    --left-pad-target False --left-pad-source False \
    --encoder-layers 6 --decoder-layers 6 \
    --max-epoch 30 --update-freq 4 --fp16 --num-workers 8 \
    --share-all-embeddings --encoder-embed-dim=512 --decoder-embed-dim=512

  这个命令的含义,可以自己去查Fairseq的文档,这里罗列出来:

    --task,意思是任务类型,默认是“translation”(翻译);包括translation_from_pretrained_bart等Fairseq自带的任务类型,这里设置为本任务fastcorrect ;

    --arch,意思是model architecture,模型结构,可选项包括,transformer,lstm等;这里采用自带的结构fastcorrect 

     --lr 学习率,为初始学习率,后续可能被--lr-scheduler修改;--lr-scheduler 是lr更新计划,这里采用inverse_sqrt 方法;

    --length-loss-factor 0.5
     --noise full_mask 
     --src-with-werdur 
     --dur-predictor-type "v2"

     --dropout 字面意思就是丢弃,这里指的是在训练模型时,丢弃一部分数据来防止过拟合;

    --weight-decay

    --optimizer adam --adam-betas '(0.9, 0.98)' 参数优化策略;

    --clip-norm 0.0 
    --criterion fc_loss 训练准则;

    --label-smoothing 0.1 
    --max-tokens 9000 一个batch最大的token数量;
    --werdur-max-predict 3 
    --assist-edit-loss 
    --save-dir $SAVE_DIR 存储checkpoints的路径,checkpoint即模型;
    --user-dir $EXP_HOME/FastCorrect 一个包含扩展的python模块,这里的扩展是指模型结构或者任务,和task是相对的,一般不适用官方规定的arch时需要手动设置这个路径;
    --left-pad-target False --left-pad-source False 
    --encoder-layers 6 --decoder-layers 6 
    --max-epoch 30 当达到这个30个epoch的时候,停止训练;

    --update-freq 4 参数更新频率,每4个batch更新参数;

    --fp16 使用FP16

      --num-workers 8 8个子线程用于load数据;
    --share-all-embeddings

     --encoder-embed-dim=512 --decoder-embed-dim=512

 

 

posted on 2022-09-19 16:24  MyTD21  阅读(563)  评论(0编辑  收藏  举报

导航