Skip to content

Commit e7d2de4

Browse files
heheda12345Isotr0py
authored andcommitted
[Kernel] unified_attention for Attention.forward (vllm-project#11967)
Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent cd80332 commit e7d2de4

File tree

10 files changed

+87
-45
lines changed

10 files changed

+87
-45
lines changed

vllm/attention/layer.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,10 @@ def forward(
134134
query: torch.Tensor,
135135
key: torch.Tensor,
136136
value: torch.Tensor,
137-
kv_cache: torch.Tensor,
138-
attn_metadata: AttentionMetadata,
137+
_kv_cache: torch.Tensor,
138+
_attn_metadata: AttentionMetadata,
139139
) -> torch.Tensor:
140-
141-
if self.use_direct_call:
142-
return self.impl.forward(query, key, value, kv_cache,
143-
attn_metadata, self._k_scale,
144-
self._v_scale)
145-
elif self.use_output:
140+
if self.use_output:
146141
output = torch.empty_like(query)
147142
hidden_size = query.size(-1)
148143
# Reshape the query, key, and value tensors.
@@ -154,12 +149,19 @@ def forward(
154149
key = key.view(-1, self.num_kv_heads, self.head_size)
155150
if value is not None:
156151
value = value.view(-1, self.num_kv_heads, self.head_size)
157-
torch.ops.vllm.unified_attention_with_output(
158-
query, key, value, output, self.layer_name)
152+
if self.use_direct_call:
153+
unified_attention_with_output(query, key, value, output,
154+
self.layer_name)
155+
else:
156+
torch.ops.vllm.unified_attention_with_output(
157+
query, key, value, output, self.layer_name)
159158
return output.view(-1, hidden_size)
160159
else:
161-
return torch.ops.vllm.unified_attention(query, key, value,
162-
self.layer_name)
160+
if self.use_direct_call:
161+
return unified_attention(query, key, value, self.layer_name)
162+
else:
163+
return torch.ops.vllm.unified_attention(
164+
query, key, value, self.layer_name)
163165

164166
def extra_repr(self) -> str:
165167
s = f"head_size={self.impl.head_size}" # type: ignore

vllm/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2171,5 +2171,4 @@ def bind_kv_cache(
21712171
forward_ctx = ctx[layer_name]
21722172
assert len(forward_ctx.kv_cache) == len(kv_cache)
21732173
for ve, ve_kv_cache in enumerate(kv_cache):
2174-
assert forward_ctx.kv_cache[ve].numel() == 0
21752174
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]

vllm/worker/hpu_model_runner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vllm.attention import AttentionMetadata, get_attn_backend
2929
from vllm.config import DeviceConfig, VllmConfig
3030
from vllm.distributed.parallel_state import get_world_group
31+
from vllm.forward_context import set_forward_context
3132
from vllm.logger import init_logger
3233
from vllm.lora.layers import LoRAMapping
3334
from vllm.lora.request import LoRARequest
@@ -40,7 +41,8 @@
4041
from vllm.sampling_params import SamplingParams
4142
from vllm.sequence import (IntermediateTensors, SequenceData,
4243
SequenceGroupMetadata)
43-
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
44+
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
45+
make_tensor_with_pad)
4446
from vllm.worker.model_runner_base import (
4547
ModelRunnerBase, ModelRunnerInputBase,
4648
_add_attn_metadata_broadcastable_dict,
@@ -1286,6 +1288,9 @@ def create_dummy_seq_group_metadata(self,
12861288
def profile_run(self) -> None:
12871289
num_layers = self.model_config.get_num_layers(self.parallel_config)
12881290
kv_caches = [None] * num_layers
1291+
bind_kv_cache(
1292+
self.vllm_config.compilation_config.static_forward_context,
1293+
[kv_caches])
12891294
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
12901295
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
12911296
self.scheduler_config.max_num_seqs)
@@ -1943,7 +1948,11 @@ def execute_model(
19431948
f"graphs{'T' if use_graphs else 'F'}")
19441949
else:
19451950
model_event_name = 'model_executable'
1946-
with self.profiler.record_event('internal', model_event_name):
1951+
with set_forward_context(
1952+
model_input.attn_metadata, self.vllm_config,
1953+
model_input.virtual_engine), \
1954+
self.profiler.record_event(
1955+
'internal', model_event_name):
19471956
hidden_states = self.model.forward(
19481957
**execute_model_kwargs,
19491958
selected_token_indices=sampling_metadata.selected_token_indices

vllm/worker/hpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.model_executor import set_random_seed
2121
from vllm.prompt_adapter.request import PromptAdapterRequest
2222
from vllm.sequence import ExecuteModelRequest
23+
from vllm.utils import bind_kv_cache
2324
from vllm.worker.cache_engine import CacheEngine
2425
from vllm.worker.hpu_model_runner import HPUModelRunner
2526
from vllm.worker.model_runner_base import ModelRunnerBase
@@ -215,6 +216,8 @@ def _init_cache_engine(self):
215216
self.cache_engine[ve].gpu_cache
216217
for ve in range(self.parallel_config.pipeline_parallel_size)
217218
]
219+
bind_kv_cache(self.compilation_config.static_forward_context,
220+
self.hpu_cache)
218221

219222
def _warm_up_model(self) -> None:
220223
# NOTE(kzawora): We should use virtual engine index here

vllm/worker/neuron_model_runner.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers_neuronx.config import GenerationConfig
99

1010
from vllm.config import VllmConfig
11+
from vllm.forward_context import set_forward_context
1112
from vllm.logger import init_logger
1213
from vllm.model_executor import SamplingMetadata
1314
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -314,13 +315,15 @@ def execute_model(
314315
raise ValueError(
315316
"NeuronModelRunner does not support multi-step execution.")
316317

317-
hidden_states = self.model(
318-
input_ids=model_input.input_tokens,
319-
positions=model_input.input_positions,
320-
input_block_ids=model_input.input_block_ids,
321-
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
322-
device=self.device),
323-
)
318+
with set_forward_context(None, self.vllm_config, 0):
319+
hidden_states = self.model(
320+
input_ids=model_input.input_tokens,
321+
positions=model_input.input_positions,
322+
input_block_ids=model_input.input_block_ids,
323+
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
324+
or {},
325+
device=self.device),
326+
)
324327

325328
# Compute the logits only if the on-device sampling is turned off as
326329
# on-device sampling outputs the token ids.

vllm/worker/openvino_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.attention import get_attn_backend
99
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
1010
from vllm.config import VllmConfig
11+
from vllm.forward_context import set_forward_context
1112
from vllm.logger import init_logger
1213
from vllm.model_executor import SamplingMetadata
1314
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -350,7 +351,8 @@ def execute_model(
350351
device=self.device),
351352
}
352353

353-
hidden_states = model_executable(**execute_model_kwargs)
354+
with set_forward_context(attn_metadata, self.vllm_config, 0):
355+
hidden_states = model_executable(**execute_model_kwargs)
354356

355357
# Compute the logits.
356358
logits = self.model.compute_logits(hidden_states, sampling_metadata)

vllm/worker/openvino_worker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.platforms import current_platform
2121
from vllm.sampling_params import SamplingParams
2222
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
23+
from vllm.utils import bind_kv_cache
2324
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
2425
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
2526

@@ -339,6 +340,8 @@ def _init_cache_engine(self) -> None:
339340
ov_device,
340341
)
341342
self.kv_cache = self.cache_engine.kv_cache
343+
bind_kv_cache(self.compilation_config.static_forward_context,
344+
[self.kv_cache])
342345
self.model_runner.block_size = self.cache_engine.block_size
343346

344347
assert self.kv_cache is not None
@@ -507,12 +510,18 @@ def model_profile_run():
507510

508511
self.model_runner.block_size = tmp_cache_config.block_size
509512

513+
bind_kv_cache(self.compilation_config.static_forward_context,
514+
profiling_cache_engine.kv_cache)
510515
# Run the model with the dummy inputs.
511516
self.model_runner.execute_model(seqs,
512517
profiling_cache_engine.kv_cache)
513518

514-
# explicitly delete temporary KV cache manager to free KV cache
515-
# when real inputs will be passed to OV
519+
# Explicitly revert bind_kv_cache and delete temporary KV cache
520+
# manager to free KV cache when real inputs will be passed to OV
521+
bind_kv_cache(self.compilation_config.static_forward_context, [[
522+
torch.tensor([])
523+
for _ in range(len(profiling_cache_engine.kv_cache))
524+
]])
516525
del profiling_cache_engine
517526

518527
logger.info(

vllm/worker/tpu_model_runner.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from vllm.attention import AttentionMetadata, get_attn_backend
1515
from vllm.config import VllmConfig
16+
from vllm.forward_context import set_forward_context
1617
from vllm.logger import init_logger
1718
from vllm.model_executor.layers.sampler import SamplerOutput
1819
from vllm.model_executor.model_loader import get_model
@@ -265,8 +266,9 @@ def _dummy_run(
265266
torch._dynamo.mark_dynamic(t, 0)
266267
torch._dynamo.mark_dynamic(p, 0)
267268
# Dummy run.
268-
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
269-
num_samples, kv_caches)
269+
with set_forward_context(attn_metadata, self.vllm_config, 0):
270+
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
271+
p, num_samples, kv_caches)
270272

271273
def warmup_model(
272274
self,
@@ -663,10 +665,13 @@ def execute_model(
663665
input_lens = model_input.input_lens[i:i + 1].to(self.device)
664666
t = model_input.t[i:i + 1].to(self.device)
665667
p = model_input.p[i:i + 1].to(self.device)
666-
output_token_ids = self.model(token_ids, position_ids,
667-
attn_metadata, input_lens, t, p,
668-
model_input.num_samples,
669-
kv_caches)
668+
with set_forward_context(model_input.attn_metadata,
669+
self.vllm_config,
670+
model_input.virtual_engine):
671+
output_token_ids = self.model(token_ids, position_ids,
672+
attn_metadata, input_lens, t,
673+
p, model_input.num_samples,
674+
kv_caches)
670675
next_token_ids.append(output_token_ids[0])
671676
start_idx = end_idx
672677

@@ -711,10 +716,13 @@ def execute_model(
711716
input_lens = model_input.input_lens.to(self.device)
712717
for i in range(num_steps):
713718
slot_mapping = attn_metadata.slot_mapping
714-
output_token_ids = self.model(token_ids, position_ids,
715-
attn_metadata, input_lens, t, p,
716-
model_input.num_samples,
717-
kv_caches)
719+
with set_forward_context(model_input.attn_metadata,
720+
self.vllm_config,
721+
model_input.virtual_engine):
722+
output_token_ids = self.model(token_ids, position_ids,
723+
attn_metadata, input_lens, t,
724+
p, model_input.num_samples,
725+
kv_caches)
718726
self.cached_step_outputs.append(output_token_ids)
719727

720728
if i < num_steps - 1:

vllm/worker/tpu_worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.logger import init_logger
1313
from vllm.model_executor import set_random_seed
1414
from vllm.sequence import ExecuteModelRequest
15-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
15+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
1616
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
1717
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
1818
LoraNotSupportedWorkerBase, WorkerBase,
@@ -108,6 +108,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
108108
torch.tensor([], dtype=torch.float32,
109109
device=self.device))
110110
for _ in range(num_layers)]
111+
bind_kv_cache(self.compilation_config.static_forward_context,
112+
[kv_caches])
111113
self.model_runner._dummy_run(
112114
batch_size=1,
113115
seq_len=self.scheduler_config.max_num_batched_tokens,
@@ -170,6 +172,8 @@ def initialize_cache(
170172
device="cpu")
171173
cpu_v_cache = torch.zeros_like(cpu_k_cache)
172174
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
175+
bind_kv_cache(self.compilation_config.static_forward_context,
176+
[self.tpu_cache])
173177
self._warmup_model()
174178

175179
def _warmup_model(self) -> None:

vllm/worker/xpu_model_runner.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.attention import get_attn_backend
1313
from vllm.config import VllmConfig
1414
from vllm.distributed import get_pp_group
15+
from vllm.forward_context import set_forward_context
1516
from vllm.inputs import INPUT_REGISTRY, InputRegistry
1617
from vllm.logger import init_logger
1718
from vllm.model_executor import SamplingMetadataCache
@@ -562,15 +563,17 @@ def execute_model(
562563
if (self.observability_config is not None
563564
and self.observability_config.collect_model_forward_time):
564565
model_forward_start_time = time.time()
565-
566-
hidden_or_intermediate_states = model_executable(
567-
input_ids=model_input.input_tokens,
568-
positions=model_input.input_positions,
569-
kv_caches=kv_caches,
570-
attn_metadata=model_input.attn_metadata,
571-
intermediate_tensors=intermediate_tensors,
572-
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
573-
device=self.device))
566+
with set_forward_context(model_input.attn_metadata, self.vllm_config,
567+
model_input.virtual_engine):
568+
hidden_or_intermediate_states = model_executable(
569+
input_ids=model_input.input_tokens,
570+
positions=model_input.input_positions,
571+
kv_caches=kv_caches,
572+
attn_metadata=model_input.attn_metadata,
573+
intermediate_tensors=intermediate_tensors,
574+
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
575+
or {},
576+
device=self.device))
574577
# Compute the logits in the last pipeline stage.
575578
if not get_pp_group().is_last_rank:
576579
return hidden_or_intermediate_states

0 commit comments

Comments
 (0)