-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
We need a mechanism to set the epoch on the distributed sampler via .set_epoch()
.
Motivation
To correctly handle shuffling with the DistributedSampler in DDP, the PyTorch user would normally call
sampler.set_epoch(epoch)
in their training loop. PL handles this automatically for the user, but in Lite, the training loop is owned by the user and hence should be signaled in a different way.
Pitch
Provide a strategy-agnostic API to set the epoch. The idea here is that you don't need to change the code or add boilerplate conditional logic to handle the sampler when you switch from DDP to single device or vice versa.
Before
....
train_dataloader, val_dataloader = self.setup_dataloaders(train_dataloader, val_dataloader)
for epoch in range(num_epochs):
# Beginning of a new epoch
# Boilerplate code
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
for idx, data in enumerate(train_dataloader):
...
Now:
....
train_dataloader, val_dataloader = self.setup_dataloaders(train_dataloader, val_dataloader)
for epoch in range(num_epochs):
# Beginning of a new epoch
# This is a no-op if not using distributed sampler (DDP)
self.set_epoch(epoch, train_dataloader)
for idx, data in enumerate(train_dataloader):
...
Alternatives
Don't introduce this. It is left to the user to handle this boilerplate code.
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.