@@ -452,6 +452,7 @@ def __init__(
452
452
self .rope_fusion = neox_args .rope_fusion
453
453
self .attention_type = neox_args .attention_config [layer_number ]
454
454
self .use_flash_attention = self .attention_type == "flash"
455
+ self .use_ring_attention = self .attention_type == "ring"
455
456
self .use_triton = (
456
457
self .use_flash_attention
457
458
and self .pos_emb == "alibi"
@@ -460,7 +461,7 @@ def __init__(
460
461
>= packaging .version .Version ("2.4.0.post1" )
461
462
)
462
463
)
463
- self .sparse = self .attention_type not in ("global" , "flash" )
464
+ self .sparse = self .attention_type not in ("global" , "flash" , "ring" )
464
465
465
466
if self .gqa :
466
467
assert not self .sparse
@@ -489,6 +490,12 @@ def __init__(
489
490
self .flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
490
491
self .flash_qkv_fn = flash_attn_func
491
492
self .flash_varlen_qkv_fn = flash_attn_varlen_func
493
+ elif self .use_ring_attention :
494
+ from ring_flash_attn .zigzag_ring_flash_attn import (
495
+ zigzag_ring_flash_attn_func ,
496
+ )
497
+
498
+ self .ring_attn_fn = zigzag_ring_flash_attn_func
492
499
else :
493
500
self .scale_mask_softmax = FusedScaleMaskSoftmax (
494
501
input_in_fp16 = self .fp16 ,
@@ -736,6 +743,96 @@ def flash_attention(self, query_layer, key_layer, value_layer):
736
743
737
744
return matmul_result
738
745
746
+ def ring_attention (self , query_layer , key_layer , value_layer ):
747
+ # [b, np, sq, sk]
748
+ output_size = (
749
+ query_layer .size (1 ),
750
+ query_layer .size (2 ),
751
+ query_layer .size (0 ),
752
+ key_layer .size (0 ),
753
+ )
754
+
755
+ # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
756
+ key_layer = key_layer .transpose (0 , 1 ).reshape (
757
+ output_size [0 ], output_size [3 ], self .num_kv_heads_per_partition , - 1
758
+ )
759
+ value_layer = value_layer .transpose (0 , 1 ).reshape (
760
+ output_size [0 ], output_size [3 ], self .num_kv_heads_per_partition , - 1
761
+ )
762
+
763
+ # [sq, b, np, hn] -> [b, sq, np, hn]
764
+ query_layer = query_layer .transpose (0 , 1 ).reshape (
765
+ output_size [0 ], output_size [2 ], output_size [1 ], - 1
766
+ )
767
+
768
+ # only pass in window_size or alibi_slopes kwarg
769
+ # if we use Sliding Window Attention / AliBi.
770
+ # Flash attn defaults to (-1,-1), or
771
+ # does not have this kwarg prior to v2.3.0
772
+ extra_kwargs = (
773
+ {"window_size" : (self .sliding_window_width , - 1 )}
774
+ if self .sliding_window_width is not None
775
+ else {}
776
+ )
777
+ if self .pos_emb == "alibi" :
778
+ extra_kwargs ["alibi_slopes" ] = self .alibi_embed .slopes .to (
779
+ query_layer .device
780
+ ).to (torch .float32 )
781
+
782
+ if not self .training :
783
+ batch_size = output_size [0 ]
784
+ max_seqlen_q = output_size [2 ]
785
+ max_seqlen_k = output_size [3 ]
786
+
787
+ cu_seqlens_q = torch .arange (
788
+ 0 ,
789
+ (batch_size + 1 ) * max_seqlen_q ,
790
+ step = max_seqlen_q ,
791
+ dtype = torch .int32 ,
792
+ device = query_layer .device ,
793
+ )
794
+
795
+ cu_seqlens_k = torch .arange (
796
+ 0 ,
797
+ (batch_size + 1 ) * max_seqlen_k ,
798
+ step = max_seqlen_k ,
799
+ dtype = torch .int32 ,
800
+ device = key_layer .device ,
801
+ )
802
+
803
+ q_shape = query_layer .shape
804
+ k_shape = key_layer .shape
805
+ v_shape = value_layer .shape
806
+ is_causal = max_seqlen_q == max_seqlen_k
807
+ output = self .ring_attn_fn (
808
+ query_layer ,
809
+ key_layer ,
810
+ value_layer ,
811
+ 0.0 ,
812
+ softmax_scale = None ,
813
+ causal = is_causal ,
814
+ group = mpu .get_context_parallel_group (),
815
+ ** extra_kwargs ,
816
+ )
817
+ output = output .reshape (q_shape )
818
+ else :
819
+ output = self .ring_attn_fn (
820
+ query_layer ,
821
+ key_layer ,
822
+ value_layer ,
823
+ self .dropout_p if self .training else 0.0 ,
824
+ softmax_scale = None ,
825
+ causal = True ,
826
+ group = mpu .get_context_parallel_group (),
827
+ ** extra_kwargs ,
828
+ )
829
+
830
+ matmul_result = output
831
+ # [b, sq, np, hn] -> [b, np, sq, hn]
832
+ matmul_result = matmul_result .transpose (1 , 2 )
833
+
834
+ return matmul_result
835
+
739
836
def sparse_attention (self , query_layer , key_layer , value_layer , attention_mask ):
740
837
# TODO: sparse attn dropout?
741
838
# TODO: pad to block size
@@ -831,7 +928,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
831
928
value_layer = value_layer .view (* new_kv_shape )
832
929
833
930
# if not using Flash attention, we repeat K/V heads to match Q head counts
834
- if not self .use_flash_attention :
931
+ if not ( self .use_flash_attention or self . use_ring_attention ) :
835
932
key_layer = torch .repeat_interleave (
836
933
key_layer ,
837
934
repeats = int (
@@ -945,6 +1042,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
945
1042
946
1043
if self .use_flash_attention :
947
1044
context_layer = self .flash_attention (query_layer , key_layer , value_layer )
1045
+ elif self .use_ring_attention :
1046
+ context_layer = self .ring_attention (query_layer , key_layer , value_layer )
948
1047
elif not self .sparse :
949
1048
context_layer = self .attention (
950
1049
query_layer , key_layer , value_layer , layer_past , attention_mask
0 commit comments