Skip to content

Commit 43de8bc

Browse files
andyxningdjmmoss
authored andcommitted
[Bugfix] fix IntermediateTensors equal method (vllm-project#23027)
Signed-off-by: Andy Xie <[email protected]> Signed-off-by: Duncan Moss <[email protected]>
1 parent 445e353 commit 43de8bc

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

tests/test_sequence.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import pytest
5+
import torch
56

67
from vllm.model_executor.layers.sampler import SamplerOutput
7-
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
8-
SequenceOutput)
8+
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
9+
SequenceData, SequenceOutput)
910

1011
from .core.utils import create_dummy_prompt
1112

@@ -98,3 +99,38 @@ def test_sequence_group_stage():
9899
assert seq_group.is_prefill() is True
99100
seq_group.update_num_computed_tokens(1)
100101
assert seq_group.is_prefill() is False
102+
103+
104+
def test_sequence_intermediate_tensors_equal():
105+
106+
class AnotherIntermediateTensors(IntermediateTensors):
107+
pass
108+
109+
intermediate_tensors = IntermediateTensors({})
110+
another_intermediate_tensors = AnotherIntermediateTensors({})
111+
assert intermediate_tensors != another_intermediate_tensors
112+
113+
empty_intermediate_tensors_1 = IntermediateTensors({})
114+
empty_intermediate_tensors_2 = IntermediateTensors({})
115+
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
116+
117+
different_key_intermediate_tensors_1 = IntermediateTensors(
118+
{"1": torch.zeros([2, 4], dtype=torch.int32)})
119+
difference_key_intermediate_tensors_2 = IntermediateTensors(
120+
{"2": torch.zeros([2, 4], dtype=torch.int32)})
121+
assert (different_key_intermediate_tensors_1
122+
!= difference_key_intermediate_tensors_2)
123+
124+
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
125+
{"1": torch.zeros([2, 4], dtype=torch.int32)})
126+
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
127+
{"1": torch.zeros([2, 5], dtype=torch.int32)})
128+
assert (same_key_different_value_intermediate_tensors_1
129+
!= same_key_different_value_intermediate_tensors_2)
130+
131+
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
132+
{"1": torch.zeros([2, 4], dtype=torch.int32)})
133+
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
134+
{"1": torch.zeros([2, 4], dtype=torch.int32)})
135+
assert (same_key_same_value_intermediate_tensors_1 ==
136+
same_key_same_value_intermediate_tensors_2)

vllm/sequence.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,13 @@ def __len__(self):
11631163
return len(self.tensors)
11641164

11651165
def __eq__(self, other: object):
1166-
return isinstance(other, self.__class__) and self
1166+
if not isinstance(other, self.__class__):
1167+
return False
1168+
if self.tensors.keys() != other.tensors.keys():
1169+
return False
1170+
return all(
1171+
torch.equal(self.tensors[k], other.tensors[k])
1172+
for k in self.tensors)
11671173

11681174
def __repr__(self) -> str:
11691175
return f"IntermediateTensors(tensors={self.tensors})"

0 commit comments

Comments
 (0)