Skip to content

Commit f9a6573

Browse files
authored
[App] Expose Run Work Executor (#15561)
1 parent 3ea8903 commit f9a6573

File tree

25 files changed

+497
-133
lines changed

25 files changed

+497
-133
lines changed

examples/app_multi_node/README.md

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,46 @@ Lightning supports makes multi-node training simple by providing a simple interf
66

77
You can run the multi-node raw PyTorch by running the following commands.
88

9+
Here is an example where you spawn your processes yourself.
10+
11+
```bash
12+
lightning run app train_pytorch.py
13+
```
14+
15+
or you can use the built-in component for it.
16+
917
```bash
10-
lightning run app app_torch_work.py
18+
lightning run app train_pytorch_spawn.py
1119
```
1220

1321
## Multi Node with raw PyTorch + Lite
1422

1523
You can run the multi-node raw PyTorch and Lite by running the following commands.
1624

1725
```bash
18-
lightning run app app_lite_work.py
26+
lightning run app train_lite.py
1927
```
2028

29+
Using Lite, you retain control over your loops while accessing in a minimal way all Lightning distributed strategies.
30+
2131
## Multi Node with PyTorch Lightning
2232

2333
Lightning supports running PyTorch Lightning from a script or within a Lightning Work.
2434

25-
### Multi Node PyTorch Lightning Script
35+
You can either run a script directly
2636

2737
```bash
28-
lightning run app app_pl_script.py
38+
lightning run app train_pl_script.py
2939
```
3040

31-
### Multi Node PyTorch Lightning Work
41+
or run your code within as a work.
3242

3343
```bash
34-
lightning run app app_pl_work.py
44+
lightning run app train_pl.py
3545
```
3646

3747
## Multi Node with any frameworks
3848

3949
```bash
40-
lightning run app app_generic_work.py
50+
lightning run app train_any.py
4151
```

examples/app_multi_node/app_lite_work.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

examples/app_multi_node/app_pl_work.py

Lines changed: 0 additions & 38 deletions
This file was deleted.
File renamed without changes.

examples/app_multi_node/app_generic_work.py renamed to examples/app_multi_node/train_any.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ def run(
1313
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {num_nodes} {node_rank}.")
1414

1515

16-
compute = L.CloudCompute("gpu")
1716
app = L.LightningApp(
1817
MultiNode(
1918
AnyDistributedComponent,
2019
num_nodes=2,
21-
cloud_compute=compute,
20+
cloud_compute=L.CloudCompute("gpu"),
2221
)
2322
)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
3+
import lightning as L
4+
from lightning.app.components import LiteMultiNode
5+
from lightning.lite import LightningLite
6+
7+
8+
class LitePyTorchDistributed(L.LightningWork):
9+
@staticmethod
10+
def run():
11+
# 1. Create LightningLite.
12+
lite = LightningLite(strategy="ddp", precision="bf16")
13+
14+
# 2. Prepare distributed model and optimizer.
15+
model = torch.nn.Linear(32, 2)
16+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
17+
model, optimizer = lite.setup(model, optimizer)
18+
criterion = torch.nn.MSELoss()
19+
20+
# 3. Train the model for 50 steps.
21+
for step in range(50):
22+
model.zero_grad()
23+
x = torch.randn(64, 32).to(lite.device)
24+
output = model(x)
25+
loss = criterion(output, torch.ones_like(output))
26+
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
27+
lite.backward(loss)
28+
optimizer.step()
29+
30+
31+
# Run over 2 nodes of 4 x V100
32+
app = L.LightningApp(
33+
LiteMultiNode(
34+
LitePyTorchDistributed,
35+
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
36+
num_nodes=2,
37+
)
38+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import lightning as L
2+
from lightning.app.components import PyTorchLightningMultiNode
3+
from lightning.pytorch.demos.boring_classes import BoringModel
4+
5+
6+
class PyTorchLightningDistributed(L.LightningWork):
7+
@staticmethod
8+
def run():
9+
model = BoringModel()
10+
trainer = L.Trainer(
11+
max_epochs=10,
12+
strategy="ddp",
13+
)
14+
trainer.fit(model)
15+
16+
17+
# Run over 2 nodes of 4 x V100
18+
app = L.LightningApp(
19+
PyTorchLightningMultiNode(
20+
PyTorchLightningDistributed,
21+
num_nodes=2,
22+
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
23+
)
24+
)

examples/app_multi_node/app_pl_script.py renamed to examples/app_multi_node/train_pl_script.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from lightning.app.components import LightningTrainingComponent
33
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
44

5+
# Run over 2 nodes of 4 x V100
56
app = L.LightningApp(
67
LightningTrainingComponent(
7-
"train.py",
8+
"pl_boring_script.py",
89
num_nodes=2,
910
cloud_compute=CloudCompute("gpu-fast-multi"),
1011
),

examples/app_multi_node/app_torch_work.py renamed to examples/app_multi_node/train_pytorch.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no
3838
loss.backward()
3939
optimizer.step()
4040

41-
# 5. Verify all processes have the same weights at the end of training.
42-
weight = model.module.weight.clone()
43-
torch.distributed.all_reduce(weight)
44-
assert torch.equal(model.module.weight, weight / world_size)
45-
46-
print("Multi Node Distributed Training Done!")
47-
4841

4942
class PyTorchDistributed(L.LightningWork):
5043
def run(
@@ -60,11 +53,11 @@ def run(
6053
)
6154

6255

63-
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
56+
# Run over 2 nodes of 4 x V100
6457
app = L.LightningApp(
6558
MultiNode(
6659
PyTorchDistributed,
6760
num_nodes=2,
68-
cloud_compute=compute,
61+
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
6962
)
7063
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
from torch.nn.parallel.distributed import DistributedDataParallel
3+
4+
import lightning as L
5+
from lightning.app.components import PyTorchSpawnMultiNode
6+
7+
8+
class PyTorchDistributed(L.LightningWork):
9+
10+
# Note: Only staticmethod are support for now with `PyTorchSpawnMultiNode`
11+
@staticmethod
12+
def run(
13+
world_size: int,
14+
node_rank: int,
15+
global_rank: str,
16+
local_rank: int,
17+
):
18+
# 1. Prepare distributed model
19+
model = torch.nn.Linear(32, 2)
20+
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
21+
device_ids = device if torch.cuda.is_available() else None
22+
model = DistributedDataParallel(model, device_ids=device_ids).to(device)
23+
24+
# 2. Prepare loss and optimizer
25+
criterion = torch.nn.MSELoss()
26+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
27+
28+
# 3. Train the model for 50 steps.
29+
for step in range(50):
30+
model.zero_grad()
31+
x = torch.randn(64, 32).to(device)
32+
output = model(x)
33+
loss = criterion(output, torch.ones_like(output))
34+
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
35+
loss.backward()
36+
optimizer.step()
37+
38+
39+
# Run over 2 nodes of 4 x V100
40+
app = L.LightningApp(
41+
PyTorchSpawnMultiNode(
42+
PyTorchDistributed,
43+
num_nodes=2,
44+
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
45+
)
46+
)

0 commit comments

Comments
 (0)