Skip to content

Commit 42cf0db

Browse files
1SAAcarmoccaBorda
committed
Add support for colossalai 0.1.11 (#15888)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 42c0884 commit 42cf0db

File tree

3 files changed

+65
-20
lines changed

3 files changed

+65
-20
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
3434

3535

36+
- Added support for colossalai 0.1.11 ([#15888](https://github.com/Lightning-AI/lightning/pull/15888))
37+
38+
3639
- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869))
3740

3841

src/pytorch_lightning/strategies/colossalai.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415
from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING, Union
1516

1617
import torch
@@ -33,11 +34,11 @@
3334
from pytorch_lightning.strategies.strategy import TBroadcast
3435
from pytorch_lightning.trainer.states import TrainerFn
3536
from pytorch_lightning.utilities.enums import PrecisionType
36-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3737
from pytorch_lightning.utilities.model_helpers import is_overridden
3838
from pytorch_lightning.utilities.types import STEP_OUTPUT
3939

4040
_COLOSSALAI_AVAILABLE = RequirementCache("colossalai")
41+
_COLOSSALAI_GREATER_0_1_10 = RequirementCache("colossalai>0.1.10")
4142
if TYPE_CHECKING and _COLOSSALAI_AVAILABLE:
4243
with _patch_cuda_is_available():
4344
from colossalai.utils.model.colo_init_context import ColoInitContext
@@ -130,7 +131,7 @@ def __init__(
130131
force_outputs_fp32: bool = False,
131132
gpu_margin_mem_ratio: float = 0.0,
132133
chunk_search_range: int = 64 * 1024**2,
133-
chunk_search_n_grids: int = 1024,
134+
chunk_search_n_grids: int = 4096,
134135
min_chunk_size: Optional[int] = None,
135136
initial_scale: float = 2**16,
136137
min_scale: float = 1,
@@ -146,7 +147,7 @@ def __init__(
146147
precision_plugin: Optional[ColossalAIPrecisionPlugin] = None,
147148
) -> None:
148149
if not _COLOSSALAI_AVAILABLE:
149-
raise MisconfigurationException(
150+
raise ModuleNotFoundError(
150151
"To use the `ColossalAIStrategy`, please install `colossalai` first. "
151152
"Download `colossalai` by consulting `https://colossalai.org/download`."
152153
)
@@ -237,7 +238,8 @@ def _post_init_method(self, module: torch.nn.Module, *args: Any, **kwargs: Any)
237238
if getattr(module, "_colossalai_module", False) is True:
238239
return
239240
super()._post_init_method(module, *args, **kwargs)
240-
module._colossalai_module = True # type: ignore[assignment]
241+
for sub_module in module.modules():
242+
sub_module._colossalai_module = True # type: ignore[assignment]
241243

242244
return ModelShardedContext()
243245

@@ -264,23 +266,54 @@ def setup_precision_plugin(self) -> None:
264266
)
265267
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
266268
pl_module = self.model
267-
process_group = ProcessGroup()
269+
268270
if not hasattr(pl_module, "_colossalai_zero"):
269-
if self.use_chunk:
270-
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
271-
self.model, **self.chunk_size_search_kwargs
271+
if not _COLOSSALAI_GREATER_0_1_10:
272+
if self.use_chunk:
273+
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
274+
self.model, **self.chunk_size_search_kwargs
275+
)
276+
else:
277+
chunk_size = None
278+
process_group = ProcessGroup()
279+
chunk_manager = ChunkManager(
280+
chunk_size,
281+
process_group,
282+
self.enable_distributed_storage,
283+
GeminiManager.get_default_device(self.placement_policy),
272284
)
285+
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
286+
model = _LightningModuleWrapperBase(self.model)
287+
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
273288
else:
274-
chunk_size = None
275-
chunk_manager = ChunkManager(
276-
chunk_size,
277-
process_group,
278-
self.enable_distributed_storage,
279-
GeminiManager.get_default_device(self.placement_policy),
280-
)
281-
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
282-
model = _LightningModuleWrapperBase(self.model)
283-
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
289+
with _patch_cuda_is_available():
290+
from colossalai.nn.parallel import GeminiDDP
291+
from colossalai.utils import get_current_device
292+
if not self.use_chunk:
293+
raise ValueError("`ColossalAIStrategy` must use chunk in versions higher than 0.1.10")
294+
chunk_search_range: int = self.chunk_size_search_kwargs.get(
295+
"search_range", 32 * 1024**2
296+
) # type: ignore[assignment]
297+
search_range_mb: float = chunk_search_range / 1024**2
298+
search_n_grids: int = self.chunk_size_search_kwargs.get("n_grids", 4096) # type: ignore[assignment]
299+
search_interval: int = math.ceil(chunk_search_range / search_n_grids)
300+
min_chunk_size_mb: float = self.chunk_size_search_kwargs.get(
301+
"min_chunk_size", 32 * 1024**2
302+
) # type: ignore[assignment]
303+
min_chunk_size_mb /= 1024**2
304+
305+
model = _LightningModuleWrapperBase(self.model)
306+
self.model = GeminiDDP(
307+
module=model,
308+
device=get_current_device(),
309+
placement_policy=self.placement_policy,
310+
pin_memory=True,
311+
force_outputs_fp32=self.force_outputs_fp32,
312+
search_range_mb=search_range_mb,
313+
hidden_dim=search_interval,
314+
min_chunk_size_mb=min_chunk_size_mb,
315+
)
316+
284317
assert self.model is not None
285318
pl_module._colossalai_zero = [self.model] # type: ignore[assignment]
286319
else:
@@ -329,10 +362,20 @@ def setup(self, trainer: "pl.Trainer") -> None:
329362
self.accelerator.setup(trainer)
330363
assert self.lightning_module is not None
331364
self.lightning_module._device = self.root_device
365+
self.ignore_no_grad_parameters(self.root_device)
332366
self.setup_optimizers(trainer)
333367
self.setup_precision_plugin()
334368
self.model_to_device()
335369

370+
def ignore_no_grad_parameters(self, running_device: torch.device) -> None:
371+
# for those parameters with no gradients
372+
# we shold ignore them on DDP and move them to CUDA
373+
assert self.model is not None
374+
for param in self.model.parameters():
375+
if not param.requires_grad:
376+
setattr(param, "_ddp_to_ignore", True)
377+
param.data = param.data.to(running_device)
378+
336379
def model_to_device(self) -> None:
337380
assert self.lightning_module is not None
338381
pl_module = self.lightning_module

tests/tests_pytorch/strategies/test_colossalai.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from pytorch_lightning.plugins.precision import ColossalAIPrecisionPlugin
2727
from pytorch_lightning.strategies import ColossalAIStrategy
2828
from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE
29-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3029
from tests_pytorch.helpers.datamodules import ClassifDataModule
3130
from tests_pytorch.helpers.runif import RunIf
3231

@@ -39,7 +38,7 @@ def test_invalid_colosalai(monkeypatch):
3938

4039
monkeypatch.setattr(colossal_strategy, "_COLOSSALAI_AVAILABLE", False)
4140
with pytest.raises(
42-
MisconfigurationException,
41+
ModuleNotFoundError,
4342
match="To use the `ColossalAIStrategy`, please install `colossalai` first. "
4443
"Download `colossalai` by consulting `https://colossalai.org/download`.",
4544
):

0 commit comments

Comments
 (0)