-
-
Notifications
You must be signed in to change notification settings - Fork 782
Description
System Info
This code which uses bitsandbytes will fails now with the newish triton versions for error that "No module named 'triton.ops'"
from transformers import LlamaForCausalLM
from transformers import BitsAndBytesConfig
model = 'facebook/opt-350m'
model = LlamaForCausalLM.from_pretrained(model, quantization_config=BitsAndBytesConfig(load_in_8bit=True))
Reason is that
int8_matmul_mixed_dequantize.py and
int8_matmul_rowwise_dequantize.py
checks whether triton is available but does not take into account that new triton versions does not anymore include the triton.ops.matmul_perf_model
According to triton bugs, that model was meant to be more like as a sample and has been now moved to triton.lang kernels project:
https://github.com/triton-lang/kernels/
triton.lang kernels does not seem to offer a method to install it with pip-install (or I did not figure it out) so the matmul_perf_model.py type of code would propably need to be imported to be a part of the bitsandbytes project?
I got the bitsandbytes to work by hacking the int8_matmul_mixed_dequantize.py and int8_matmul_rowwise_dequantize.py to assume that triton is not installed by changing the files to only returning None.
import torch
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
return None
and
import torch
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
return None
Reproduction
Build pretty new triton version, due to other restrictions I used triton source code from October 29, 2024 with hash
ebce7f3a62af5242bbb3fe05876c5b3995eb2988
Execute a following code and you should get error about missing triton.ops
from transformers import LlamaForCausalLM
from transformers import BitsAndBytesConfig
model = 'facebook/opt-350m'
model = LlamaForCausalLM.from_pretrained(model, quantization_config=BitsAndBytesConfig(load_in_8bit=True))
Stacktrace
Traceback (most recent call last):
File "/opt/rocm_sdk_612/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1817, in _get_module
return importlib.import_module("." + module_name, self.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/rocm_sdk_612/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/opt/rocm_sdk_612/lib/python3.11/site-packages/transformers/integrations/bitsandbytes.py", line 21, in <module>
import bitsandbytes as bnb
File "/opt/rocm_sdk_612/lib/python3.11/site-packages/bitsandbytes/__init__.py", line 21, in <module>
from .nn import modules
File "/opt/rocm_sdk_612/lib/python3.11/site-packages/bitsandbytes/nn/__init__.py", line 17, in <module>
from .triton_based_modules import (
File "/opt/rocm_sdk_612/lib/python3.11/site-packages/bitsandbytes/nn/triton_based_modules.py", line 7, in <module>
from bitsandbytes.triton.int8_matmul_mixed_dequantize import (
File "/opt/rocm_sdk_612/lib/python3.11/site-packages/bitsandbytes/triton/int8_matmul_mixed_dequantize.py", line 12, in <module>
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
ModuleNotFoundError: No module named 'triton.ops'
Expected behavior
bitsandbytes should handle the situation where triton is available but can not import the triton.ops package.