Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 72 additions & 51 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
get_tensor_model_parallel_world_size)
# yapf: enable
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -39,6 +37,8 @@
set_weight_attrs)
from vllm.platforms import current_platform

# yapf conflicts with isort for this block

logger = init_logger(__name__)


Expand All @@ -54,11 +54,17 @@ def __init__(self, load_config: LoadConfig):
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: list[str] = []
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
self.pre_quant: bool = False
self.load_8bit: bool = False
self.is_pool_model: bool = False

def _get_weight_files(
self,
Expand Down Expand Up @@ -134,13 +140,14 @@ def _prepare_weights(self, model_name_or_path: str,
return hf_weights_files, use_safetensors

def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name:str):

def _maybe_pool_model(module_name: str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if self.is_pool_model and self.target_modules[0]. \
startswith("model.") and not module_name.startswith(
"model."):
return "model."+module_name
return "model." + module_name

return module_name

Expand All @@ -159,17 +166,14 @@ def _maybe_pool_model(module_name:str):
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
mapped_name=_maybe_pool_model(mapped_name)

mapped_name = _maybe_pool_model(mapped_name)

yield org_name, mapped_name, param

def _get_quantized_weights_iterator(
self,
model_name_or_path: str,
revision: Optional[str],
pre_quant: bool,
load_8bit: bool,
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
Expand All @@ -192,8 +196,8 @@ def _get_quantized_weights_iterator(

quant_state_dict: dict[str, Any] = {}

if pre_quant:
if load_8bit:
if self.pre_quant:
if self.load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
Expand Down Expand Up @@ -390,10 +394,13 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
yield org_weight_name, processed_weight

def _get_bnb_target_modules(self, model: nn.Module) -> None:

"""
Identify and collect all modules that support BitsAndBytes
quantization.
"""
Comment on lines +397 to +400
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This docstring's line breaks can be improved for better readability. Consider formatting it on a single line or adjusting the breaks to avoid awkward mid-sentence splits.

        """Identify and collect all modules that support BitsAndBytes quantization."""

for name, module in model.named_modules():
if (isinstance(module, LinearBase) and
hasattr(module.quant_method, "quant_config")):
if (isinstance(module, LinearBase)
and hasattr(module.quant_method, "quant_config")):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
Expand All @@ -409,29 +416,11 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"

def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")

if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
self.is_pool_model=is_pooling_model(model)

self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))

# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)

# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
self._get_bnb_target_modules(model)
def _classify_module_sharding(self, model: nn.Module):
"""
Categorize modules based on their weight sharding requirements
for tensor parallelism.
"""
Comment on lines +420 to +423
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This docstring's line breaks can be improved for better readability. Consider formatting it on a single line or adjusting the breaks to avoid awkward mid-sentence splits.

        """Categorize modules based on their weight sharding requirements for tensor parallelism."""

for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
Expand All @@ -449,40 +438,71 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)

self.model_type = type(model).__name__
def _verify_model_compatibility(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""
Verify that the model is compatible with BitsAndBytes quantization.
"""
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")

logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")

quant_config = getattr(model_config.hf_config, "quantization_config",
None)

pre_quant = False
if quant_config is not None:
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
pre_quant = True
self.pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
"quantization")

# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism.")
if self.pre_quant:
self.load_8bit = quant_config.get("load_in_8bit", False)

def _initialize_loader_state(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""
Initialize the loader's internal state based on the model and
configuration.
"""
Comment on lines +478 to +481
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This docstring's line breaks can be improved for better readability. Consider formatting it on a single line or adjusting the breaks to avoid awkward mid-sentence splits.

        """Initialize the loader's internal state based on the model and configuration."""

self.is_pool_model = is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))

load_8bit = False
if pre_quant:
load_8bit = quant_config.get("load_in_8bit", False)
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)

qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision,
pre_quant, load_8bit))
self._get_bnb_target_modules(model)
self._classify_module_sharding(model)

def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:

self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)

logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
Expand Down Expand Up @@ -562,10 +582,11 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})

if load_8bit:
if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
torch.cuda.empty_cache()

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

Expand Down