-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
The Trainer
's base precision class inherits from Fabric's base Precision: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/plugins/precision/precision_plugin.py#L31
Fabric's precision defines this hook: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/precision.py#L38-L43 which gets called by the Fabric
object on setup.
But there's no call for it in the Trainer.
Pitch
Should connect()
call it? https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/plugins/precision/precision_plugin.py#L37-L41
For instance, this could be used to support true half precision in the Trainer via inheriting Fabric's HalfPrecision plugin: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/half.py#L39-L40 as it defines convert_module
Alternatives
Not do it
Additional context
This dicussion stems from the Graphcore folks wanting to support true half precision.