Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))


- Added support for colossalai 0.1.11 ([#15888](https://github.com/Lightning-AI/lightning/pull/15888))


- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869))


Expand Down
79 changes: 61 additions & 18 deletions src/pytorch_lightning/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING, Union

import torch
Expand All @@ -33,11 +34,11 @@
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT

_COLOSSALAI_AVAILABLE = RequirementCache("colossalai")
_COLOSSALAI_GREATER_0_1_10 = RequirementCache("colossalai>0.1.10")
if TYPE_CHECKING and _COLOSSALAI_AVAILABLE:
with _patch_cuda_is_available():
from colossalai.utils.model.colo_init_context import ColoInitContext
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(
force_outputs_fp32: bool = False,
gpu_margin_mem_ratio: float = 0.0,
chunk_search_range: int = 64 * 1024**2,
chunk_search_n_grids: int = 1024,
chunk_search_n_grids: int = 4096,
min_chunk_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
Expand All @@ -146,7 +147,7 @@ def __init__(
precision_plugin: Optional[ColossalAIPrecisionPlugin] = None,
) -> None:
if not _COLOSSALAI_AVAILABLE:
raise MisconfigurationException(
raise ModuleNotFoundError(
"To use the `ColossalAIStrategy`, please install `colossalai` first. "
"Download `colossalai` by consulting `https://colossalai.org/download`."
)
Expand Down Expand Up @@ -237,7 +238,8 @@ def _post_init_method(self, module: torch.nn.Module, *args: Any, **kwargs: Any)
if getattr(module, "_colossalai_module", False) is True:
return
super()._post_init_method(module, *args, **kwargs)
module._colossalai_module = True # type: ignore[assignment]
for sub_module in module.modules():
sub_module._colossalai_module = True # type: ignore[assignment]

return ModelShardedContext()

Expand All @@ -264,23 +266,54 @@ def setup_precision_plugin(self) -> None:
)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
pl_module = self.model
process_group = ProcessGroup()

if not hasattr(pl_module, "_colossalai_zero"):
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
if not _COLOSSALAI_GREATER_0_1_10:
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
)
else:
chunk_size = None
process_group = ProcessGroup()
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
else:
chunk_size = None
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
with _patch_cuda_is_available():
from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import get_current_device
if not self.use_chunk:
raise ValueError("`ColossalAIStrategy` must use chunk in versions higher than 0.1.10")
chunk_search_range: int = self.chunk_size_search_kwargs.get(
"search_range", 32 * 1024**2
) # type: ignore[assignment]
search_range_mb: float = chunk_search_range / 1024**2
search_n_grids: int = self.chunk_size_search_kwargs.get("n_grids", 4096) # type: ignore[assignment]
search_interval: int = math.ceil(chunk_search_range / search_n_grids)
min_chunk_size_mb: float = self.chunk_size_search_kwargs.get(
"min_chunk_size", 32 * 1024**2
) # type: ignore[assignment]
min_chunk_size_mb /= 1024**2

model = _LightningModuleWrapperBase(self.model)
self.model = GeminiDDP(
module=model,
device=get_current_device(),
placement_policy=self.placement_policy,
pin_memory=True,
force_outputs_fp32=self.force_outputs_fp32,
search_range_mb=search_range_mb,
hidden_dim=search_interval,
min_chunk_size_mb=min_chunk_size_mb,
)

assert self.model is not None
pl_module._colossalai_zero = [self.model] # type: ignore[assignment]
else:
Expand Down Expand Up @@ -329,10 +362,20 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
assert self.lightning_module is not None
self.lightning_module._device = self.root_device
self.ignore_no_grad_parameters(self.root_device)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self.model_to_device()

def ignore_no_grad_parameters(self, running_device: torch.device) -> None:
# for those parameters with no gradients
# we shold ignore them on DDP and move them to CUDA
assert self.model is not None
for param in self.model.parameters():
if not param.requires_grad:
setattr(param, "_ddp_to_ignore", True)
param.data = param.data.to(running_device)

def model_to_device(self) -> None:
assert self.lightning_module is not None
pl_module = self.lightning_module
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/strategies/test_colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pytorch_lightning.plugins.precision import ColossalAIPrecisionPlugin
from pytorch_lightning.strategies import ColossalAIStrategy
from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf

Expand All @@ -39,7 +38,7 @@ def test_invalid_colosalai(monkeypatch):

monkeypatch.setattr(colossal_strategy, "_COLOSSALAI_AVAILABLE", False)
with pytest.raises(
MisconfigurationException,
ModuleNotFoundError,
match="To use the `ColossalAIStrategy`, please install `colossalai` first. "
"Download `colossalai` by consulting `https://colossalai.org/download`.",
):
Expand Down