Skip to content

Commit b886d0a

Browse files
committed
Cleaned typing to be in line with accelerate hyperparameters type resctrictions
1 parent 2c9db5d commit b886d0a

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

library/train_util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3390,7 +3390,20 @@ def filter_sensitive_args(args: argparse.Namespace):
33903390
"output_dir",
33913391
"logging_dir",
33923392
]
3393-
filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args}
3393+
filtered_args = {}
3394+
for k, v in vars(args).items():
3395+
# filter out sensitive values
3396+
if k not in sensitive_args + sensitive_path_args:
3397+
#Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
3398+
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
3399+
filtered_args[k] = v
3400+
# accelerate does not support lists
3401+
elif isinstance(v, list):
3402+
filtered_args[k] = f"{v}"
3403+
# accelerate does not support objects
3404+
elif isinstance(v, object):
3405+
filtered_args[k] = f"{v}"
3406+
33943407
return filtered_args
33953408

33963409
# verify command line args for training

0 commit comments

Comments
 (0)