Skip to content

Commit 2b0bd92

Browse files
dmahan93bclyang
authored andcommitted
Add context parallelism support
1 parent a6d6af0 commit 2b0bd92

File tree

15 files changed

+406
-69
lines changed

15 files changed

+406
-69
lines changed

megatron/data/data_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,12 @@ def build_train_valid_test_data_loaders(neox_args):
531531
else:
532532
pipe_load = True
533533

534-
# Data loader only on rank 0 of each model parallel group.
535-
if mpu.get_model_parallel_rank() == 0 and pipe_load:
534+
# Data loader only on rank 0 of each model and context parallel group.
535+
if (
536+
mpu.get_model_parallel_rank() == 0
537+
and pipe_load
538+
and mpu.get_context_parallel_rank() == 0
539+
):
536540
# Number of train/valid/test samples.
537541
if neox_args.train_iters is not None:
538542
train_iters = neox_args.train_iters
@@ -671,11 +675,17 @@ def build_train_valid_test_data_loaders(neox_args):
671675
# broadcast globally instead of just the model parallel group.
672676
torch.distributed.broadcast(flags, src=0)
673677
else:
678+
# The same data should be used for the model parallel and context parallel groups
674679
torch.distributed.broadcast(
675680
flags,
676681
mpu.get_model_parallel_src_rank(),
677682
group=mpu.get_model_parallel_group(),
678683
)
684+
torch.distributed.broadcast(
685+
flags,
686+
mpu.get_context_parallel_src_rank(),
687+
group=mpu.get_context_parallel_group(),
688+
)
679689
neox_args.do_train = flags[0].item()
680690
neox_args.do_valid = flags[1].item()
681691
neox_args.do_test = flags[2].item()

megatron/initialize.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,20 @@ def _initialize_distributed(neox_args):
158158
# Setup 3D topology.
159159
pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
160160
mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
161+
cp = neox_args.context_parallel_size if neox_args.context_parallel_size >= 1 else 1
162+
assert (
163+
neox_args.world_size % (pp * mp * cp) == 0
164+
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}, cp={cp}"
161165
assert (
162166
neox_args.world_size % (pp * mp) == 0
163167
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
168+
# The data parallel ranks will be used for context parallel
169+
# to piggy back the gradient all reduce
164170
dp = neox_args.world_size // (pp * mp)
171+
assert dp % cp == 0
172+
from deepspeed.runtime.pipe.topology import ProcessTopology
165173

166-
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
167-
168-
# this does pipe on the most outside, then data, then model.
169-
# PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order.
170-
topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)
174+
topo = ProcessTopology(axes=["pipe", "data", "model"], dims=[pp, dp, mp])
171175

172176
# Offset base seeds for the interior pipeline stages.
173177
# TODO: adjust last stage too once IO is improved.
@@ -186,6 +190,8 @@ def _initialize_distributed(neox_args):
186190
else:
187191
mpu.initialize_model_parallel(
188192
neox_args.model_parallel_size,
193+
neox_args.pipe_parallel_size,
194+
neox_args.context_parallel_size,
189195
topology=topo,
190196
fp32_allreduce=neox_args.fp32_allreduce,
191197
)

megatron/model/fused_layer_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
normalized_shape,
3838
eps=1e-5,
3939
no_persist_layer_norm=True,
40-
sequence_parallel=False,
40+
context_parallel=False,
4141
apply_layernorm_1p=False,
4242
mem_efficient_ln=True,
4343
):
@@ -92,11 +92,11 @@ def __init__(
9292
self.bias = Parameter(torch.Tensor(*normalized_shape))
9393
self.reset_parameters()
9494
self.no_persist_layer_norm = no_persist_layer_norm
95-
self.sequence_parallel = sequence_parallel
95+
self.context_parallel = context_parallel
9696

9797
# set sequence parallelism flag on weight and bias parameters
98-
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
99-
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
98+
setattr(self.weight, "context_parallel", self.context_parallel)
99+
setattr(self.bias, "context_parallel", self.context_parallel)
100100

101101
def reset_parameters(self):
102102

megatron/model/gpt2_model.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,30 @@ def cross_entropy(output, labels, _fp16=False):
7474
else:
7575
losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels)
7676
loss_mask = loss_mask.view(-1)
77-
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
77+
loss_mask_sum = loss_mask.sum()
78+
if mpu.get_context_parallel_world_size() > 1:
79+
dt = loss_mask_sum.dtype
80+
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
81+
loss_mask_sum = loss_mask_sum.float()
82+
torch.distributed.all_reduce(
83+
loss_mask_sum,
84+
op=torch.distributed.ReduceOp.SUM,
85+
group=mpu.get_context_parallel_group(),
86+
)
87+
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
88+
loss_mask_sum = loss_mask_sum.bfloat16()
89+
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
90+
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
91+
loss = loss.float()
92+
torch.distributed.all_reduce(
93+
loss,
94+
op=torch.distributed.ReduceOp.SUM,
95+
group=mpu.get_context_parallel_group(),
96+
)
97+
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
98+
loss = loss.bfloat16()
99+
else:
100+
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
78101
return loss
79102

80103

megatron/model/positional_embeddings.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import math
17+
import megatron.mpu as mpu
1718

1819

1920
class SinusoidalPositionalEmbedding(torch.nn.Module):
@@ -37,7 +38,13 @@ def forward(self, x, seq_dim=1):
3738

3839
class RotaryEmbedding(torch.nn.Module):
3940
def __init__(
40-
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False
41+
self,
42+
dim,
43+
max_seq_len,
44+
base=10000,
45+
precision=torch.half,
46+
save_inv_freqs=False,
47+
zigzag=True,
4148
):
4249
super().__init__()
4350
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
@@ -49,6 +56,7 @@ def __init__(
4956
self.max_seq_len = max_seq_len
5057
self.base = base
5158
self.dim = dim
59+
self.zigzag = zigzag # seq parallel zigzag
5260

5361
# precompute cos_cached, sin_cached in fp32
5462
cos_cached, sin_cached, inv_freq = self._prepare_cache(
@@ -64,6 +72,19 @@ def _prepare_cache(self, seq_len, precision, base):
6472
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim))
6573

6674
t = torch.arange(seq_len).type_as(inv_freq)
75+
if mpu.get_context_parallel_world_size() > 1:
76+
if not self.zigzag:
77+
t_chunks = torch.chunk(t, mpu.get_context_parallel_world_size())
78+
t = t_chunks[mpu.get_context_parallel_rank()].contiguous()
79+
else:
80+
t_chunks = torch.chunk(t, 2 * mpu.get_context_parallel_world_size())
81+
t = torch.cat(
82+
(
83+
t_chunks[mpu.get_context_parallel_rank()],
84+
t_chunks[-(mpu.get_context_parallel_rank() + 1)],
85+
),
86+
dim=0,
87+
).contiguous()
6788
freqs = torch.einsum("i,j->ij", t, inv_freq)
6889
emb = torch.cat((freqs, freqs), dim=-1)
6990

megatron/model/transformer.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def __init__(
452452
self.rope_fusion = neox_args.rope_fusion
453453
self.attention_type = neox_args.attention_config[layer_number]
454454
self.use_flash_attention = self.attention_type == "flash"
455+
self.use_ring_attention = self.attention_type == "ring"
455456
self.use_triton = (
456457
self.use_flash_attention
457458
and self.pos_emb == "alibi"
@@ -460,7 +461,7 @@ def __init__(
460461
>= packaging.version.Version("2.4.0.post1")
461462
)
462463
)
463-
self.sparse = self.attention_type not in ("global", "flash")
464+
self.sparse = self.attention_type not in ("global", "flash", "ring")
464465

465466
if self.gqa:
466467
assert not self.sparse
@@ -489,6 +490,12 @@ def __init__(
489490
self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
490491
self.flash_qkv_fn = flash_attn_func
491492
self.flash_varlen_qkv_fn = flash_attn_varlen_func
493+
elif self.use_ring_attention:
494+
from ring_flash_attn.zigzag_ring_flash_attn import (
495+
zigzag_ring_flash_attn_func,
496+
)
497+
498+
self.ring_attn_fn = zigzag_ring_flash_attn_func
492499
else:
493500
self.scale_mask_softmax = FusedScaleMaskSoftmax(
494501
input_in_fp16=self.fp16,
@@ -736,6 +743,96 @@ def flash_attention(self, query_layer, key_layer, value_layer):
736743

737744
return matmul_result
738745

746+
def ring_attention(self, query_layer, key_layer, value_layer):
747+
# [b, np, sq, sk]
748+
output_size = (
749+
query_layer.size(1),
750+
query_layer.size(2),
751+
query_layer.size(0),
752+
key_layer.size(0),
753+
)
754+
755+
# [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
756+
key_layer = key_layer.transpose(0, 1).reshape(
757+
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
758+
)
759+
value_layer = value_layer.transpose(0, 1).reshape(
760+
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
761+
)
762+
763+
# [sq, b, np, hn] -> [b, sq, np, hn]
764+
query_layer = query_layer.transpose(0, 1).reshape(
765+
output_size[0], output_size[2], output_size[1], -1
766+
)
767+
768+
# only pass in window_size or alibi_slopes kwarg
769+
# if we use Sliding Window Attention / AliBi.
770+
# Flash attn defaults to (-1,-1), or
771+
# does not have this kwarg prior to v2.3.0
772+
extra_kwargs = (
773+
{"window_size": (self.sliding_window_width, -1)}
774+
if self.sliding_window_width is not None
775+
else {}
776+
)
777+
if self.pos_emb == "alibi":
778+
extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to(
779+
query_layer.device
780+
).to(torch.float32)
781+
782+
if not self.training:
783+
batch_size = output_size[0]
784+
max_seqlen_q = output_size[2]
785+
max_seqlen_k = output_size[3]
786+
787+
cu_seqlens_q = torch.arange(
788+
0,
789+
(batch_size + 1) * max_seqlen_q,
790+
step=max_seqlen_q,
791+
dtype=torch.int32,
792+
device=query_layer.device,
793+
)
794+
795+
cu_seqlens_k = torch.arange(
796+
0,
797+
(batch_size + 1) * max_seqlen_k,
798+
step=max_seqlen_k,
799+
dtype=torch.int32,
800+
device=key_layer.device,
801+
)
802+
803+
q_shape = query_layer.shape
804+
k_shape = key_layer.shape
805+
v_shape = value_layer.shape
806+
is_causal = max_seqlen_q == max_seqlen_k
807+
output = self.ring_attn_fn(
808+
query_layer,
809+
key_layer,
810+
value_layer,
811+
0.0,
812+
softmax_scale=None,
813+
causal=is_causal,
814+
group=mpu.get_context_parallel_group(),
815+
**extra_kwargs,
816+
)
817+
output = output.reshape(q_shape)
818+
else:
819+
output = self.ring_attn_fn(
820+
query_layer,
821+
key_layer,
822+
value_layer,
823+
self.dropout_p if self.training else 0.0,
824+
softmax_scale=None,
825+
causal=True,
826+
group=mpu.get_context_parallel_group(),
827+
**extra_kwargs,
828+
)
829+
830+
matmul_result = output
831+
# [b, sq, np, hn] -> [b, np, sq, hn]
832+
matmul_result = matmul_result.transpose(1, 2)
833+
834+
return matmul_result
835+
739836
def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
740837
# TODO: sparse attn dropout?
741838
# TODO: pad to block size
@@ -831,7 +928,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
831928
value_layer = value_layer.view(*new_kv_shape)
832929

833930
# if not using Flash attention, we repeat K/V heads to match Q head counts
834-
if not self.use_flash_attention:
931+
if not (self.use_flash_attention or self.use_ring_attention):
835932
key_layer = torch.repeat_interleave(
836933
key_layer,
837934
repeats=int(
@@ -945,6 +1042,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
9451042

9461043
if self.use_flash_attention:
9471044
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
1045+
elif self.use_ring_attention:
1046+
context_layer = self.ring_attention(query_layer, key_layer, value_layer)
9481047
elif not self.sparse:
9491048
context_layer = self.attention(
9501049
query_layer, key_layer, value_layer, layer_past, attention_mask

megatron/model/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,14 @@ def reduce_weight_grads_from_model_parallel_region(input_):
373373

374374
# Bf16 convert
375375
dt = input_.dtype
376-
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
376+
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
377377
input_ = input_.float()
378378

379379
# All-reduce.
380380
dist.all_reduce(input_, group=mpu.get_model_parallel_group())
381381

382382
# Bf16 convert
383-
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
383+
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
384384
input_ = input_.bfloat16()
385385

386386
return input_

megatron/mpu/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,10 @@
5757

5858
from .utils import divide
5959
from .utils import split_tensor_along_last_dim
60+
from .data import zigzag_data
61+
from .initialize import (
62+
get_context_parallel_group,
63+
get_context_parallel_rank,
64+
get_context_parallel_world_size,
65+
get_context_parallel_src_rank,
66+
)

0 commit comments

Comments
 (0)