@@ -218,8 +218,27 @@ def test_lr_monitor_multi_lrs(tmpdir, logging_interval: str):
218
218
"""Test that learning rates are extracted and logged for multi lr schedulers."""
219
219
220
220
class CustomBoringModel (BoringModel ):
221
- def training_step (self , batch , batch_idx , optimizer_idx ):
222
- return super ().training_step (batch , batch_idx )
221
+ def __init__ (self ):
222
+ super ().__init__ ()
223
+ self .automatic_optimization = False
224
+
225
+ def training_step (self , batch , batch_idx ):
226
+ opt1 , opt2 = self .optimizers ()
227
+
228
+ loss = self .loss (self .step (batch ))
229
+ opt1 .zero_grad ()
230
+ self .manual_backward (loss )
231
+ opt1 .step ()
232
+
233
+ loss = self .loss (self .step (batch ))
234
+ opt2 .zero_grad ()
235
+ self .manual_backward (loss )
236
+ opt2 .step ()
237
+
238
+ def on_train_epoch_end (self ):
239
+ scheduler1 , scheduler2 = self .lr_schedulers ()
240
+ scheduler1 .step ()
241
+ scheduler2 .step ()
223
242
224
243
def configure_optimizers (self ):
225
244
optimizer1 = optim .Adam (self .parameters (), lr = 1e-2 )
@@ -262,8 +281,22 @@ def test_lr_monitor_no_lr_scheduler_multi_lrs(tmpdir, logging_interval: str):
262
281
"""Test that learning rates are extracted and logged for multi optimizers but no lr scheduler."""
263
282
264
283
class CustomBoringModel (BoringModel ):
265
- def training_step (self , batch , batch_idx , optimizer_idx ):
266
- return super ().training_step (batch , batch_idx )
284
+ def __init__ (self ):
285
+ super ().__init__ ()
286
+ self .automatic_optimization = False
287
+
288
+ def training_step (self , batch , batch_idx ):
289
+ opt1 , opt2 = self .optimizers ()
290
+
291
+ loss = self .loss (self .step (batch ))
292
+ opt1 .zero_grad ()
293
+ self .manual_backward (loss )
294
+ opt1 .step ()
295
+
296
+ loss = self .loss (self .step (batch ))
297
+ opt2 .zero_grad ()
298
+ self .manual_backward (loss )
299
+ opt2 .step ()
267
300
268
301
def configure_optimizers (self ):
269
302
optimizer1 = optim .Adam (self .parameters (), lr = 1e-2 )
@@ -421,22 +454,46 @@ def test_multiple_optimizers_basefinetuning(tmpdir):
421
454
class TestModel (BoringModel ):
422
455
def __init__ (self ):
423
456
super ().__init__ ()
457
+ self .automatic_optimization = False
424
458
self .backbone = torch .nn .Sequential (
425
459
torch .nn .Linear (32 , 32 ), torch .nn .Linear (32 , 32 ), torch .nn .Linear (32 , 32 ), torch .nn .ReLU (True )
426
460
)
427
461
self .layer = torch .nn .Linear (32 , 2 )
428
462
429
- def training_step (self , batch , batch_idx , optimizer_idx ):
430
- return super ().training_step (batch , batch_idx )
463
+ def training_step (self , batch , batch_idx ):
464
+ opt1 , opt2 , opt3 = self .optimizers ()
465
+
466
+ # optimizer 1
467
+ loss = self .step (batch )
468
+ self .manual_backward (loss )
469
+ opt1 .step ()
470
+ opt1 .zero_grad ()
471
+
472
+ # optimizer 2
473
+ loss = self .step (batch )
474
+ self .manual_backward (loss )
475
+ opt2 .step ()
476
+ opt2 .zero_grad ()
477
+
478
+ # optimizer 3
479
+ loss = self .step (batch )
480
+ self .manual_backward (loss )
481
+ opt3 .step ()
482
+ opt3 .zero_grad ()
483
+
484
+ def on_train_epoch_end (self ) -> None :
485
+ lr_sched1 , lr_sched2 = self .lr_schedulers ()
486
+ lr_sched1 .step ()
487
+ lr_sched2 .step ()
431
488
432
489
def forward (self , x ):
433
490
return self .layer (self .backbone (x ))
434
491
435
492
def configure_optimizers (self ):
436
493
parameters = list (filter (lambda p : p .requires_grad , self .parameters ()))
437
- opt = optim .Adam (parameters , lr = 0.1 )
494
+ opt = optim .SGD (parameters , lr = 0.1 )
438
495
opt_2 = optim .Adam (parameters , lr = 0.1 )
439
- opt_3 = optim .Adam (parameters , lr = 0.1 )
496
+ opt_3 = optim .AdamW (parameters , lr = 0.1 )
440
497
optimizers = [opt , opt_2 , opt_3 ]
441
498
schedulers = [
442
499
optim .lr_scheduler .StepLR (opt , step_size = 1 , gamma = 0.5 ),
@@ -452,24 +509,24 @@ def on_train_epoch_start(self, trainer, pl_module) -> None:
452
509
assert num_param_groups == 3
453
510
elif trainer .current_epoch == 1 :
454
511
assert num_param_groups == 4
455
- assert list (lr_monitor .lrs ) == ["lr-Adam-1 " , "lr-Adam-2 " , "lr-Adam /pg1" , "lr-Adam /pg2" ]
512
+ assert list (lr_monitor .lrs ) == ["lr-Adam" , "lr-AdamW " , "lr-SGD /pg1" , "lr-SGD /pg2" ]
456
513
elif trainer .current_epoch == 2 :
457
514
assert num_param_groups == 5
458
515
assert list (lr_monitor .lrs ) == [
459
- "lr-Adam-2" ,
516
+ "lr-AdamW" ,
517
+ "lr-SGD/pg1" ,
518
+ "lr-SGD/pg2" ,
460
519
"lr-Adam/pg1" ,
461
520
"lr-Adam/pg2" ,
462
- "lr-Adam-1/pg1" ,
463
- "lr-Adam-1/pg2" ,
464
521
]
465
522
else :
466
523
expected = [
467
- "lr-Adam-2" ,
524
+ "lr-AdamW" ,
525
+ "lr-SGD/pg1" ,
526
+ "lr-SGD/pg2" ,
468
527
"lr-Adam/pg1" ,
469
528
"lr-Adam/pg2" ,
470
- "lr-Adam-1/pg1" ,
471
- "lr-Adam-1/pg2" ,
472
- "lr-Adam-1/pg3" ,
529
+ "lr-Adam/pg3" ,
473
530
]
474
531
assert list (lr_monitor .lrs ) == expected
475
532
@@ -481,12 +538,12 @@ def freeze_before_training(self, pl_module):
481
538
482
539
def finetune_function (self , pl_module , epoch : int , optimizer , opt_idx : int ):
483
540
"""Called when the epoch begins."""
484
- if epoch == 1 and opt_idx == 0 :
541
+ if epoch == 1 and isinstance ( optimizer , torch . optim . SGD ) :
485
542
self .unfreeze_and_add_param_group (pl_module .backbone [0 ], optimizer , lr = 0.1 )
486
- if epoch == 2 and opt_idx == 1 :
543
+ if epoch == 2 and isinstance ( optimizer , torch . optim . Adam ) :
487
544
self .unfreeze_and_add_param_group (pl_module .layer , optimizer , lr = 0.1 )
488
545
489
- if epoch == 3 and opt_idx == 1 :
546
+ if epoch == 3 and isinstance ( optimizer , torch . optim . Adam ) :
490
547
assert len (optimizer .param_groups ) == 2
491
548
self .unfreeze_and_add_param_group (pl_module .backbone [1 ], optimizer , lr = 0.1 )
492
549
assert len (optimizer .param_groups ) == 3
@@ -507,22 +564,22 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
507
564
trainer .fit (model )
508
565
509
566
expected = [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ]
510
- assert lr_monitor .lrs ["lr-Adam-2 " ] == expected
567
+ assert lr_monitor .lrs ["lr-AdamW " ] == expected
511
568
512
569
expected = [0.1 , 0.05 , 0.025 , 0.0125 , 0.00625 ]
513
- assert lr_monitor .lrs ["lr-Adam /pg1" ] == expected
570
+ assert lr_monitor .lrs ["lr-SGD /pg1" ] == expected
514
571
515
572
expected = [0.1 , 0.05 , 0.025 , 0.0125 ]
516
- assert lr_monitor .lrs ["lr-Adam /pg2" ] == expected
573
+ assert lr_monitor .lrs ["lr-SGD /pg2" ] == expected
517
574
518
575
expected = [0.1 , 0.05 , 0.025 , 0.0125 , 0.00625 ]
519
- assert lr_monitor .lrs ["lr-Adam-1 /pg1" ] == expected
576
+ assert lr_monitor .lrs ["lr-Adam/pg1" ] == expected
520
577
521
578
expected = [0.1 , 0.05 , 0.025 ]
522
- assert lr_monitor .lrs ["lr-Adam-1 /pg2" ] == expected
579
+ assert lr_monitor .lrs ["lr-Adam/pg2" ] == expected
523
580
524
581
expected = [0.1 , 0.05 ]
525
- assert lr_monitor .lrs ["lr-Adam-1 /pg3" ] == expected
582
+ assert lr_monitor .lrs ["lr-Adam/pg3" ] == expected
526
583
527
584
528
585
def test_lr_monitor_multiple_param_groups_no_lr_scheduler (tmpdir ):
0 commit comments