-
Notifications
You must be signed in to change notification settings - Fork 52
Description
Hi, BPT and RingAttention are awesome works! Thanks a lot for open-sourcing the code.
I have a question about the 2nd equation in the following snapshot taken from the paper. I am having difficulty deriving the LHS and RHS as equal.

-
Should the scaling factor be
$$\exp(max(Q_i K_j^T) - \max_i)$$ instead of$$\exp(Q_i K_j^T - \max_i)$$ , i.e., the maximum symbol is missing? -
Even fixing the above issue, following Online Normalizer Calculation for Softmax, should the scaling factor be applied to both the numerator and the denominator as done in both the pseudo-code in the paper (L43-45) as well as the implementation in the following?
ringattention/ringattention/ringattention_jax.py
Lines 142 to 144 in aef108a
correction = rearrange(jnp.exp(prev_max_score_chunk - max_score_chunk), 'b h q -> b q h')[..., None] numerator_chunk = numerator_chunk * correction + exp_values denominator_chunk = denominator_chunk * jnp.exp(prev_max_score_chunk - max_score_chunk) + exp_weights.sum(axis=-1)
But I may miss something in the paper. Any guidance would be much appreciated.
Thanks a lot in advance.