|
13 | 13 |
|
14 | 14 | from vllm.attention import AttentionMetadata, get_attn_backend
|
15 | 15 | from vllm.config import VllmConfig
|
| 16 | +from vllm.forward_context import set_forward_context |
16 | 17 | from vllm.logger import init_logger
|
17 | 18 | from vllm.model_executor.layers.sampler import SamplerOutput
|
18 | 19 | from vllm.model_executor.model_loader import get_model
|
@@ -265,8 +266,9 @@ def _dummy_run(
|
265 | 266 | torch._dynamo.mark_dynamic(t, 0)
|
266 | 267 | torch._dynamo.mark_dynamic(p, 0)
|
267 | 268 | # Dummy run.
|
268 |
| - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, |
269 |
| - num_samples, kv_caches) |
| 269 | + with set_forward_context(attn_metadata, self.vllm_config, 0): |
| 270 | + self.model(token_ids, position_ids, attn_metadata, input_lens, t, |
| 271 | + p, num_samples, kv_caches) |
270 | 272 |
|
271 | 273 | def warmup_model(
|
272 | 274 | self,
|
@@ -663,10 +665,13 @@ def execute_model(
|
663 | 665 | input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
664 | 666 | t = model_input.t[i:i + 1].to(self.device)
|
665 | 667 | p = model_input.p[i:i + 1].to(self.device)
|
666 |
| - output_token_ids = self.model(token_ids, position_ids, |
667 |
| - attn_metadata, input_lens, t, p, |
668 |
| - model_input.num_samples, |
669 |
| - kv_caches) |
| 668 | + with set_forward_context(model_input.attn_metadata, |
| 669 | + self.vllm_config, |
| 670 | + model_input.virtual_engine): |
| 671 | + output_token_ids = self.model(token_ids, position_ids, |
| 672 | + attn_metadata, input_lens, t, |
| 673 | + p, model_input.num_samples, |
| 674 | + kv_caches) |
670 | 675 | next_token_ids.append(output_token_ids[0])
|
671 | 676 | start_idx = end_idx
|
672 | 677 |
|
@@ -711,10 +716,13 @@ def execute_model(
|
711 | 716 | input_lens = model_input.input_lens.to(self.device)
|
712 | 717 | for i in range(num_steps):
|
713 | 718 | slot_mapping = attn_metadata.slot_mapping
|
714 |
| - output_token_ids = self.model(token_ids, position_ids, |
715 |
| - attn_metadata, input_lens, t, p, |
716 |
| - model_input.num_samples, |
717 |
| - kv_caches) |
| 719 | + with set_forward_context(model_input.attn_metadata, |
| 720 | + self.vllm_config, |
| 721 | + model_input.virtual_engine): |
| 722 | + output_token_ids = self.model(token_ids, position_ids, |
| 723 | + attn_metadata, input_lens, t, |
| 724 | + p, model_input.num_samples, |
| 725 | + kv_caches) |
718 | 726 | self.cached_step_outputs.append(output_token_ids)
|
719 | 727 |
|
720 | 728 | if i < num_steps - 1:
|
|
0 commit comments