|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 | 4 | import pytest
|
| 5 | +import torch |
5 | 6 |
|
6 | 7 | from vllm.model_executor.layers.sampler import SamplerOutput
|
7 |
| -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, |
8 |
| - SequenceOutput) |
| 8 | +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, |
| 9 | + SequenceData, SequenceOutput) |
9 | 10 |
|
10 | 11 | from .core.utils import create_dummy_prompt
|
11 | 12 |
|
@@ -98,3 +99,38 @@ def test_sequence_group_stage():
|
98 | 99 | assert seq_group.is_prefill() is True
|
99 | 100 | seq_group.update_num_computed_tokens(1)
|
100 | 101 | assert seq_group.is_prefill() is False
|
| 102 | + |
| 103 | + |
| 104 | +def test_sequence_intermediate_tensors_equal(): |
| 105 | + |
| 106 | + class AnotherIntermediateTensors(IntermediateTensors): |
| 107 | + pass |
| 108 | + |
| 109 | + intermediate_tensors = IntermediateTensors({}) |
| 110 | + another_intermediate_tensors = AnotherIntermediateTensors({}) |
| 111 | + assert intermediate_tensors != another_intermediate_tensors |
| 112 | + |
| 113 | + empty_intermediate_tensors_1 = IntermediateTensors({}) |
| 114 | + empty_intermediate_tensors_2 = IntermediateTensors({}) |
| 115 | + assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 |
| 116 | + |
| 117 | + different_key_intermediate_tensors_1 = IntermediateTensors( |
| 118 | + {"1": torch.zeros([2, 4], dtype=torch.int32)}) |
| 119 | + difference_key_intermediate_tensors_2 = IntermediateTensors( |
| 120 | + {"2": torch.zeros([2, 4], dtype=torch.int32)}) |
| 121 | + assert (different_key_intermediate_tensors_1 |
| 122 | + != difference_key_intermediate_tensors_2) |
| 123 | + |
| 124 | + same_key_different_value_intermediate_tensors_1 = IntermediateTensors( |
| 125 | + {"1": torch.zeros([2, 4], dtype=torch.int32)}) |
| 126 | + same_key_different_value_intermediate_tensors_2 = IntermediateTensors( |
| 127 | + {"1": torch.zeros([2, 5], dtype=torch.int32)}) |
| 128 | + assert (same_key_different_value_intermediate_tensors_1 |
| 129 | + != same_key_different_value_intermediate_tensors_2) |
| 130 | + |
| 131 | + same_key_same_value_intermediate_tensors_1 = IntermediateTensors( |
| 132 | + {"1": torch.zeros([2, 4], dtype=torch.int32)}) |
| 133 | + same_key_same_value_intermediate_tensors_2 = IntermediateTensors( |
| 134 | + {"1": torch.zeros([2, 4], dtype=torch.int32)}) |
| 135 | + assert (same_key_same_value_intermediate_tensors_1 == |
| 136 | + same_key_same_value_intermediate_tensors_2) |
0 commit comments