diff --git a/imagenet/main.py b/imagenet/main.py index e828ea0fa7..2b1c7f9a49 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -215,6 +215,9 @@ def main_worker(gpu, ngpus_per_node, args): best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) + if args.lr != parser.get_default('lr'): + # resume with newly specified learning rate + optimizer.param_groups[0]['lr'] = args.lr scheduler.load_state_dict(checkpoint['scheduler']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) @@ -293,8 +296,8 @@ def main_worker(gpu, ngpus_per_node, args): 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, - 'optimizer' : optimizer.state_dict(), - 'scheduler' : scheduler.state_dict() + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict() }, is_best)