@@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure):
77
77
_fields_ = [("status" , ctypes .c_int ),
78
78
("text" , ctypes .c_char * 32768 )]
79
79
80
+ class token_count_outputs (ctypes .Structure ):
81
+ _fields_ = [("count" , ctypes .c_int ),
82
+ ("ids" , ctypes .POINTER (ctypes .c_int ))]
83
+
80
84
handle = None
81
85
82
86
def getdirpath ():
@@ -218,7 +222,7 @@ def init_library():
218
222
handle .get_total_gens .restype = ctypes .c_int
219
223
handle .get_last_stop_reason .restype = ctypes .c_int
220
224
handle .abort_generate .restype = ctypes .c_bool
221
- handle .token_count .restype = ctypes . c_int
225
+ handle .token_count .restype = token_count_outputs
222
226
handle .get_pending_output .restype = ctypes .c_char_p
223
227
224
228
def load_model (model_filename ):
@@ -729,8 +733,11 @@ def do_POST(self):
729
733
try :
730
734
genparams = json .loads (body )
731
735
countprompt = genparams .get ('prompt' , "" )
732
- count = handle .token_count (countprompt .encode ("UTF-8" ))
733
- response_body = (json .dumps ({"value" : count }).encode ())
736
+ rawcountdata = handle .token_count (countprompt .encode ("UTF-8" ))
737
+ countlimit = rawcountdata .count if (rawcountdata .count >= 0 and rawcountdata .count < 50000 ) else 0
738
+ # the above protects the server in case the count limit got corrupted
739
+ countdata = [rawcountdata .ids [i ] for i in range (countlimit )]
740
+ response_body = (json .dumps ({"value" : len (countdata ),"ids" : countdata }).encode ())
734
741
735
742
except Exception as e :
736
743
utfprint ("Count Tokens - Body Error: " + str (e ))
0 commit comments