@@ -134,6 +134,121 @@ you can also pass in an :doc:`datamodules <../data/datamodule>` that have overri
134
134
# test (pass in datamodule)
135
135
trainer.test(datamodule = dm)
136
136
137
+
138
+ Test with Multiple DataLoaders
139
+ ==============================
140
+
141
+ When you need to evaluate your model on multiple test datasets simultaneously (e.g., different domains, conditions, or
142
+ evaluation scenarios), PyTorch Lightning supports multiple test dataloaders out of the box.
143
+
144
+ To use multiple test dataloaders, simply return a list of dataloaders from your ``test_dataloader() `` method:
145
+
146
+ .. code-block :: python
147
+
148
+ class LitModel (L .LightningModule ):
149
+ def test_dataloader (self ):
150
+ return [
151
+ DataLoader(clean_test_dataset, batch_size = 32 ),
152
+ DataLoader(noisy_test_dataset, batch_size = 32 ),
153
+ DataLoader(adversarial_test_dataset, batch_size = 32 ),
154
+ ]
155
+
156
+ When using multiple test dataloaders, your ``test_step `` method **must ** include a ``dataloader_idx `` parameter:
157
+
158
+ .. code-block :: python
159
+
160
+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
161
+ x, y = batch
162
+ y_hat = self (x)
163
+ loss = F.cross_entropy(y_hat, y)
164
+
165
+ # Use dataloader_idx to handle different test scenarios
166
+ return {' test_loss' : loss}
167
+
168
+ Logging Metrics Per Dataloader
169
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
170
+
171
+ Lightning provides automatic support for logging metrics per dataloader:
172
+
173
+ .. code-block :: python
174
+
175
+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
176
+ x, y = batch
177
+ y_hat = self (x)
178
+ loss = F.cross_entropy(y_hat, y)
179
+ acc = (y_hat.argmax(dim = 1 ) == y).float().mean()
180
+
181
+ # Lightning automatically adds "/dataloader_idx_X" suffix
182
+ self .log(' test_loss' , loss, add_dataloader_idx = True )
183
+ self .log(' test_acc' , acc, add_dataloader_idx = True )
184
+
185
+ return loss
186
+
187
+ This will create metrics like ``test_loss/dataloader_idx_0 ``, ``test_loss/dataloader_idx_1 ``, etc.
188
+
189
+ For more meaningful metric names, you can use custom naming where you need to make sure that individual names are
190
+ unique across dataloaders.
191
+
192
+ .. code-block :: python
193
+
194
+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
195
+ # Define meaningful names for each dataloader
196
+ dataloader_names = {0 : " clean" , 1 : " noisy" , 2 : " adversarial" }
197
+ dataset_name = dataloader_names.get(dataloader_idx, f " dataset_ { dataloader_idx} " )
198
+
199
+ # Log with custom names
200
+ self .log(f ' test_loss_ { dataset_name} ' , loss, add_dataloader_idx = False )
201
+ self .log(f ' test_acc_ { dataset_name} ' , acc, add_dataloader_idx = False )
202
+
203
+ Processing Entire Datasets Per Dataloader
204
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
205
+
206
+ To perform calculations on the entire test dataset for each dataloader (e.g., computing overall metrics, creating
207
+ visualizations), accumulate results during ``test_step `` and process them in ``on_test_epoch_end ``:
208
+
209
+ .. code-block :: python
210
+
211
+ class LitModel (L .LightningModule ):
212
+ def __init__ (self ):
213
+ super ().__init__ ()
214
+ # Store outputs per dataloader
215
+ self .test_outputs = {}
216
+
217
+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
218
+ x, y = batch
219
+ y_hat = self (x)
220
+ loss = F.cross_entropy(y_hat, y)
221
+
222
+ # Initialize and store results
223
+ if dataloader_idx not in self .test_outputs:
224
+ self .test_outputs[dataloader_idx] = {' predictions' : [], ' targets' : []}
225
+ self .test_outputs[dataloader_idx][' predictions' ].append(y_hat)
226
+ self .test_outputs[dataloader_idx][' targets' ].append(y)
227
+ return loss
228
+
229
+ def on_test_epoch_end (self ):
230
+ for dataloader_idx, outputs in self .test_outputs.items():
231
+ # Concatenate all predictions and targets for this dataloader
232
+ all_predictions = torch.cat(outputs[' predictions' ], dim = 0 )
233
+ all_targets = torch.cat(outputs[' targets' ], dim = 0 )
234
+
235
+ # Calculate metrics on the entire dataset, log and create visualizations
236
+ overall_accuracy = (all_predictions.argmax(dim = 1 ) == all_targets).float().mean()
237
+ self .log(f ' test_overall_acc_dataloader_ { dataloader_idx} ' , overall_accuracy)
238
+ self ._save_results(all_predictions, all_targets, dataloader_idx)
239
+
240
+ self .test_outputs.clear()
241
+
242
+ .. note ::
243
+ When using multiple test dataloaders, ``trainer.test() `` returns a list of results, one for each dataloader:
244
+
245
+ .. code-block :: python
246
+
247
+ results = trainer.test(model)
248
+ print (f " Results from { len (results)} test dataloaders: " )
249
+ for i, result in enumerate (results):
250
+ print (f " Dataloader { i} : { result} " )
251
+
137
252
----------
138
253
139
254
**********
0 commit comments