3
3
import os
4
4
import re
5
5
from pathlib import Path
6
- from typing import Callable , Optional
6
+ from typing import Optional
7
7
8
8
import numpy as np
9
9
from datasets import load_dataset
13
13
MAX_CHAR = 1000
14
14
15
15
16
- def create_token_estimator (
17
- model_name : str = "mistralai/Mistral-7B-Instruct-v0.2" ,
18
- ) -> Callable [[ str ], int ]:
19
- _tokenizer : Optional [AutoTokenizer ] = None
16
+ class TokenCounter :
17
+ def __init__ ( self , model_name : str = "mistralai/Mistral-7B-Instruct-v0.2" ):
18
+ self . model_name = model_name
19
+ self . _tokenizer : Optional [AutoTokenizer ] = None
20
20
21
- def initialize () -> None :
22
- nonlocal _tokenizer
23
- if _tokenizer is None :
21
+ def _initialize_tokenizer (self ) -> None :
22
+ if self ._tokenizer is None :
24
23
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
25
24
try :
26
- _tokenizer = AutoTokenizer .from_pretrained (model_name )
25
+ self . _tokenizer = AutoTokenizer .from_pretrained (self . model_name )
27
26
except (OSError , ImportError , ValueError ) as e :
28
27
raise RuntimeError (f"Failed to initialize tokenizer: { e } " ) from e
29
28
30
- def estimate_num_tokens (text : str ) -> int :
31
- initialize ()
29
+ def estimate_num_tokens (self , text : str ) -> int :
30
+ self . _initialize_tokenizer ()
32
31
33
- if _tokenizer is None :
32
+ if self . _tokenizer is None :
34
33
return 0
35
34
36
35
try :
37
- encoding = _tokenizer (text , return_tensors = None )
36
+ encoding = self . _tokenizer (text , return_tensors = None )
38
37
return len (encoding ["input_ids" ])
39
38
except (AttributeError , TypeError , RuntimeError ) as e :
40
39
raise ValueError (f"Error processing text: { e } " ) from e
41
40
42
- return estimate_num_tokens
43
-
44
41
45
42
def extract_and_save_with_filtering (file ):
46
43
"""substract human prompts and apply filtering conditions"""
@@ -93,7 +90,7 @@ def extract_and_save_with_filtering(file):
93
90
with Path (sharegpt_file ).open ("r" , encoding = "utf-8" ) as file :
94
91
data = json .load (file )
95
92
96
- estimate_tokens = create_token_estimator ()
93
+ counter = TokenCounter ()
97
94
num_of_ids = len (data )
98
95
data = data [: int (num_of_ids * args .parse )]
99
96
for d in data :
@@ -102,9 +99,9 @@ def extract_and_save_with_filtering(file):
102
99
gpt_tokens = []
103
100
for conv in d ["conversations" ]:
104
101
if conv ["from" ] == "human" :
105
- human_tokens .append (estimate_tokens (conv ["value" ]))
102
+ human_tokens .append (counter . estimate_num_tokens (conv ["value" ]))
106
103
if conv ["from" ] == "gpt" :
107
- token_number = estimate_tokens (conv ["value" ])
104
+ token_number = counter . estimate_num_tokens (conv ["value" ])
108
105
conv ["num_tokens" ] = token_number
109
106
gpt_tokens .append (token_number )
110
107
if len (human_tokens ) == 0 :
0 commit comments