-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
[Quantization] Improve BitsAndBytesModelLoader #20242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,8 +20,6 @@ | |
get_tensor_model_parallel_world_size) | ||
# yapf: enable | ||
from vllm.logger import init_logger | ||
# yapf conflicts with isort for this block | ||
# yapf: disable | ||
from vllm.model_executor.layers.linear import (LinearBase, | ||
MergedColumnParallelLinear, | ||
QKVParallelLinear, | ||
|
@@ -39,6 +37,8 @@ | |
set_weight_attrs) | ||
from vllm.platforms import current_platform | ||
|
||
# yapf conflicts with isort for this block | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
|
@@ -54,11 +54,17 @@ def __init__(self, load_config: LoadConfig): | |
self.unsharded_weights_modules: list[str] = [] | ||
# Save the module names that are sharded by column. | ||
self.column_sharded_weights_modules: list[str] = [] | ||
# Modules whose weights might have fused on disk | ||
# we need their output_sizes to make shard in flight correctly with TP | ||
self.maybe_fused_weights_modules: dict[str, list[int]] = {} | ||
# Store all module names (from transformers) that support | ||
# BNB quantization. | ||
self.target_modules: list[str] = [] | ||
# mapping weight names from transformers to vllm. | ||
self.weight_mapper: Callable = lambda name: name | ||
self.pre_quant: bool = False | ||
self.load_8bit: bool = False | ||
self.is_pool_model: bool = False | ||
|
||
def _get_weight_files( | ||
self, | ||
|
@@ -134,13 +140,14 @@ def _prepare_weights(self, model_name_or_path: str, | |
return hf_weights_files, use_safetensors | ||
|
||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): | ||
def _maybe_pool_model(module_name:str): | ||
|
||
def _maybe_pool_model(module_name: str): | ||
# For pool model, we need to add the prefix `model.` | ||
# for the weight name if possible. | ||
if self.is_pool_model and self.target_modules[0]. \ | ||
startswith("model.") and not module_name.startswith( | ||
"model."): | ||
return "model."+module_name | ||
return "model." + module_name | ||
|
||
return module_name | ||
|
||
|
@@ -159,17 +166,14 @@ def _maybe_pool_model(module_name:str): | |
# mapping weight names from transformers to vllm while preserving | ||
# original names. | ||
mapped_name = self.weight_mapper(org_name) | ||
mapped_name=_maybe_pool_model(mapped_name) | ||
|
||
mapped_name = _maybe_pool_model(mapped_name) | ||
|
||
yield org_name, mapped_name, param | ||
|
||
def _get_quantized_weights_iterator( | ||
self, | ||
model_name_or_path: str, | ||
revision: Optional[str], | ||
pre_quant: bool, | ||
load_8bit: bool, | ||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, | ||
Any]]: | ||
"""Get an iterator to the model weights with bitsandbytes quantization, | ||
|
@@ -192,8 +196,8 @@ def _get_quantized_weights_iterator( | |
|
||
quant_state_dict: dict[str, Any] = {} | ||
|
||
if pre_quant: | ||
if load_8bit: | ||
if self.pre_quant: | ||
if self.load_8bit: | ||
return self._quantized_8bit_generator( | ||
hf_weights_files, use_safetensors, | ||
quant_state_dict), quant_state_dict | ||
|
@@ -390,10 +394,13 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, | |
yield org_weight_name, processed_weight | ||
|
||
def _get_bnb_target_modules(self, model: nn.Module) -> None: | ||
|
||
""" | ||
Identify and collect all modules that support BitsAndBytes | ||
quantization. | ||
""" | ||
for name, module in model.named_modules(): | ||
if (isinstance(module, LinearBase) and | ||
hasattr(module.quant_method, "quant_config")): | ||
if (isinstance(module, LinearBase) | ||
and hasattr(module.quant_method, "quant_config")): | ||
if modules_info := self.modules_mapping.get_sub_modules(name): | ||
# Map vllm's names to transformers's names. | ||
rep_name, sub_modules = modules_info | ||
|
@@ -409,29 +416,11 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: | |
), "vllm currently does not support BNB quantization for" | ||
f" {type(model).__name__}" | ||
|
||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: | ||
if not hasattr(model, "load_weights"): | ||
raise AttributeError( | ||
"The required method 'load_weights' is not defined in class" | ||
f" {type(model).__name__}.") | ||
|
||
if not hasattr(model, "packed_modules_mapping"): | ||
raise AttributeError( | ||
f"Model {type(model).__name__} does not support BitsAndBytes " | ||
"quantization yet. No 'packed_modules_mapping' found.") | ||
self.is_pool_model=is_pooling_model(model) | ||
|
||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) | ||
|
||
# For some models like Molmo, we need to use hf_to_vllm_mapper | ||
# to ensure correct loading of weights. | ||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): | ||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) | ||
|
||
# Modules whose weights might have fused on disk | ||
# we need their output_sizes to make shard in flight correctly with TP | ||
self.maybe_fused_weights_modules: dict[str, list[int]] = {} | ||
self._get_bnb_target_modules(model) | ||
def _classify_module_sharding(self, model: nn.Module): | ||
""" | ||
Categorize modules based on their weight sharding requirements | ||
for tensor parallelism. | ||
""" | ||
Comment on lines
+420
to
+423
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
for name, module in model.named_modules(): | ||
# Some modules like `ReplicatedLinear` should not have their weights | ||
# sharded. The reason for implementing it this way is to avoid new | ||
|
@@ -449,40 +438,71 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: | |
elif isinstance(module, (RowParallelLinear, )): | ||
self.column_sharded_weights_modules.append(name) | ||
|
||
self.model_type = type(model).__name__ | ||
def _verify_model_compatibility(self, model: nn.Module, | ||
model_config: ModelConfig) -> None: | ||
""" | ||
Verify that the model is compatible with BitsAndBytes quantization. | ||
""" | ||
if not hasattr(model, "load_weights"): | ||
raise AttributeError( | ||
"The required method 'load_weights' is not defined in class" | ||
f" {type(model).__name__}.") | ||
|
||
logger.info("Loading weights with BitsAndBytes quantization. " | ||
"May take a while ...") | ||
if not hasattr(model, "packed_modules_mapping"): | ||
raise AttributeError( | ||
f"Model {type(model).__name__} does not support BitsAndBytes " | ||
"quantization yet. No 'packed_modules_mapping' found.") | ||
|
||
quant_config = getattr(model_config.hf_config, "quantization_config", | ||
None) | ||
|
||
pre_quant = False | ||
if quant_config is not None: | ||
quant_method = quant_config.get("quant_method") | ||
if quant_method == "bitsandbytes": | ||
pre_quant = True | ||
self.pre_quant = True | ||
else: | ||
raise ValueError( | ||
f"BitsAndBytes loader does not support {quant_method} " | ||
"quantization") | ||
|
||
# The quant_states in pre_quantized models cannot work with a split | ||
# weight tensor. So TP does not work with pre_quantized bnb models. | ||
if pre_quant and get_tensor_model_parallel_world_size() > 1: | ||
if self.pre_quant and get_tensor_model_parallel_world_size() > 1: | ||
raise ValueError( | ||
"Prequant BitsAndBytes models with tensor parallelism is not " | ||
"supported. Please try with pipeline parallelism.") | ||
if self.pre_quant: | ||
self.load_8bit = quant_config.get("load_in_8bit", False) | ||
|
||
def _initialize_loader_state(self, model: nn.Module, | ||
model_config: ModelConfig) -> None: | ||
""" | ||
Initialize the loader's internal state based on the model and | ||
configuration. | ||
""" | ||
Comment on lines
+478
to
+481
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.is_pool_model = is_pooling_model(model) | ||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) | ||
|
||
load_8bit = False | ||
if pre_quant: | ||
load_8bit = quant_config.get("load_in_8bit", False) | ||
# For some models like Molmo, we need to use hf_to_vllm_mapper | ||
# to ensure correct loading of weights. | ||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): | ||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) | ||
|
||
qweight_iterator, quant_state_dict = ( | ||
self._get_quantized_weights_iterator(model_config.model, | ||
model_config.revision, | ||
pre_quant, load_8bit)) | ||
self._get_bnb_target_modules(model) | ||
self._classify_module_sharding(model) | ||
|
||
def load_weights(self, model: nn.Module, | ||
model_config: ModelConfig) -> None: | ||
|
||
self._verify_model_compatibility(model, model_config) | ||
self._initialize_loader_state(model, model_config) | ||
|
||
logger.info("Loading weights with BitsAndBytes quantization. " | ||
"May take a while ...") | ||
qweight_iterator, quant_state_dict = ( | ||
self._get_quantized_weights_iterator( | ||
model_config.model, | ||
model_config.revision, | ||
)) | ||
weights_to_load = {name for name, _ in model.named_parameters()} | ||
loaded_weights = model.load_weights(qweight_iterator) | ||
# Some models may have weights loading tracker unimplemented. | ||
|
@@ -562,10 +582,11 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: | |
offsets = torch.tensor(offsets).cpu() | ||
set_weight_attrs(param, {"bnb_shard_offsets": offsets}) | ||
|
||
if load_8bit: | ||
if self.load_8bit: | ||
set_weight_attrs( | ||
param, {"matmul_state": [None] * len(quant_states)}) | ||
torch.cuda.empty_cache() | ||
|
||
def download_model(self, model_config: ModelConfig) -> None: | ||
self._prepare_weights(model_config.model, model_config.revision) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring's line breaks can be improved for better readability. Consider formatting it on a single line or adjusting the breaks to avoid awkward mid-sentence splits.
"""Identify and collect all modules that support BitsAndBytes quantization."""