Skip to content

Commit 18ceb67

Browse files
SkafteNickiBorda
andauthored
Add documentation on multi test dataloaders (#21215)
add documentation on multi test dataloaders Co-authored-by: Jirka Borovec <[email protected]>
1 parent e9368a7 commit 18ceb67

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

docs/source-pytorch/common/evaluation_intermediate.rst

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,121 @@ you can also pass in an :doc:`datamodules <../data/datamodule>` that have overri
134134
# test (pass in datamodule)
135135
trainer.test(datamodule=dm)
136136
137+
138+
Test with Multiple DataLoaders
139+
==============================
140+
141+
When you need to evaluate your model on multiple test datasets simultaneously (e.g., different domains, conditions, or
142+
evaluation scenarios), PyTorch Lightning supports multiple test dataloaders out of the box.
143+
144+
To use multiple test dataloaders, simply return a list of dataloaders from your ``test_dataloader()`` method:
145+
146+
.. code-block:: python
147+
148+
class LitModel(L.LightningModule):
149+
def test_dataloader(self):
150+
return [
151+
DataLoader(clean_test_dataset, batch_size=32),
152+
DataLoader(noisy_test_dataset, batch_size=32),
153+
DataLoader(adversarial_test_dataset, batch_size=32),
154+
]
155+
156+
When using multiple test dataloaders, your ``test_step`` method **must** include a ``dataloader_idx`` parameter:
157+
158+
.. code-block:: python
159+
160+
def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
161+
x, y = batch
162+
y_hat = self(x)
163+
loss = F.cross_entropy(y_hat, y)
164+
165+
# Use dataloader_idx to handle different test scenarios
166+
return {'test_loss': loss}
167+
168+
Logging Metrics Per Dataloader
169+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
170+
171+
Lightning provides automatic support for logging metrics per dataloader:
172+
173+
.. code-block:: python
174+
175+
def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
176+
x, y = batch
177+
y_hat = self(x)
178+
loss = F.cross_entropy(y_hat, y)
179+
acc = (y_hat.argmax(dim=1) == y).float().mean()
180+
181+
# Lightning automatically adds "/dataloader_idx_X" suffix
182+
self.log('test_loss', loss, add_dataloader_idx=True)
183+
self.log('test_acc', acc, add_dataloader_idx=True)
184+
185+
return loss
186+
187+
This will create metrics like ``test_loss/dataloader_idx_0``, ``test_loss/dataloader_idx_1``, etc.
188+
189+
For more meaningful metric names, you can use custom naming where you need to make sure that individual names are
190+
unique across dataloaders.
191+
192+
.. code-block:: python
193+
194+
def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
195+
# Define meaningful names for each dataloader
196+
dataloader_names = {0: "clean", 1: "noisy", 2: "adversarial"}
197+
dataset_name = dataloader_names.get(dataloader_idx, f"dataset_{dataloader_idx}")
198+
199+
# Log with custom names
200+
self.log(f'test_loss_{dataset_name}', loss, add_dataloader_idx=False)
201+
self.log(f'test_acc_{dataset_name}', acc, add_dataloader_idx=False)
202+
203+
Processing Entire Datasets Per Dataloader
204+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
205+
206+
To perform calculations on the entire test dataset for each dataloader (e.g., computing overall metrics, creating
207+
visualizations), accumulate results during ``test_step`` and process them in ``on_test_epoch_end``:
208+
209+
.. code-block:: python
210+
211+
class LitModel(L.LightningModule):
212+
def __init__(self):
213+
super().__init__()
214+
# Store outputs per dataloader
215+
self.test_outputs = {}
216+
217+
def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
218+
x, y = batch
219+
y_hat = self(x)
220+
loss = F.cross_entropy(y_hat, y)
221+
222+
# Initialize and store results
223+
if dataloader_idx not in self.test_outputs:
224+
self.test_outputs[dataloader_idx] = {'predictions': [], 'targets': []}
225+
self.test_outputs[dataloader_idx]['predictions'].append(y_hat)
226+
self.test_outputs[dataloader_idx]['targets'].append(y)
227+
return loss
228+
229+
def on_test_epoch_end(self):
230+
for dataloader_idx, outputs in self.test_outputs.items():
231+
# Concatenate all predictions and targets for this dataloader
232+
all_predictions = torch.cat(outputs['predictions'], dim=0)
233+
all_targets = torch.cat(outputs['targets'], dim=0)
234+
235+
# Calculate metrics on the entire dataset, log and create visualizations
236+
overall_accuracy = (all_predictions.argmax(dim=1) == all_targets).float().mean()
237+
self.log(f'test_overall_acc_dataloader_{dataloader_idx}', overall_accuracy)
238+
self._save_results(all_predictions, all_targets, dataloader_idx)
239+
240+
self.test_outputs.clear()
241+
242+
.. note::
243+
When using multiple test dataloaders, ``trainer.test()`` returns a list of results, one for each dataloader:
244+
245+
.. code-block:: python
246+
247+
results = trainer.test(model)
248+
print(f"Results from {len(results)} test dataloaders:")
249+
for i, result in enumerate(results):
250+
print(f"Dataloader {i}: {result}")
251+
137252
----------
138253

139254
**********

0 commit comments

Comments
 (0)