Skip to content

[Feature Request] Balancing computation with zigzag blocking #2

@zhuzilin

Description

@zhuzilin

Currently the implementation will split the input sequence into n blocks, e.g. 4 gpu will split into:

b0 | b1 | b2 | b3

however, this will result in uneven calculation, where the gpu that has b3 will do around 4 times more calculation than the gpu that has b0, due to causal attention mask.

If we split the input sequence into 2n blocks, e.g. 4 gpu will split into:

b0,b7 | b1,b6 | b2,b5 | b3,b4

then all gpu will have the same amount of calculation, and theoratically the latency should be decrease by half.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions