Skip to content

Commit e88e0fe

Browse files
tchatonBorda
authored andcommitted
[App] Fixed Multi Node and add examples (#15557)
(cherry picked from commit 8202331)
1 parent ea136f5 commit e88e0fe

File tree

15 files changed

+298
-57
lines changed

15 files changed

+298
-57
lines changed

examples/app_multi_node/README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Lightning & Multi Node Training
2+
3+
Lightning supports makes multi-node training simple by providing a simple interface to orchestrate compute and data.
4+
5+
## Multi Node with raw PyTorch
6+
7+
You can run the multi-node raw PyTorch by running the following commands.
8+
9+
```bash
10+
lightning run app app_torch_work.py
11+
```
12+
13+
## Multi Node with raw PyTorch + Lite
14+
15+
You can run the multi-node raw PyTorch and Lite by running the following commands.
16+
17+
```bash
18+
lightning run app app_lite_work.py
19+
```
20+
21+
## Multi Node with PyTorch Lightning
22+
23+
Lightning supports running PyTorch Lightning from a script or within a Lightning Work.
24+
25+
### Multi Node PyTorch Lightning Script
26+
27+
```bash
28+
lightning run app app_pl_script.py
29+
```
30+
31+
### Multi Node PyTorch Lightning Work
32+
33+
```bash
34+
lightning run app app_pl_work.py
35+
```
36+
37+
## Multi Node with any frameworks
38+
39+
```bash
40+
lightning run app app_generic_work.py
41+
```
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import lightning.app as L
1+
import lightning as L
22
from lightning.app.components import MultiNode
33

44

@@ -7,16 +7,17 @@ def run(
77
self,
88
main_address: str,
99
main_port: int,
10+
num_nodes: int,
1011
node_rank: int,
1112
):
12-
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
13+
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {num_nodes} {node_rank}.")
1314

1415

1516
compute = L.CloudCompute("gpu")
1617
app = L.LightningApp(
1718
MultiNode(
1819
AnyDistributedComponent,
19-
nodes=2,
20+
num_nodes=2,
2021
cloud_compute=compute,
2122
)
2223
)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
3+
import torch
4+
5+
import lightning as L
6+
from lightning.app.components import MultiNode
7+
from lightning.lite import LightningLite
8+
9+
10+
def distributed_train(lite: LightningLite):
11+
# 1. Prepare distributed model and optimizer
12+
model = torch.nn.Linear(32, 2)
13+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
14+
model, optimizer = lite.setup(model, optimizer)
15+
criterion = torch.nn.MSELoss()
16+
17+
# 2. Train the model for 50 steps.
18+
for step in range(50):
19+
model.zero_grad()
20+
x = torch.randn(64, 32).to(lite.device)
21+
output = model(x)
22+
loss = criterion(output, torch.ones_like(output))
23+
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
24+
lite.backward(loss)
25+
optimizer.step()
26+
27+
# 3. Verify all processes have the same weights at the end of training.
28+
weight = model.module.weight.clone()
29+
torch.distributed.all_reduce(weight)
30+
assert torch.equal(model.module.weight, weight / lite.world_size)
31+
32+
print("Multi Node Distributed Training Done!")
33+
34+
35+
class PyTorchDistributed(L.LightningWork):
36+
def run(
37+
self,
38+
main_address: str,
39+
main_port: int,
40+
num_nodes: int,
41+
node_rank: int,
42+
):
43+
44+
os.environ["MASTER_ADDR"] = main_address
45+
os.environ["MASTER_PORT"] = str(main_port)
46+
os.environ["NODE_RANK"] = str(node_rank)
47+
48+
lite = LightningLite(accelerator="auto", devices="auto", strategy="ddp_spawn", num_nodes=num_nodes)
49+
lite.launch(function=distributed_train)
50+
51+
52+
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
53+
app = L.LightningApp(
54+
MultiNode(
55+
PyTorchDistributed,
56+
num_nodes=2,
57+
cloud_compute=compute,
58+
)
59+
)
File renamed without changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
3+
import lightning as L
4+
from lightning.app.components import MultiNode
5+
from lightning.pytorch.demos.boring_classes import BoringModel
6+
7+
8+
class PyTorchLightningDistributed(L.LightningWork):
9+
def run(
10+
self,
11+
main_address: str,
12+
main_port: int,
13+
num_nodes: int,
14+
node_rank: int,
15+
):
16+
os.environ["MASTER_ADDR"] = main_address
17+
os.environ["MASTER_PORT"] = str(main_port)
18+
os.environ["NODE_RANK"] = str(node_rank)
19+
20+
model = BoringModel()
21+
trainer = L.Trainer(
22+
max_epochs=10,
23+
devices="auto",
24+
accelerator="auto",
25+
num_nodes=num_nodes,
26+
strategy="ddp_spawn", # Only spawn based strategies are supported for now.
27+
)
28+
trainer.fit(model)
29+
30+
31+
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
32+
app = L.LightningApp(
33+
MultiNode(
34+
PyTorchLightningDistributed,
35+
num_nodes=2,
36+
cloud_compute=compute,
37+
)
38+
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
from torch.nn.parallel.distributed import DistributedDataParallel
3+
4+
import lightning as L
5+
from lightning.app.components import MultiNode
6+
7+
8+
def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int):
9+
# 1. Setting distributed environment
10+
global_rank = local_rank + node_rank * nprocs
11+
world_size = num_nodes * nprocs
12+
13+
if torch.distributed.is_available() and not torch.distributed.is_initialized():
14+
torch.distributed.init_process_group(
15+
"nccl" if torch.cuda.is_available() else "gloo",
16+
rank=global_rank,
17+
world_size=world_size,
18+
init_method=f"tcp://{main_address}:{main_port}",
19+
)
20+
21+
# 2. Prepare distributed model
22+
model = torch.nn.Linear(32, 2)
23+
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
24+
device_ids = device if torch.cuda.is_available() else None
25+
model = DistributedDataParallel(model, device_ids=device_ids).to(device)
26+
27+
# 3. Prepare loss and optimizer
28+
criterion = torch.nn.MSELoss()
29+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
30+
31+
# 4. Train the model for 50 steps.
32+
for step in range(50):
33+
model.zero_grad()
34+
x = torch.randn(64, 32).to(device)
35+
output = model(x)
36+
loss = criterion(output, torch.ones_like(output))
37+
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
38+
loss.backward()
39+
optimizer.step()
40+
41+
# 5. Verify all processes have the same weights at the end of training.
42+
weight = model.module.weight.clone()
43+
torch.distributed.all_reduce(weight)
44+
assert torch.equal(model.module.weight, weight / world_size)
45+
46+
print("Multi Node Distributed Training Done!")
47+
48+
49+
class PyTorchDistributed(L.LightningWork):
50+
def run(
51+
self,
52+
main_address: str,
53+
main_port: int,
54+
num_nodes: int,
55+
node_rank: int,
56+
):
57+
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
58+
torch.multiprocessing.spawn(
59+
distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
60+
)
61+
62+
63+
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
64+
app = L.LightningApp(
65+
MultiNode(
66+
PyTorchDistributed,
67+
num_nodes=2,
68+
cloud_compute=compute,
69+
)
70+
)

examples/app_multi_node/bare/.gitignore

Lines changed: 0 additions & 2 deletions
This file was deleted.

examples/app_multi_node/bare/multi_node.py

Lines changed: 0 additions & 36 deletions
This file was deleted.
File renamed without changes.

examples/app_multi_node/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from pytorch_lightning import Trainer
2-
from pytorch_lightning.demos.boring_classes import BoringModel
1+
import lightning as L
2+
from lightning.pytorch.demos.boring_classes import BoringModel
33

44
if __name__ == "__main__":
55
model = BoringModel()
6-
trainer = Trainer(max_epochs=1)
6+
trainer = L.Trainer(max_epochs=1)
77
trainer.fit(model)

0 commit comments

Comments
 (0)