Skip to content

Conversation

Armannas
Copy link
Contributor

@Armannas Armannas commented Jun 4, 2025

Problem

When selecting mps (tested on MacBook Pro M4) as the accelerator in Fabric with precision="bf16-mixed" or "16-mixed", the device passed to torch.autocast is incorrectly hardcoded as "cuda". This results in the following warning and disables automatic mixed precision:

.../torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling.

Minimal Reproducible Example

import warnings
from lightning.fabric import Fabric

fabric = Fabric(accelerator="mps", precision="bf16-mixed")
fabric.launch()

with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    with fabric.autocast():
        pass

    for warning in w:
        assert "device_type of 'cuda'" not in str(warning.message), \
            "Fabric autocast used incorrect device_type='cuda' on MPS"

Proposed Solution

In lightning/fabric/connector.py, the device is currently selected using:

device = "cpu" if self._accelerator_flag == "cpu" else "cuda"

This can be corrected to:

device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"

This should ensure that torch.autocast receives the correct device type when enabling Automatic Mixed Precision in Fabric.

Open to feedback on test (placement) or implementation details.


📚 Documentation preview 📚: https://pytorch-lightning--20876.org.readthedocs.build/en/20876/

@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Jun 4, 2025
@Borda
Copy link
Member

Borda commented Jun 5, 2025

Looks great! Can you pls add a test for this selection...
Since we use M1 machines in CI, this shall be feasible 🦩

Copy link

codecov bot commented Jun 5, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 87%. Comparing base (6675932) to head (26d15ad).
⚠️ Report is 183 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20876   +/-   ##
=======================================
  Coverage      87%      87%           
=======================================
  Files         268      268           
  Lines       23411    23411           
=======================================
  Hits        20360    20360           
  Misses       3051     3051           

@Borda Borda merged commit 14b6c3e into Lightning-AI:master Jun 5, 2025
140 of 142 checks passed
@Armannas Armannas deleted the bug/issue_mps_amp_support branch June 5, 2025 17:32
Borda pushed a commit that referenced this pull request Jun 19, 2025
…bric accelerator (#20876)

* Make sure MPS is used when chosen as accelerator in Fabric
* Added mps tests to connector and Fabric

---------

Co-authored-by: Haga Device <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 14b6c3e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants