Skip to content

Multi-node training with FSDP results in weird behaviour #18008

@srikanthsrnvs

Description

@srikanthsrnvs

Bug description

Here is my train script for llama 65b

import os
import sys
import time
from functools import partial
from pathlib import Path

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from prepare import (calculate_training_steps, get_batch, load_datasets,
                     prepare, validate)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy, size_based_auto_wrap_policy
from torch.nn.utils.init import skip_init
from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_llama.model import Block, LLaMA, LLaMAConfig
from lit_llama.utils import save_model_checkpoint

# compilation fails as it does not support torch.complex64 for RoPE
# compile = False

beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0



def train(
    output_dir: str,
    data_dir: str,
    eval_interval: int = 2000,
    eval_iterations: int = 200,
    log_interval: int = 10,
    learning_rate: float = 6e-4,
    batch_size: int = 1,
    epochs: int = 3,
    context_length: int = 4096,
    model_size: str = "65B",
    weight_decay: float = 1e-1,
):
    prepped_data_dir = prepare(
        input_file=data_dir,
        tokenizer_path=f"/scratch/data/models/llama-{model_size.lower()}/tokenizer.model"
    )
    train_data, val_data = load_datasets(prepped_data_dir)
    train_steps = calculate_training_steps(train_data, batch_size, epochs, context_length)

    print("Total number of training iterations: ", train_steps)

    ############################# TRAINING PREP #########################################
    auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
    strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True, cpu_offload=True, precision="16")

    fabric = L.Fabric(precision="bf16-mixed", strategy=strategy)

    # fabric.launch()
    fabric.seed_everything(1337 + fabric.global_rank)

    if fabric.global_rank == 0:
        os.makedirs(output_dir, exist_ok=True)
    ######################################################################################

    config = LLaMAConfig.from_name(model_size)
    config.block_size = context_length

    with fabric.init_module(empty_init=True):
        print("Loading model...")
        model = LLaMA(config).bfloat16()
        # Main machine, no need to load the weights again
        print("Loading checkpoint...")
        checkpoint = torch.load(f'/scratch/data/models/llama-{model_size.lower()}/lit-llama.pth')
        print("Loading state dict...")
        model.load_state_dict(checkpoint, strict=False)

    fabric.barrier()
    print("Setting up modules...")
    model = fabric.setup_module(model)

    print("Setting up the optimizers...")
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)
    optimizer = fabric.setup_optimizers(optimizer)

    loop(
        fabric,
        model,
        optimizer,
        train_data,
        val_data,
        eval_interval,
        eval_iterations,
        output_dir,
        batch_size,
        log_interval,
        train_steps
    )

############################### MAIN TRAINING LOOP ########################################
"""The training loop.

Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
def loop(
    fabric,
    model,
    optimizer,
    train_data,
    val_data,
    eval_interval,
    eval_iterations,
    output_dir,
    batch_size,
    log_interval,
    train_steps  
):
    iter_num = 0

    print("Beginning training loop ---------->")
    while True:
        # TODO: add learning rate scheduling

        # evaluate the loss on train/val sets and write checkpoints
        if iter_num > 0 and iter_num % eval_interval == 0:
            val_loss = validate(fabric, model, val_data, eval_iterations)
            fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
            fabric.print(f"Saving checkpoint to {output_dir}")
            save_model_checkpoint(fabric, model, os.path.join(output_dir, f"iter-{iter_num:06d}-ckpt.pth"))

        t0 = time.time()

        input_ids, targets = get_batch(
            fabric,
            train_data,
            block_size=model.config.block_size,  # type: ignore[union-attr,arg-type]
            batch_size=batch_size
        )
        logits = model(input_ids)
        loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        fabric.backward(loss)

        # TODO: Gradient clipping
        # if grad_clip != 0.0:
        #     fabric.clip_gradients(model, optimizer, max_norm=grad_clip)

        optimizer.step()
        optimizer.zero_grad()

        dt = time.time() - t0
        if iter_num % log_interval == 0:
            fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
        iter_num += 1

        if iter_num > train_steps:
            save_model_checkpoint(fabric, model, os.path.join(output_dir, f"iter-{iter_num:06d}-ckpt.pth"))
            break

I run this script using:

lightning run model \
    --main-port=29415 \
    --node-rank=0 \
    --main-address=IP_ADDRESS \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=2 \
    train.py \
        --ARGS...

on two separate machines with a different node rank. I'm seeing two bugs, although it just might be me not seeing something

  1. When training the 30B model, the weights get initialized properly on the master Node, but on the child node, all the weights are dumped into GPU 1, which gets an OOM. No other GPUs are used. This happens in the Llama(config) step.
  2. When training the 65B model, the weights are initialized on node 1, but there are only 8 GPUs with 80gb each (640GB) which is not enough to load llama, and I get an OOM again. I need to shard this across both the GPUs and the Nodes, and it seems to load a copy of llama on each node rather than one llama copy and shard across the nodes.

Is there something I'm missing or is this a bug? Thanks!

What version are you seeing the problem on?

v2.0

How to reproduce the bug

See above code

Error messages and logs

CUDA oom

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingfabriclightning.fabric.Fabricver: 2.0.x

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions