diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index a2f200dfe..45e10ceec 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -352,6 +352,7 @@ def __init__( model=self.model, model_param_partition_specs=model_param_partition_specs, ) + register_sigterm_handler() self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) @property @@ -1450,3 +1451,13 @@ def m_or_g(x, suffix=""): logging.warning("Attempt to parse cost_stats=%s but failed.", cost_stats) return analysis_results + +def register_sigterm_handler(): + original_sigterm_handler = signal.getsignal(signal.SIGTERM) + def sigterm_handler(signum, frame): + original_sigterm_handler(signum, frame) + + # system is being shutdown + if os.path.exists("/var/run/nologin") or os.path.exists("/run/nologin"): + raise SystemExit("Exiting without waiting checkpoint saving after system shutdown is detected.") + signal.signal(signal.SIGTERM, sigterm_handler)