Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision:
if self._precision_input == "16-mixed"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]

raise RuntimeError("No precision set")
Expand Down
7 changes: 7 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
assert isinstance(connector.strategy, DDPStrategy)


@RunIf(mps=True)
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
def test_mps_enabled_with_float16_or_bfloat16_precision(precision):
connector = _Connector(accelerator="mps", precision=precision)
assert connector.precision.device == "mps"


def test_invalid_accelerator_choice():
with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='cocofruit'`"):
_Connector(accelerator="cocofruit")
Expand Down
17 changes: 17 additions & 0 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from contextlib import nullcontext
from re import escape
from unittest import mock
Expand Down Expand Up @@ -735,6 +736,22 @@ def test_autocast():
fabric._precision.forward_context().__exit__.assert_called()


@RunIf(mps=True)
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
def test_autocast_does_not_use_cuda_on_mps(precision):
"""Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision."""
fabric = Fabric(accelerator="mps", precision=precision)
fabric.launch()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with fabric.autocast():
pass

for warning in w:
assert "device_type of 'cuda'" not in str(warning.message)


def test_no_backward_sync():
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
fabric = Fabric(devices=1)
Expand Down
Loading