-
Notifications
You must be signed in to change notification settings - Fork 3.6k
True half-precision support in Fabric #17287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
be48ffa
model instantiation
awaelchli 5d68d2f
strategy implementations
awaelchli df5c9ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1779460
tests
awaelchli b92d056
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] de21ae5
connect precision
awaelchli 3153e0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 70c25df
tests
awaelchli 3fb0c50
ddp
awaelchli ccc9b8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d93b7c9
update
awaelchli cb94829
Merge remote-tracking branch 'origin/fabric/half-precision' into fabr…
awaelchli 5f57343
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2fd241a
ddp test
awaelchli 9f80ea3
ddp test
awaelchli 5a72dca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e1e1852
reset
awaelchli 4c07eae
notebook
awaelchli 9b0f0de
notebook
awaelchli bb2321f
notebook
awaelchli 6a14bdd
add test
awaelchli b6f8e1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 71d9308
Merge branch 'master' into fabric/half-precision
awaelchli 66d3a20
fsdp tests
awaelchli 3207812
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1a852e1
comments
awaelchli 1805a60
reset
awaelchli f368a66
Revert "reset"
awaelchli edd35dd
Merge branch 'master' into fabric/module-init
awaelchli edf7135
changelog
awaelchli 65e0d22
add changelog
awaelchli 2993281
add test
awaelchli e2aa4c3
add test
awaelchli caa469a
Merge branch 'fabric/module-init' into fabric/half-precision
awaelchli 63faed7
document true half precision
awaelchli e9463a6
changelog
awaelchli c114f45
Merge branch 'master' into fabric/half-precision
awaelchli 57e9ee8
fix import
awaelchli 808f9d4
fix merge error
awaelchli a81152a
ignore weirdo type error
awaelchli 043911f
Update docs/source-fabric/fundamentals/precision.rst
awaelchli 898ee2a
Update src/lightning/fabric/plugins/precision/half.py
awaelchli 9e581fa
Update src/lightning/fabric/plugins/precision/half.py
awaelchli 9cbd557
update default
awaelchli f8b02e1
mypy
awaelchli 3fac938
Merge branch 'master' into fabric/half-precision
Borda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from contextlib import contextmanager | ||
from typing import Any, Generator, Literal | ||
|
||
import torch | ||
from lightning_utilities.core.apply_func import apply_to_collection | ||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from lightning.fabric.plugins.precision.precision import Precision | ||
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor | ||
|
||
|
||
class HalfPrecision(Precision): | ||
"""Plugin for training with half precision. | ||
|
||
Args: | ||
precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``). | ||
""" | ||
|
||
precision: Literal["bf16-true", "16-true"] = "16-true" | ||
|
||
def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> None: | ||
self.precision = precision | ||
self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16 | ||
|
||
def convert_module(self, module: Module) -> Module: | ||
return module.to(dtype=self._desired_input_dtype) | ||
|
||
@contextmanager | ||
def module_init_context(self) -> Generator[None, None, None]: | ||
"""A context manager to change the default tensor type when initializing the parameters in a module. | ||
|
||
See: :meth:`torch.set_default_tensor_type` | ||
""" | ||
default_dtype = torch.get_default_dtype() | ||
torch.set_default_dtype(self._desired_input_dtype) | ||
yield | ||
torch.set_default_dtype(default_dtype) | ||
|
||
@contextmanager | ||
def forward_context(self) -> Generator[None, None, None]: | ||
"""A context manager to change the default tensor type when tensors get created during the module's | ||
forward. | ||
|
||
See: :meth:`torch.set_default_tensor_type` | ||
""" | ||
default_dtype = torch.get_default_dtype() | ||
torch.set_default_dtype(self._desired_input_dtype) | ||
yield | ||
torch.set_default_dtype(default_dtype) | ||
|
||
def convert_input(self, data: Any) -> Any: | ||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) | ||
|
||
def convert_output(self, data: Any) -> Any: | ||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytest | ||
import torch | ||
|
||
from lightning.fabric.plugins.precision import HalfPrecision | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"precision, expected_dtype", | ||
[ | ||
("bf16-true", torch.bfloat16), | ||
("16-true", torch.half), | ||
], | ||
) | ||
def test_selected_dtype(precision, expected_dtype): | ||
plugin = HalfPrecision(precision=precision) | ||
assert plugin.precision == precision | ||
assert plugin._desired_input_dtype == expected_dtype | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"precision, expected_dtype", | ||
[ | ||
("bf16-true", torch.bfloat16), | ||
("16-true", torch.half), | ||
], | ||
) | ||
def test_module_init_context(precision, expected_dtype): | ||
plugin = HalfPrecision(precision=precision) | ||
with plugin.module_init_context(): | ||
model = torch.nn.Linear(2, 2) | ||
assert torch.get_default_dtype() == expected_dtype | ||
assert model.weight.dtype == expected_dtype | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"precision, expected_dtype", | ||
[ | ||
("bf16-true", torch.bfloat16), | ||
("16-true", torch.half), | ||
], | ||
) | ||
def test_forward_context(precision, expected_dtype): | ||
precision = HalfPrecision(precision=precision) | ||
assert torch.get_default_dtype() == torch.float32 | ||
with precision.forward_context(): | ||
assert torch.get_default_dtype() == expected_dtype | ||
assert torch.get_default_dtype() == torch.float32 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"precision, expected_dtype", | ||
[ | ||
("bf16-true", torch.bfloat16), | ||
("16-true", torch.half), | ||
], | ||
) | ||
def test_convert_module(precision, expected_dtype): | ||
precision = HalfPrecision(precision=precision) | ||
module = torch.nn.Linear(2, 2) | ||
assert module.weight.dtype == module.bias.dtype == torch.float32 | ||
module = precision.convert_module(module) | ||
assert module.weight.dtype == module.bias.dtype == expected_dtype |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.