-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
(I'll point to Fabric for the motivation, but this same issue exists in the Trainer)
The connector logic to choose bf16 precision with XLA is flawed:
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/connector.py#L435-L441
Where XLABf16Precision
basically just sets XLA_USE_BF16=1
: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/xlabf16.py#L25
In https://pytorch.org/xla/release/2.0/, it's explained that
XLA_USE_BF16: If set to 1, tranforms all the PyTorch Float values into BiFloat16 when sending to the TPU device. Note that when using XLA_USE_BF16=1 tensor arithmetic will be done in reduced precision and so tensors will not be accurate if accumulated over time.
and
If XLA_USE_BF16 is set, then torch.float and torch.double are both bfloat16 on TPUs.
Which corresponds to what you would expect with precision="bf16-true"
.
However, the connector logic only selects this plugin with precision="bf16-mixed"
even though mixed precision training is not currently supported with XLA (this is tracked in #17927).
Fixing this will be a breaking change
What version are you seeing the problem on?
v2.0
How to reproduce the bug
Fabric(precision="bf16-true")
will not set XLA_USE_BF16=1
.
(the Trainer currently does not implement precision="bf16-true"
)
Error messages and logs
None
Environment
Current master
More info
We should also support the same with XLA_USE_F16
for 16-true
precision
cc @carmocca @justusschock @awaelchli @JackCaoG @steventk-g @Liyang90