Skip to content

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Nov 19, 2023

What does this PR do?

This PR temporarily fixes the issue that the DDP strategies who call set_world_ranks only set the rank_zero_only.rank attribute for the fabric/pytorch utilities, however NOT for the lightning_utilities package. This leads to logs appearing on rank > 0 on spawn-based strategies.

The decision to outsource these utilities into the separate package made all of this very brittle. Now we need to maintain 2 globals instead of one. With this fix, I will open an issue requesting that these utilities get moved back into the Lightning package so we only need to maintain one global variable (rank_zero_only.rank).

Minimal repro (just observe info and warn outputs duplicated):

import torch
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(max_epochs=1, devices=2, strategy="ddp_spawn", accelerator="cpu")
    trainer.fit(model, train_data)


if __name__ == "__main__":
    run()

Discovered while debugging CI flakiness. This fix should automatically resolve the timeouts for some of the ddp-spawn tests like test_result_reduce_ddp etc.


📚 Documentation preview 📚: https://pytorch-lightning--19030.org.readthedocs.build/en/19030/

cc @Borda @tchaton @carmocca @justusschock @awaelchli

@awaelchli awaelchli added bug Something isn't working priority: 0 High priority task strategy: ddp DistributedDataParallel labels Nov 19, 2023
@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package and removed strategy: ddp DistributedDataParallel labels Nov 19, 2023
@awaelchli awaelchli added this to the 2.1.x milestone Nov 19, 2023
Copy link
Contributor

github-actions bot commented Nov 19, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/pytorch/strategies/ddp.py, src/lightning/pytorch/strategies/fsdp.py, tests/tests_pytorch/conftest.py, tests/tests_pytorch/models/test_torchscript.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to src/lightning/pytorch/strategies/ddp.py, src/lightning/pytorch/strategies/fsdp.py, tests/tests_pytorch/conftest.py, tests/tests_pytorch/models/test_torchscript.py, src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/pytorch/strategies/ddp.py, src/lightning/pytorch/strategies/fsdp.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/strategies/ddp.py, src/lightning/pytorch/strategies/fsdp.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.11, 2.1) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1) success
fabric-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.11, 2.1) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.1) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py, tests/tests_fabric/conftest.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) (testing Fabric | latest) success
lightning-fabric (GPUs) (testing Lightning | latest) success

These checks are required after the changes to src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py, tests/tests_fabric/conftest.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/pytorch/strategies/ddp.py, src/lightning/pytorch/strategies/fsdp.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/strategies/ddp.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/pytorch/strategies/ddp.py, src/lightning/pytorch/strategies/fsdp.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@awaelchli awaelchli added the strategy: ddp DistributedDataParallel label Nov 19, 2023
Copy link

codecov bot commented Nov 19, 2023

Codecov Report

Merging #19030 (fe3ccf1) into master (4f4c890) will decrease coverage by 30%.
The diff coverage is 100%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19030      +/-   ##
==========================================
- Coverage      84%      54%     -30%     
==========================================
  Files         443      438       -5     
  Lines       36154    36060      -94     
==========================================
- Hits        30260    19461   -10799     
- Misses       5894    16599   +10705     

@awaelchli awaelchli added the fun Staff contributions outside working hours - to differentiate from the "community" label label Nov 20, 2023
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is why I always suggest that only the fabric/pl imports are used (e.g. #16178)

Unfortunately, PyCharm seems to prefer importing utilities when using auto-import, so it's easy to miss

@mergify mergify bot added the ready PRs ready to be merged label Nov 20, 2023
@carmocca carmocca merged commit f652e6c into master Nov 20, 2023
@carmocca carmocca deleted the tests/flaky-reduce-ddp branch November 20, 2023 15:49
Borda pushed a commit that referenced this pull request Dec 19, 2023
lantiga pushed a commit that referenced this pull request Dec 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric fun Staff contributions outside working hours - to differentiate from the "community" label pl Generic label for PyTorch Lightning package priority: 0 High priority task ready PRs ready to be merged strategy: ddp DistributedDataParallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants