Skip to content

Commit ac8ef69

Browse files
committed
fix
1 parent 9c5dfe7 commit ac8ef69

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

src/guidellm/utils/preprocessing_sharegpt_data.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import re
55
from pathlib import Path
6-
from typing import Callable, Optional
6+
from typing import Optional
77

88
import numpy as np
99
from datasets import load_dataset
@@ -13,34 +13,31 @@
1313
MAX_CHAR = 1000
1414

1515

16-
def create_token_estimator(
17-
model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
18-
) -> Callable[[str], int]:
19-
_tokenizer: Optional[AutoTokenizer] = None
16+
class TokenCounter:
17+
def __init__(self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2"):
18+
self.model_name = model_name
19+
self._tokenizer: Optional[AutoTokenizer] = None
2020

21-
def initialize() -> None:
22-
nonlocal _tokenizer
23-
if _tokenizer is None:
21+
def _initialize_tokenizer(self) -> None:
22+
if self._tokenizer is None:
2423
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2524
try:
26-
_tokenizer = AutoTokenizer.from_pretrained(model_name)
25+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
2726
except (OSError, ImportError, ValueError) as e:
2827
raise RuntimeError(f"Failed to initialize tokenizer: {e}") from e
2928

30-
def estimate_num_tokens(text: str) -> int:
31-
initialize()
29+
def estimate_num_tokens(self, text: str) -> int:
30+
self._initialize_tokenizer()
3231

33-
if _tokenizer is None:
32+
if self._tokenizer is None:
3433
return 0
3534

3635
try:
37-
encoding = _tokenizer(text, return_tensors=None)
36+
encoding = self._tokenizer(text, return_tensors=None)
3837
return len(encoding["input_ids"])
3938
except (AttributeError, TypeError, RuntimeError) as e:
4039
raise ValueError(f"Error processing text: {e}") from e
4140

42-
return estimate_num_tokens
43-
4441

4542
def extract_and_save_with_filtering(file):
4643
"""substract human prompts and apply filtering conditions"""
@@ -93,7 +90,7 @@ def extract_and_save_with_filtering(file):
9390
with Path(sharegpt_file).open("r", encoding="utf-8") as file:
9491
data = json.load(file)
9592

96-
estimate_tokens = create_token_estimator()
93+
counter = TokenCounter()
9794
num_of_ids = len(data)
9895
data = data[: int(num_of_ids * args.parse)]
9996
for d in data:
@@ -102,9 +99,9 @@ def extract_and_save_with_filtering(file):
10299
gpt_tokens = []
103100
for conv in d["conversations"]:
104101
if conv["from"] == "human":
105-
human_tokens.append(estimate_tokens(conv["value"]))
102+
human_tokens.append(counter.estimate_num_tokens(conv["value"]))
106103
if conv["from"] == "gpt":
107-
token_number = estimate_tokens(conv["value"])
104+
token_number = counter.estimate_num_tokens(conv["value"])
108105
conv["num_tokens"] = token_number
109106
gpt_tokens.append(token_number)
110107
if len(human_tokens) == 0:

0 commit comments

Comments
 (0)