Skip to content

Bagua training not working with manual optimization #12534

@lorcll

Description

@lorcll

🐛 Bug

Bagua strategy raises errors when running manual optimization

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.automatic_optimization = False

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        accelerator='gpu',
        gpus=1,
        strategy="bagua",
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Expected behavior

Traceback (most recent call last):                                                                           
  File "/home/wizard/test_bug/test_bug.py", line 70, in <module>                                             
    run()                                                                                                    
  File "/home/wizard/test_bug/test_bug.py", line 65, in run                                                  
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)                               
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py
", line 771, in fit                                                                                          
    self._call_and_handle_interrupt(                                                                         
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py
", line 722, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/strategies/launche
rs/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py
", line 812, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py
", line 1237, in _run
    results = self._run_stage()
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py
", line 1324, in _run_stage 
    return self._run_train()
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py
", line 1354, in _run_train 
    self.fit_loop.run()
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", li
ne 204, in run 
   self.advance(*args, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 269, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 208, in advance
    batch_output = self.batch_loop.run(batch, batch_idx)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 90, in advance
    outputs = self.manual_loop.run(split_batch, batch_idx)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 115, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1766, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 344, in training_step
    return self.model(*args, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/bagua/torch_api/data_parallel/distributed.py", line 171, in forward
    output = self.module(*inputs, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", line 88, in forward
    trainer.model.require_backward_grad_sync = False  # type: ignore[assignment]
  File "/home/wizard/miniconda/envs/bug_test/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1233, in __setattr__
    object.__setattr__(self, name, value)
AttributeError: can't set attribute

Environment

* CUDA:
        - GPU:
                - Tesla T4
        - available:         True
        - version:           11.3
* Packages:
        - numpy:             1.22.3
        - pyTorch_debug:     False
        - pyTorch_version:   1.11.0
        - pytorch-lightning: 1.6.0
        - tqdm:              4.63.1
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - 
        - processor:         x86_64
        - python:            3.9.7
        - version:           #1 SMP Fri Apr 30 09:52:02 PDT 2021
        

Additional context

This is due to the Bagua DDP module redefining parameters within itself and being blind to those of the original model. They are accessible through self.model.inner rather than self.model
Suggested Fix: Add a check on strategy type

  1. In pytorch_lightning/strategies/ddp.py line 365:
  def post_training_step(self):
        if not self.lightning_module.automatic_optimization:
            if self.strategy_name == "bagua":
                self.model.inner.require_backward_grad_sync = True
            else:
                self.model.require_backward_grad_sync = True

  1. In pytorch_lightning/overrides/base.py line 87:
                if not pl_module.automatic_optimization:
                    if trainer.strategy.strategy_name == "bagua":
                        trainer.model.inner.require_backward_grad_sync = False  # type: ignore[assignment]
                    else:
                        trainer.model.require_backward_grad_sync = False  # type: ignore[assignment]

I tested this solution and I get the same results on a single GPU as I would with DDP in the original code (I tested this only to compare results, I also tested if it runs on multiple GPUs this way). Also, it is compatible with the other implemented strategies.

cc @awaelchli @wangraying @akihironitta

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions