|
29 | 29 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
30 | 30 |
|
31 | 31 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
32 |
| -from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model |
| 32 | +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters |
33 | 33 | from timm.utils import *
|
34 | 34 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
35 | 35 | from timm.optim import create_optimizer
|
|
116 | 116 | help='weight decay (default: 0.0001)')
|
117 | 117 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
118 | 118 | help='Clip gradient norm (default: None, no clipping)')
|
119 |
| - |
| 119 | +parser.add_argument('--clip-mode', type=str, default='norm', |
| 120 | + help='Gradient clipping mode. One of ("norm", "value", "agc")') |
120 | 121 |
|
121 | 122 |
|
122 | 123 | # Learning rate schedule parameters
|
@@ -637,11 +638,16 @@ def train_one_epoch(
|
637 | 638 | optimizer.zero_grad()
|
638 | 639 | if loss_scaler is not None:
|
639 | 640 | loss_scaler(
|
640 |
| - loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) |
| 641 | + loss, optimizer, |
| 642 | + clip_grad=args.clip_grad, clip_mode=args.clip_mode, |
| 643 | + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), |
| 644 | + create_graph=second_order) |
641 | 645 | else:
|
642 | 646 | loss.backward(create_graph=second_order)
|
643 | 647 | if args.clip_grad is not None:
|
644 |
| - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) |
| 648 | + dispatch_clip_grad( |
| 649 | + model_parameters(model, exclude_head='agc' in args.clip_mode), |
| 650 | + value=args.clip_grad, mode=args.clip_mode) |
645 | 651 | optimizer.step()
|
646 | 652 |
|
647 | 653 | if model_ema is not None:
|
|
0 commit comments