Skip to content

Commit 862f2ef

Browse files
[XPU] Fix the bug of LoRA logits on the XPU platform (#24081)
Signed-off-by: chzhang <[email protected]>
1 parent 2fd1a40 commit 862f2ef

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

vllm/lora/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def _get_logits(
11511151
lora_logits = lora_logits.mT
11521152
indices_padded = self.punica_wrapper.sampler_indices_padded
11531153

1154-
if current_platform.is_tpu():
1154+
if current_platform.is_tpu() or current_platform.is_xpu():
11551155
indices_padded = indices_padded[:logits.size(0)]
11561156

11571157
lora_logits = (lora_logits.reshape(

vllm/lora/punica_wrapper/punica_xpu.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,13 @@ def add_lora_linear(self,
225225
add_inputs=True,
226226
**kwargs)
227227

228+
@property
229+
def sampler_indices_padded(self) -> torch.Tensor:
230+
"""
231+
This property provides access to padded sampler indices.
232+
"""
233+
return self._sampler_indices_padded[:]
234+
228235
def add_lora_logits(self,
229236
y: torch.Tensor,
230237
x: torch.Tensor,
@@ -259,11 +266,11 @@ def add_lora_logits(self,
259266
buffer = torch.zeros((x.size(0), r),
260267
dtype=torch.float32,
261268
device=x.device)
262-
263-
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
269+
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
270+
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
264271
bgmv_expand(buffer,
265272
lora_b_stacked,
266273
y,
267-
self.sampler_indices,
274+
sampler_indices,
268275
add_inputs=True)
269276
return y.view_as(y_org)

vllm/platforms/xpu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9191
cache_config.block_size = 64
9292

9393
# lazy import to avoid circular import
94-
from vllm.config import CUDAGraphMode
94+
from vllm.config import CompilationLevel, CUDAGraphMode
9595
compilation_config = vllm_config.compilation_config
9696
if compilation_config.cudagraph_mode is None or \
9797
compilation_config.cudagraph_mode.max_cudagraph_mode() \
@@ -100,6 +100,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
100100
"cudagraphs. Fallback to cudagraph_mode=NONE")
101101
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
102102

103+
if vllm_config.lora_config is not None:
104+
compilation_config.level = CompilationLevel.NO_COMPILATION
105+
103106
# check and update parallel config
104107
parallel_config = vllm_config.parallel_config
105108
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"

0 commit comments

Comments
 (0)