Skip to content

scale_batch_size does not work anymore? #13696

@cschell

Description

@cschell

🐛 Bug

After an update of PyTorch Lightning, the batch size scaling of trainer.tune is not working anymore: It just exhausts the maximum configured trials without actually probing the training loop. This results in incorrect batch sizes that cause Cuda OOM errors.

After some code diving I think this is due to a bug in the method _run_power_scaling that sets trainer.fit_loop.global_step = 0 instead of trainer.fit_loop.epoch_loop.global_step = 0, which seems to be required after a recent change to FitLoop.

Edit:

The script below works as expected with pytorch-lightning==1.5.9.

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringDatamodule(LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(RandomDataset(1_000, 999_999), batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(RandomDataset(1_000, 999_999), batch_size=self.batch_size)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(1_000, 10_000)
        self.batch_size = "unset"

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)


def run():
    model = BoringModel().to("cuda:0")
    trainer = Trainer(
            default_root_dir=os.getcwd(),
            limit_train_batches=1,
            limit_val_batches=1,
            limit_test_batches=1,
            num_sanity_val_steps=0,
            max_epochs=1,
            enable_model_summary=False,
            auto_scale_batch_size=True,
            gpus=1
    )

    # the model can be fitted with batch_size=1000 on a NVIDIA 2080Ti without OOM errors
    datamodule = BoringDatamodule(batch_size=1000)

    trainer.tune(model=model, datamodule=datamodule, scale_batch_size_kwargs={"init_val": 1_000, "max_trials": 10})

    print(f"fitting with {datamodule.batch_size=} (fails on a NVIDIA 2080Ti)")
    trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__":
    run()

Expected behavior

trainer.tune finds a batch size that does not provoke a Cuda OOM error.

Environment

  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 2080 Ti
    - available: True
    - version: 11.3
  • Packages:
    - numpy: 1.21.6
    - pyTorch_debug: False
    - pyTorch_version: 1.12.0+cu113
    - pytorch-lightning: 1.6.4
    - tqdm: 4.64.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.10
    - version: 127-Ubuntu SMP Wed May 18 14:30:56 UTC 2022

cc @akihironitta @Borda @rohitgr7

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtuner

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions