@@ -126,6 +126,9 @@ def _run_power_scaling(
126
126
trainer : "pl.Trainer" , model : "pl.LightningModule" , new_size : int , batch_arg_name : str , max_trials : int
127
127
) -> int :
128
128
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
129
+ # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
130
+ # if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
131
+ any_success = False
129
132
for _ in range (max_trials ):
130
133
garbage_collection_cuda ()
131
134
@@ -137,22 +140,28 @@ def _run_power_scaling(
137
140
trainer .tuner ._run (model )
138
141
# Double in size
139
142
new_size , changed = _adjust_batch_size (trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" )
143
+
144
+ if not changed :
145
+ break
146
+
147
+ # Force the train dataloader to reset as the batch size has changed
148
+ trainer .reset_train_dataloader (model )
149
+ trainer .reset_val_dataloader (model )
150
+ any_success = True
140
151
except RuntimeError as exception :
141
152
# Only these errors should trigger an adjustment
142
153
if is_oom_error (exception ):
143
154
# If we fail in power mode, half the size and return
144
155
garbage_collection_cuda ()
145
156
new_size , _ = _adjust_batch_size (trainer , batch_arg_name , factor = 0.5 , desc = "failed" )
146
- break
157
+ # Force the train dataloader to reset as the batch size has changed
158
+ trainer .reset_train_dataloader (model )
159
+ trainer .reset_val_dataloader (model )
160
+ if any_success :
161
+ break
147
162
else :
148
163
raise # some other error not memory related
149
164
150
- if changed :
151
- # Force the train dataloader to reset as the batch size has changed
152
- trainer .reset_train_dataloader (model )
153
- trainer .reset_val_dataloader (model )
154
- else :
155
- break
156
165
return new_size
157
166
158
167
@@ -189,13 +198,13 @@ def _run_binsearch_scaling(
189
198
else :
190
199
new_size , changed = _adjust_batch_size (trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" )
191
200
192
- if changed :
193
- # Force the train dataloader to reset as the batch size has changed
194
- trainer .reset_train_dataloader (model )
195
- trainer .reset_val_dataloader (model )
196
- else :
201
+ if not changed :
197
202
break
198
203
204
+ # Force the train dataloader to reset as the batch size has changed
205
+ trainer .reset_train_dataloader (model )
206
+ trainer .reset_val_dataloader (model )
207
+
199
208
except RuntimeError as exception :
200
209
# Only these errors should trigger an adjustment
201
210
if is_oom_error (exception ):
@@ -204,6 +213,11 @@ def _run_binsearch_scaling(
204
213
high = new_size
205
214
midval = (high + low ) // 2
206
215
new_size , _ = _adjust_batch_size (trainer , batch_arg_name , value = midval , desc = "failed" )
216
+
217
+ # Force the train dataloader to reset as the batch size has changed
218
+ trainer .reset_train_dataloader (model )
219
+ trainer .reset_val_dataloader (model )
220
+
207
221
if high - low <= 1 :
208
222
break
209
223
else :
0 commit comments