-
-
Notifications
You must be signed in to change notification settings - Fork 654
Add SAVED_CHECKPOINT event to Checkpoint handler #3440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add SAVED_CHECKPOINT event to Checkpoint handler #3440
Conversation
@JeevanChevula thanks for the PR. However, let's rework the API of the new feature you are working on:
# checkpoint.py
class CheckpointEvents(EventEnum):
SAVED_CHECKPOINT = "saved_checkpoint"
class Checkpoint(...):
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
...
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, global_step_from_engine
trainer = ...
evaluator = ...
# Setup Accuracy metric computation on evaluator.
# evaluator.state.metrics contain 'accuracy',
# which will be used to define ``score_function`` automatically.
# Run evaluation on epoch completed event
# ...
to_save = {'model': model}
handler = Checkpoint(
to_save, '/tmp/models',
n_saved=2, filename_prefix='best',
score_name="accuracy",
global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, handler)
# ---- New API with Checkpoint.SAVED_CHECKPOINT event: -----
@evaluator.on(Checkpoint.SAVED_CHECKPOINT)
def notify_when_saved(eval_engine, chkpt_handler): # we should pass to the attached handlers the engine and the checkpoint instance.
assert eval_engine is engine
assert chkpt_handler is handler
print("Saved checkpoint:", chkpt_handler.last_checkpoint)
# ---- End of New API with Checkpoint.SAVED_CHECKPOINT event: -----
trainer.run(data_loader, max_epochs=10)
> ["best_model_9_accuracy=0.77.pt", "best_model_10_accuracy=0.78.pt", ] Let me know what do you think? |
Thanks for the suggestion . I’ll try to work on updating the PR to follow the API approach you mentioned with |
Implementation Note: Implemented EventEnum-based SAVED_CHECKPOINT event as requested. However, Ignite's event system only supports single-parameter handlers - the originally requested two-parameter signature (handler(engine, checkpoint_handler)) failed during event firing and registration. Current implementation uses single parameter with checkpoint access via engine._current_checkpoint_handler. All 61 core tests pass, confirming functionality works without breaking existing features. The 3 distributed test errors are pre-existing infrastructure issues unrelated to this change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this PR @JeevanChevula
I left few more comments to improve the PR
@@ -460,11 +470,15 @@ def __call__(self, engine: Engine) -> None: | |||
if self.include_self: | |||
# Now that we've updated _saved, we can add our own state_dict. | |||
checkpoint["checkpointer"] = self.state_dict() | |||
|
|||
# Store reference to self in engine for event handlers to access | |||
engine._current_checkpoint_handler = self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this code? I do not understand why we need this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a workaround for Ignite's event system limitation. You originally requested a two-parameter handler signature handler(engine, checkpoint_handler)
, but Ignite's fire_event()
only supports single parameters and rejects handlers expecting additional arguments. This line stores the checkpoint reference in the engine so handlers can access it via engine._current_checkpoint_handler
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to check details about this limitation. Unfortunately engine._current_checkpoint_handler
is not good enough for a public API usage.
Alternatively, we can pass the instance of the checkpointer as an arg when attaching:
checkpoint = Checkpoint(...)
@trainer.on(Checkpoint. SAVED_CHECKPOINT, checkpoint)
def handler(engine, chkpt_handler):
assert engine is trainer
assert chkpt_handler is checkpoint
Maybe, we can skip automatic chkpt_handler
arg injection:
checkpoint = Checkpoint(...)
@trainer.on(Checkpoint. SAVED_CHECKPOINT)
def handler(engine):
assert engine is trainer
```
Pushing current implementation with working SAVED_CHECKPOINT event functionality. Will add proper Google-style docstrings with version directives by Monday per contributing guidelines |
@JeevanChevula please rebase your PR branch, you have now some extra commits |
…istration instead of class constant- Register event dynamically when first used - Event fires correctly after successful checkpoint save- Tested locally and working properly
…ixed code formatting with black
…istration instead of class constant- Register event dynamically when first used - Event fires correctly after successful checkpoint save- Tested locally and working properly
…ixed code formatting with black
d500fc8
to
d81faa9
Compare
Fixes #934
This PR adds a "saved_checkpoint" event that fires after successful checkpoint saves.
Usage: