Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
[submodule "3rdparty/llama.cpp"]
path = 3rdparty/llama.cpp
url = https://github.com/kaleid-liner/llama.cpp
branch = master-rebased
2 changes: 1 addition & 1 deletion 3rdparty/llama.cpp
Submodule llama.cpp updated 916 files
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

## News

- 10/10/2024 🚀🚀: By updating and rebasing our llama.cpp version, T-MAC now support more models (e.g., qwen2) and the end-to-end performance is further improved by 10~15%! Try qwen2 using [the Official GPTQ model](https://huggingface.co/Qwen/Qwen2-7B-Instruct-GPTQ-Int4).

- 08/21/2024 🎉🎉: T-MAC paper is accepted by EuroSys 2025.

- 08/17/2024 🚀: T-MAC now supports 1/2/4-bit quantized models of (almost) any architecture in GPTQ format.
Expand All @@ -32,6 +34,8 @@ T-MAC achieves a token generation throughput of 20 tokens/sec with a single core

## End-2-End Speedup

> All of the following data is profiled based on llama.cpp b2794 (May 2024). The latest T-MAC and baseline, after updating the llama.cpp version, is further optimized by 10~15%.

We evaluate the token generation performance of different models on five different devices: Surface Laptop 7, Apple M2-Ultra, Jetson AGX Orin, Raspberry Pi 5 and Surface Book 3. Check [datasheet](docs/profiling_data.md) for more details.

> We evaluate BitNet-3B and Llama-2-7B (W2) with T-MAC 2-bit and llama.cpp Q2_K, and evaluate Llama-2-7B (W4) with T-MAC 4-bit and llama.cpp Q4_0.
Expand Down
43 changes: 41 additions & 2 deletions python/t_mac/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import os
import logging
import json
import configparser

import numpy as np

from t_mac.weights import preprocess_weights



logger = logging.getLogger("model_utils")

Expand Down Expand Up @@ -113,7 +117,7 @@ def __init__(self, dir_model: Path):
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = _Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")

@staticmethod
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> List[str]:
part_names: list[str] = []
Expand Down Expand Up @@ -170,7 +174,7 @@ def extract_kernel_shapes(self):
raise RuntimeError("Models in {} not in GPTQ format".format(self.dir_model))

return ks

@staticmethod
def load_hparams(dir_model):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
Expand All @@ -192,6 +196,7 @@ def extract_kernel_shapes(model_arch: Optional[str] = "gptq-auto", model_dir: Op

def get_quantization_config(model_dir: Optional[str] = None) -> dict:
hparams = _Model.load_hparams(Path(model_dir))
# GPTQ
quantization_config = hparams.get("quantization_config", {})
desc_act = quantization_config.get("desc_act", False)
assert not desc_act, "desc_act=True currently unsupported by T-MAC"
Expand All @@ -200,11 +205,45 @@ def get_quantization_config(model_dir: Optional[str] = None) -> dict:
bits = quantization_config.get("bits", 0)
sym = quantization_config.get("sym", False)
quant_method = quantization_config.get("quant_method", "")
# BitNet
weight_bits = hparams.get("weight_bits", 0)

return {
"quantizer": quantizer,
"group_size": group_size,
"bits": bits,
"sym": sym,
"quant_method": quant_method,
"weight_bits": weight_bits,
}


def preprocess_for_t_mac(
kcfg_file: str,
w: np.ndarray,
scales: np.ndarray,
zeros: Optional[np.ndarray] = None,
bits: int = 2,
g: int = 4,
) -> np.ndarray:

M, K = w.shape
cf = configparser.ConfigParser()
cf.read(kcfg_file)
secs = cf.sections()
found = False
for sec in secs:
sec_splits = str(sec).split('_')
if sec_splits[-4] == "m" + str(M * bits) and sec_splits[-3] == "k" + str(K):
bm = int(cf.get(sec, 'bm'))
kfactor = int(cf.get(sec, 'kfactor'))
simd_n_in = int(cf.get(sec, 'simd_n_in'))
simd_n_out = int(cf.get(sec, 'simd_n_out'))
found = True
break

if not found:
raise KeyError("GEMM of shape ({}, {}) is not found in {}. Please compile the kernels using T-MAC first.".format(M, K, kcfg_file))

w, scales = preprocess_weights(w, scales, zeros, bits=bits, g=g, bm=bm, kfactor=kfactor, simd_n_in=simd_n_in, simd_n_out=simd_n_out)
return np.concatenate([w.flatten(), scales.astype(np.float32).view(np.uint8).flatten()])
2 changes: 1 addition & 1 deletion python/t_mac/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0a3"
__version__ = "1.0.0a4"
40 changes: 21 additions & 19 deletions tools/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger("run_pipeline")


def run_command(command, pwd):
def run_command(command, pwd, ignore_errors=False):
print(f" Running command in {pwd}:")
print(f" {' '.join(command)}")
os.makedirs(FLAGS.logs_dir, exist_ok=True)
Expand All @@ -21,8 +21,9 @@ def run_command(command, pwd):
try:
subprocess.check_call(command, cwd=pwd, stdout=fp, stderr=fp)
except subprocess.CalledProcessError as err:
print(RED + f"Please check {log_file} for what's wrong" + RESET)
exit(-1)
if not ignore_errors:
print(RED + f"Please check {log_file} for what's wrong" + RESET)
exit(-1)
return log_file


Expand Down Expand Up @@ -83,6 +84,8 @@ def compile_kernels():


def _clean_cmake(build_dir):
command = ['cmake', '--build', '.', '--target', 'clean']
run_command(command, build_dir, ignore_errors=True)
shutil.rmtree(os.path.join(build_dir, "CMakeFiles"), ignore_errors=True)
shutil.rmtree(os.path.join(build_dir, "CMakeCache.txt"), ignore_errors=True)

Expand Down Expand Up @@ -125,12 +128,13 @@ def convert_models():
llamacpp_dir = os.path.join(ROOT_DIR, "3rdparty", "llama.cpp")
command = [
'python',
'convert-hf-to-gguf-t-mac.py',
'convert_hf_to_gguf.py',
f'{model_dir}',
'--outtype',
f'{FLAGS.quant_type}',
'--outtype', f'{FLAGS.quant_type}',
'--outfile', f'{out_path}',
'--kcfg', f'{kcfg_path}',
'--enable-t-mac',
'--verbose',
]
run_command(command, llamacpp_dir)

Expand All @@ -140,10 +144,10 @@ def cmake_llamacpp():
cmake_prefix_path = os.path.join(ROOT_DIR, "install", "lib", "cmake", "t-mac")
command = [
'cmake', '..',
'-DLLAMA_TMAC=ON',
'-DGGML_TMAC=ON',
f'-DCMAKE_PREFIX_PATH={cmake_prefix_path}',
'-DCMAKE_BUILD_TYPE=Release',
'-DLLAMA_LLAMAFILE_DEFAULT=OFF',
'-DGGML_OPENMP=OFF',
]
if FLAGS.device == "android":
try:
Expand All @@ -154,15 +158,13 @@ def cmake_llamacpp():
command.append("-DANDROID_ABI=arm64-v8a")
command.append("-DANDROID_PLATFORM=android-23")
command.append("-DCMAKE_C_FLAGS=-march=armv8.2a+dotprod+fp16")
command.append("-DLLAMA_METAL=OFF")
command.append("-DLLAMA_ACCELERATE=OFF")
command.append("-DGGML_METAL=OFF")
command.append("-DGGML_ACCELERATE=OFF")
command.append("-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH")
command.append("-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH")
elif is_win():
if is_arm():
command.append("-DCMAKE_C_COMPILER=clang")
command.append("-DCMAKE_CXX_COMPILER=clang++")
command.append("-G Ninja")
command.append("--preset arm64-windows-llvm-release")
else:
command.append("-T ClangCL")
else:
Expand All @@ -176,26 +178,26 @@ def cmake_llamacpp():

def build_llamacpp():
build_dir = get_llamacpp_build_dir()
command = ['cmake', '--build', '.', '--target', 'main', 'llama-bench', '--config', 'Release']
command = ['cmake', '--build', '.', '--target', 'llama-cli', 'llama-bench', '--config', 'Release']
run_command(command, build_dir)


def run_inference():
build_dir = get_llamacpp_build_dir()
out_path = os.path.join(FLAGS.model_dir, f"ggml-model.{FLAGS.quant_type}.gguf")
if is_win():
main_path = os.path.join(build_dir, "bin", "Release", "main.exe")
main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe")
if not os.path.exists(main_path):
main_path = os.path.join(build_dir, "bin", "main")
main_path = os.path.join(build_dir, "bin", "llama-cli")
else:
main_path = os.path.join(build_dir, "bin", "main")
main_path = os.path.join(build_dir, "bin", "llama-cli")
prompt = "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington."
if FLAGS.device == "android":
remote_bin_path = os.path.join(FLAGS.remote_dir, "bin")
# TODO: verify in Windows
command = ['push', os.path.join(build_dir, "bin"), FLAGS.remote_dir]
run_adb_command(command, build_dir)
remote_main_path = os.path.join(remote_bin_path, "main")
remote_main_path = os.path.join(remote_bin_path, "llama-cli")
command = ['shell', 'chmod', '-R', '+x', remote_bin_path]
run_adb_command(command, build_dir)
remote_out_path = os.path.join(
Expand Down Expand Up @@ -276,7 +278,7 @@ def parse_args():
parser.add_argument("-gs", "--group_size", type=int, default=None, help="Don't set this argument if you don't know its meaning.")
parser.add_argument("-ags", "--act_group_size", type=int, default=None, help="Don't set this argument if you don't know its meaning.")
parser.add_argument("-ld", "--logs_dir", type=str, default="logs")
parser.add_argument("-q", "--quant_type", type=str, choices=["in", "i1", "i2", "i3", "i4"], default="in")
parser.add_argument("-q", "--quant_type", type=str, choices=["int_n", "f16", "f32"], default="int_n")
parser.add_argument("-zp", "--zero_point", action="store_true", help="Enforce enable zero_point. Required by EfficientQAT models.")
parser.add_argument("-nzp", "--no_zero_point", action="store_false", help="Enforce disable zero_point. Don't set this argument if you don't know its meaning.")

Expand Down