Skip to content

Cast fp tensor to half when using 16-true precision #17727

@joao-alex-cunha

Description

@joao-alex-cunha

Description & Motivation

The goal of this feature is to support transferring tensors in half precision to an accelerator when using 16-true precision.
Currently, some strategies support this behavior but only for 16-mixed of bf16-mixed, such as DeepSpeed and IPUStrategy.
This is done in batch_to_device method which in turn calls _fp_to_half.
The definition of done for this PR is to extend the _fp_to_half method.

Pitch

By changing the behavior of _fp_to_half when using 16-true precision, tensors are cast to half before being transferred to the accelerator(s).

Should be a matter of changing the method to:

def _fp_to_half(
    tensor: Tensor,
    precision: Literal[
        "64-true",
        "32-true",
        "16-mixed",
        "bf16-mixed",
        "16-true",
    ],
) -> Tensor:
    if str(precision) in ("16-mixed", "16-true"):
        return _convert_fp_tensor(tensor, torch.half)
    if precision == "bf16-mixed":
        return _convert_fp_tensor(tensor, torch.bfloat16)
    return tensor

As far as I understand the _fp_to_half method is only used inside batch_to_device so I don't expect any side-effect from this change. Hopefully that is really the case.

Alternatives

I believe the alternative is to cast tensors to half precision inside the datamodule, which involves making the datamodule aware of the training precision, which might not be trivial to do.
Another alternative is to cast tensors after they are in the accelerator. This involves changing the model to pre-append a casting operation. This seems easier, but isn't exactly the same as the proposal above. By transferring data in half precision, the data payload is reduced which increases the performance of the training loop.

Additional context

No response

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions