Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,15 @@ def test_prepare_inputs():
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
('model', 'embed_tokens')),
])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_registry, mock_get_layers, method, proposer_helper,
draft_model_dir, target_attribute_path):
mock_registry, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path):

# Setup mock for model class
mock_model_cls = mock.MagicMock()
Expand Down Expand Up @@ -158,6 +159,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
# Make mock_get_layers return different values for each call
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]

mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 2 if method == "eagle" else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to cover eagle3?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else fall into eagle3, algin with logic in 177 to 184

mock_get_pp_group.return_value = mock_pp_group

# Setup model loader mock
mock_loader = mock.MagicMock()
mock_get_loader.return_value = mock_loader
Expand Down