Skip to content

XLA's 16-bit precision plugin selection is incorrect #18172

@carmocca

Description

@carmocca

Bug description

(I'll point to Fabric for the motivation, but this same issue exists in the Trainer)

The connector logic to choose bf16 precision with XLA is flawed:

https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/connector.py#L435-L441

Where XLABf16Precision basically just sets XLA_USE_BF16=1: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/xlabf16.py#L25

In https://pytorch.org/xla/release/2.0/, it's explained that

XLA_USE_BF16: If set to 1, tranforms all the PyTorch Float values into BiFloat16 when sending to the TPU device. Note that when using XLA_USE_BF16=1 tensor arithmetic will be done in reduced precision and so tensors will not be accurate if accumulated over time.

and

If XLA_USE_BF16 is set, then torch.float and torch.double are both bfloat16 on TPUs.

Which corresponds to what you would expect with precision="bf16-true".

However, the connector logic only selects this plugin with precision="bf16-mixed" even though mixed precision training is not currently supported with XLA (this is tracked in #17927).

Fixing this will be a breaking change

What version are you seeing the problem on?

v2.0

How to reproduce the bug

Fabric(precision="bf16-true") will not set XLA_USE_BF16=1.

(the Trainer currently does not implement precision="bf16-true")

Error messages and logs

None

Environment

Current master

More info

We should also support the same with XLA_USE_F16 for 16-true precision

cc @carmocca @justusschock @awaelchli @JackCaoG @steventk-g @Liyang90

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions