Skip to content

Inconsistent API for on_predict_epoch_end #8479

@ananthsub

Description

@ananthsub

🚀 Feature

Background
We are auditing the Lightning components and APIs to assess opportunities for improvements:

One item that came up was on_predict_epoch_end() defined on the ModelHooks Mixin. This accepts an outputs: List[Any] argument. However, this is inconsistent with the other model hooks of the same type: on_train_epoch_end, on_validation_epoch_end, and on_test_epoch_end

Motivation

API consistency with other epoch end model hooks.

Pitch

  • Add a prediction_epoch_end hook to the LightningModule
  • Deprecate the outputs argument from on_predict_epoch_end in v1.5 and remove entirely in v1.7
  • Update the prediction loop to only cache predictions if prediction_epoch_end is implemented

Users can optionally avoid this entirely and cache their prediction outputs as needed by implementing this logic in predict_step directly

Alternatives

Keep as is?

Additional context

cc @Borda @tchaton @justusschock @awaelchli @carmocca @ninginthecloud @daniellepintz @rohitgr7

Metadata

Metadata

Assignees

Labels

designIncludes a design discussionfeatureIs an improvement or enhancementhooksRelated to the hooks APItrainer: predict

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions