SWA,模型参数平均代码
import os
from argparse import ArgumentParser
import torch
def main():
parser = ArgumentParser()
parser.add_argument(
'--model_dir', default='E:\\code\\swa_object_detection-master\\swa_object_detection-master\pth', help='the directory where checkpoints are saved')
parser.add_argument(
'--starting_model_id',
default='6000',
type=int,
help='the id of the starting checkpoint for averaging, e.g. 1')
parser.add_argument(
'--ending_model_id',
default='16000',
type=int,
help='the id of the ending checkpoint for averaging, e.g. 12')
parser.add_argument(
'--save_dir',
default='E:\\code\\swa_object_detection-master\\swa_object_detection-master\pth',
help='the directory for saving the SWA model')
args = parser.parse_args()
model_dir = args.model_dir
starting_id = int(args.starting_model_id)
ending_id = int(args.ending_model_id)
model_names = list(range(starting_id, ending_id, 2000))
model_dirs = [
os.path.join(model_dir, 'iter_' + str(i) + '.pth')
for i in model_names
]
models = [torch.load(model_dir) for model_dir in model_dirs]
model_num = len(models)
model_keys = models[-1]['state_dict'].keys()
state_dict = models[-1]['state_dict']
new_state_dict = state_dict.copy()
ref_model = models[-1]
for key in model_keys:
sum_weight = 0.0
for m in models:
sum_weight += m['state_dict'][key]
avg_weight = sum_weight / model_num
new_state_dict[key] = avg_weight
ref_model['state_dict'] = new_state_dict
a = str(args.ending_model_id)
save_model_name = 'swa_' + str(args.starting_model_id) + '-' + a + '.pth'
if args.save_dir is not None:
save_dir = os.path.join(args.save_dir, save_model_name)
else:
save_dir = os.path.join(model_dir, save_model_name)
torch.save(ref_model, save_dir)
print('Model is saved at', save_dir)
if __name__ == '__main__':
main()
@article{zhang2020swa,
title={SWA Object Detection},
author={Zhang, Haoyang and Wang, Ying and Dayoub, Feras and S{\"u}nderhauf, Niko},
journal={arXiv preprint arXiv:2012.12645},
year={2020}
}
原始代码:github.com/hyz-xmaster/swa_object_detection
浙公网安备 33010602011771号