Skip to content

Commit a9e9831

Browse files
committed
[Fix] Fix mistakes in model name in run_pipeline.
1 parent 9029315 commit a9e9831

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tools/run_pipeline.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,9 @@ def build_llamacpp():
221221

222222
def run_inference():
223223
build_dir = get_llamacpp_build_dir()
224-
model_name = f"{os.path.split(FLAGS.model_dir)[-1]}.{str(FLAGS.inference_type).upper()}.gguf"
225-
out_path = os.path.join(FLAGS.model_dir, model_name)
224+
model_dir = str(FLAGS.model_dir).rstrip('\\').rstrip('/')
225+
model_name = f"{os.path.split(model_dir)[-1]}.{str(FLAGS.inference_type).upper()}.gguf"
226+
out_path = os.path.join(model_dir, model_name)
226227
if is_win():
227228
main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe")
228229
if not os.path.exists(main_path):
@@ -240,7 +241,7 @@ def run_inference():
240241
run_adb_command(command, build_dir)
241242
remote_out_path = os.path.join(
242243
FLAGS.remote_dir,
243-
f"{os.path.basename(FLAGS.model_dir)}-{os.path.basename(out_path)}",
244+
f"{os.path.basename(model_dir)}-{os.path.basename(out_path)}",
244245
)
245246
if not FLAGS.skip_push_model:
246247
command = ['push', out_path, remote_out_path]
@@ -277,8 +278,9 @@ def run_inference():
277278

278279
def run_llama_bench():
279280
build_dir = get_llamacpp_build_dir()
280-
model_name = f"{os.path.split(FLAGS.model_dir)[-1]}.{str(FLAGS.inference_type).upper()}.gguf"
281-
out_path = os.path.join(FLAGS.model_dir, model_name)
281+
model_dir = str(FLAGS.model_dir).rstrip('\\').rstrip('/')
282+
model_name = f"{os.path.split(model_dir)[-1]}.{str(FLAGS.inference_type).upper()}.gguf"
283+
out_path = os.path.join(model_dir, model_name)
282284
if is_win():
283285
main_path = os.path.join(build_dir, "bin", "Release", "llama-bench.exe")
284286
if not os.path.exists(main_path):
@@ -296,7 +298,7 @@ def run_llama_bench():
296298
run_adb_command(command, build_dir)
297299
remote_out_path = os.path.join(
298300
FLAGS.remote_dir,
299-
f"{os.path.basename(FLAGS.model_dir)}-{os.path.basename(out_path)}",
301+
f"{os.path.basename(model_dir)}-{os.path.basename(out_path)}",
300302
)
301303
if not FLAGS.skip_push_model:
302304
command = ['push', out_path, remote_out_path]

0 commit comments

Comments
 (0)