Skip to content

Call Fabric's precision plugin convert_module from the Trainer's connect #17655

@carmocca

Description

@carmocca

Description & Motivation

The Trainer's base precision class inherits from Fabric's base Precision: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/plugins/precision/precision_plugin.py#L31

Fabric's precision defines this hook: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/precision.py#L38-L43 which gets called by the Fabric object on setup.

But there's no call for it in the Trainer.

Pitch

Should connect() call it? https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/plugins/precision/precision_plugin.py#L37-L41

For instance, this could be used to support true half precision in the Trainer via inheriting Fabric's HalfPrecision plugin: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/half.py#L39-L40 as it defines convert_module

Alternatives

Not do it

Additional context

This dicussion stems from the Graphcore folks wanting to support true half precision.

cc @Borda @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementprecision: ampAutomatic Mixed Precisiontrainer

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions