Skip to content

Introduce a mechanism to set the epoch on the sampler in LightningLite #14636

@awaelchli

Description

@awaelchli

🚀 Feature

We need a mechanism to set the epoch on the distributed sampler via .set_epoch().

Motivation

To correctly handle shuffling with the DistributedSampler in DDP, the PyTorch user would normally call

sampler.set_epoch(epoch)

in their training loop. PL handles this automatically for the user, but in Lite, the training loop is owned by the user and hence should be signaled in a different way.

Pitch

Provide a strategy-agnostic API to set the epoch. The idea here is that you don't need to change the code or add boilerplate conditional logic to handle the sampler when you switch from DDP to single device or vice versa.

Before

....
train_dataloader, val_dataloader = self.setup_dataloaders(train_dataloader, val_dataloader)

for epoch in range(num_epochs):

    # Beginning of a new epoch

    # Boilerplate code
    if isinstance(train_dataloader.sampler, DistributedSampler):
        train_dataloader.sampler.set_epoch(epoch)


    for idx, data in enumerate(train_dataloader):
        ...

Now:

....
train_dataloader, val_dataloader = self.setup_dataloaders(train_dataloader, val_dataloader)

for epoch in range(num_epochs):

    # Beginning of a new epoch

    # This is a no-op if not using distributed sampler (DDP)
    self.set_epoch(epoch, train_dataloader)

    for idx, data in enumerate(train_dataloader):
        ...

Alternatives

Don't introduce this. It is left to the user to handle this boilerplate code.

Additional context

PyTorch docs for set_epoch


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.

cc @Borda @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

Labels

fabriclightning.fabric.FabricfeatureIs an improvement or enhancement

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions