@@ -946,104 +946,87 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
946
946
assert tracker .mock_calls == expected_sequence
947
947
948
948
949
- @pytest .mark .parametrize ("n" , [1 , 2 ])
950
- def test_dataloaders_load_every_n_epochs (tmpdir , n ):
951
- train_reload_epochs , val_reload_epochs = [], []
952
-
953
- class TestModel (BoringModel ):
954
- def train_dataloader (self ):
955
- train_reload_epochs .append (self .current_epoch )
956
- return super ().train_dataloader ()
957
-
958
- def val_dataloader (self ):
959
- val_reload_epochs .append (self .current_epoch )
960
- return super ().val_dataloader ()
961
-
962
- model = TestModel ()
963
-
964
- trainer = Trainer (
965
- default_root_dir = tmpdir ,
966
- limit_train_batches = 0.3 ,
967
- limit_val_batches = 0.3 ,
968
- reload_dataloaders_every_n_epochs = n ,
969
- max_epochs = 5 ,
970
- )
971
-
972
- tracker = Mock ()
973
- model .train_dataloader = Mock (wraps = model .train_dataloader )
974
- model .val_dataloader = Mock (wraps = model .val_dataloader )
975
- model .test_dataloader = Mock (wraps = model .test_dataloader )
976
-
977
- tracker .attach_mock (model .train_dataloader , "train_dataloader" )
978
- tracker .attach_mock (model .val_dataloader , "val_dataloader" )
979
- tracker .attach_mock (model .test_dataloader , "test_dataloader" )
980
-
981
- trainer .fit (model )
982
- trainer .test (model )
983
-
984
- # Verify the sequence
985
- expected_sequence = [call .val_dataloader (), call .train_dataloader ()] # Sanity check first
986
- if n == 1 :
987
- expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 4
988
- elif n == 2 :
989
- expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 2
990
- expected_sequence += [call .test_dataloader ()]
991
-
992
- assert tracker .mock_calls == expected_sequence
993
-
994
- # Verify epoch of reloads
995
- if n == 1 :
996
- assert train_reload_epochs == [0 , 1 , 2 , 3 , 4 ]
997
- assert val_reload_epochs == [0 , 1 , 2 , 3 , 4 ]
998
- elif n == 2 :
999
- assert train_reload_epochs == [0 , 2 , 4 ]
1000
- assert val_reload_epochs == [0 , 2 , 4 ]
1001
-
1002
-
1003
949
@pytest .mark .parametrize (
1004
- "n, train_reload_epochs_expect, val_reload_epochs_expect" ,
950
+ (
951
+ "num_sanity_val_steps, check_val_every_n_epoch, reload_dataloaders_every_n_epochs,"
952
+ " train_reload_epochs_expect,val_reload_epochs_expect,val_step_epochs_expect"
953
+ ),
1005
954
[
1006
- # Sanity check at epoch 0 creates a validation dataloader, but validation is
1007
- # checked (and in this case reloaded) every n epochs starting from epoch n-1
1008
- (3 , [0 , 2 , 4 , 6 , 8 ], [0 , 2 , 5 , 8 ]),
1009
- (5 , [0 , 2 , 4 , 6 , 8 ], [0 , 4 , 9 ]),
955
+ # general case where sanity check reloads the dataloaders for validation on current_epoch=0
956
+ (0 , 1 , 1 , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]),
957
+ (1 , 1 , 1 , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]),
958
+ # case where check_val_every_n_epoch < reload_dataloaders_every_n_epochs so expected val_reload_epoch
959
+ # and val_step_epoch will be different
960
+ (0 , 1 , 2 , [0 , 2 , 4 , 6 , 8 ], [0 , 2 , 4 , 6 , 8 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]),
961
+ (1 , 1 , 2 , [0 , 2 , 4 , 6 , 8 ], [2 , 4 , 6 , 8 ], [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]),
962
+ (0 , 3 , 4 , [0 , 4 , 8 ], [2 , 8 ], [2 , 5 , 8 ]),
963
+ (1 , 3 , 4 , [0 , 4 , 8 ], [2 , 8 ], [2 , 5 , 8 ]),
964
+ # case where check_val_every_n_epoch > reload_dataloaders_every_n_epochs so expected val_reload_epoch
965
+ # and val_step_epoch will be same
966
+ (0 , 2 , 1 , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [1 , 3 , 5 , 7 , 9 ], [1 , 3 , 5 , 7 , 9 ]),
967
+ (1 , 2 , 1 , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [1 , 3 , 5 , 7 , 9 ], [1 , 3 , 5 , 7 , 9 ]),
968
+ (0 , 3 , 2 , [0 , 2 , 4 , 6 , 8 ], [2 , 5 , 8 ], [2 , 5 , 8 ]),
969
+ (1 , 3 , 2 , [0 , 2 , 4 , 6 , 8 ], [2 , 5 , 8 ], [2 , 5 , 8 ]),
970
+ (0 , 5 , 2 , [0 , 2 , 4 , 6 , 8 ], [4 , 9 ], [4 , 9 ]),
971
+ (1 , 5 , 2 , [0 , 2 , 4 , 6 , 8 ], [4 , 9 ], [4 , 9 ]),
972
+ # case where check_val_every_n_epoch = reload_dataloaders_every_n_epochs so expected val_reload_epoch
973
+ # and val_step_epoch will be same
974
+ (0 , 2 , 2 , [0 , 2 , 4 , 6 , 8 ], [1 , 3 , 5 , 7 , 9 ], [1 , 3 , 5 , 7 , 9 ]),
975
+ (1 , 2 , 2 , [0 , 2 , 4 , 6 , 8 ], [1 , 3 , 5 , 7 , 9 ], [1 , 3 , 5 , 7 , 9 ]),
1010
976
],
1011
977
)
1012
978
def test_dataloaders_load_every_n_epochs_infrequent_val (
1013
- tmpdir , n , train_reload_epochs_expect , val_reload_epochs_expect
979
+ tmpdir ,
980
+ num_sanity_val_steps ,
981
+ check_val_every_n_epoch ,
982
+ reload_dataloaders_every_n_epochs ,
983
+ train_reload_epochs_expect ,
984
+ val_reload_epochs_expect ,
985
+ val_step_epochs_expect ,
1014
986
):
1015
987
"""Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)"""
1016
- train_reload_epochs , val_reload_epochs = [], []
988
+ sanity_val_check_epochs , train_reload_epochs , val_reload_epochs = [], [], []
989
+ sanity_val_step_epochs , val_step_epochs = [], []
1017
990
1018
991
class TestModel (BoringModel ):
1019
992
def train_dataloader (self ):
1020
993
train_reload_epochs .append (self .current_epoch )
1021
994
return super ().train_dataloader ()
1022
995
1023
996
def val_dataloader (self ):
1024
- val_reload_epochs .append (self .current_epoch )
997
+ if self .trainer .sanity_checking :
998
+ sanity_val_check_epochs .append (self .current_epoch )
999
+ else :
1000
+ val_reload_epochs .append (self .current_epoch )
1025
1001
return super ().val_dataloader ()
1026
1002
1003
+ def validation_step (self , * args , ** kwargs ):
1004
+ if self .trainer .sanity_checking :
1005
+ sanity_val_step_epochs .append (self .current_epoch )
1006
+ else :
1007
+ val_step_epochs .append (self .current_epoch )
1008
+
1009
+ return super ().validation_step (* args , ** kwargs )
1010
+
1027
1011
model = TestModel ()
1028
1012
1029
1013
trainer = Trainer (
1030
1014
default_root_dir = tmpdir ,
1031
- limit_train_batches = 0.3 ,
1032
- limit_val_batches = 0.3 ,
1033
- check_val_every_n_epoch = n ,
1034
- reload_dataloaders_every_n_epochs = 2 ,
1015
+ limit_train_batches = 1 ,
1016
+ limit_val_batches = 1 ,
1017
+ check_val_every_n_epoch = check_val_every_n_epoch ,
1018
+ reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs ,
1035
1019
max_epochs = 10 ,
1020
+ num_sanity_val_steps = num_sanity_val_steps ,
1036
1021
)
1037
- model .test_dataloader = Mock (wraps = model .test_dataloader )
1038
-
1039
1022
trainer .fit (model )
1040
- trainer .test (model )
1041
1023
1042
1024
# Verify epoch of reloads
1025
+ sanity_val_check_epochs_expect = [0 ] if num_sanity_val_steps else []
1026
+ assert sanity_val_check_epochs == sanity_val_step_epochs == sanity_val_check_epochs_expect
1043
1027
assert train_reload_epochs == train_reload_epochs_expect
1044
1028
assert val_reload_epochs == val_reload_epochs_expect
1045
-
1046
- model .test_dataloader .assert_called_once ()
1029
+ assert val_step_epochs == val_step_epochs_expect
1047
1030
1048
1031
1049
1032
def test_dataloaders_load_every_n_epochs_frequent_val (tmpdir ):
0 commit comments