Skip to content

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Nov 6, 2022

What does this PR do?

This PR exposes RunWorkExecutor and provides some default one for MultiNode Component.

Example with PyTorch. Processes are automatically spawned for the users.

import torch
from torch.nn.parallel.distributed import DistributedDataParallel

import lightning as L
from lightning.app.components import PyTorchSpawnMultiNode


class PyTorchDistributed(L.LightningWork):

    # Note: Only staticmethod are support for now with `PyTorchSpawnMultiNode`
    @staticmethod
    def run(
        world_size: int,
        node_rank: int,
        global_rank: str,
        local_rank: int,
    ):
        # 1. Prepare distributed model
        model = torch.nn.Linear(32, 2)
        device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
        device_ids = device if torch.cuda.is_available() else None
        model = DistributedDataParallel(model, device_ids=device_ids).to(device)

        # 2. Prepare loss and optimizer
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

        # 3. Train the model for 50 steps.
        for step in range(50):
            model.zero_grad()
            x = torch.randn(64, 32).to(device)
            output = model(x)
            loss = criterion(output, torch.ones_like(output))
            print(f"global_rank: {global_rank} step: {step} loss: {loss}")
            loss.backward()
            optimizer.step()


compute = L.CloudCompute("gpu-fast-multi")  # 4 x V100
app = L.LightningApp(
    PyTorchSpawnMultiNode(
        PyTorchDistributed,
        num_nodes=2,
        cloud_compute=compute,
    )
)

Example with LightningLite.

import torch

import lightning as L
from lightning.app.components import LiteMultiNode
from lightning.lite import LightningLite


class LitePyTorchDistributed(L.LightningWork):
    @staticmethod
    def run():
        # 1. Create LightningLite.
        lite = LightningLite(strategy="ddp", precision="bf16")

        # 2. Prepare distributed model and optimizer.
        model = torch.nn.Linear(32, 2)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        model, optimizer = lite.setup(model, optimizer)
        criterion = torch.nn.MSELoss()

        # 3. Train the model for 50 steps.
        for step in range(50):
            model.zero_grad()
            x = torch.randn(64, 32).to(lite.device)
            output = model(x)
            loss = criterion(output, torch.ones_like(output))
            print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
            lite.backward(loss)
            optimizer.step()


app = L.LightningApp(
    LiteMultiNode(
        LitePyTorchDistributed,
        cloud_compute=L.CloudCompute("gpu-fast-multi"),  # 4 x V100,
        num_nodes=2,
    )
)
import lightning as L
from lightning.app.components import PyTorchLightningMultiNode
from lightning.pytorch.demos.boring_classes import BoringModel


class PyTorchLightningDistributed(L.LightningWork):
    @staticmethod
    def run():
        model = BoringModel()
        trainer = L.Trainer(
            max_epochs=10,
            strategy="ddp",
        )
        trainer.fit(model)


compute = L.CloudCompute("gpu-fast-multi")  # 4 x V100
app = L.LightningApp(
    PyTorchLightningMultiNode(
        PyTorchLightningDistributed,
        num_nodes=2,
        cloud_compute=compute,
    )
)

Fixes #<issue_number>

Does your PR introduce any breaking changes? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

cc @Borda

@mergify mergify bot added the ready PRs ready to be merged label Nov 8, 2022
Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, a few minor comments

@Borda Borda requested review from Borda and manskx and removed request for otaj, rohitgr7 and kaushikb11 November 8, 2022 10:05
@tchaton tchaton requested review from awaelchli and removed request for manskx November 8, 2022 10:47
@tchaton tchaton enabled auto-merge (squash) November 8, 2022 12:15
@tchaton tchaton merged commit f9a6573 into master Nov 8, 2022
@tchaton tchaton deleted the expose_work_runner branch November 8, 2022 12:55
Borda pushed a commit that referenced this pull request Nov 8, 2022
lexierule pushed a commit that referenced this pull request Nov 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
app (removed) Generic label for Lightning App package ready PRs ready to be merged
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

7 participants