11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import math
14
15
from typing import Any , Callable , Dict , List , Mapping , Optional , TYPE_CHECKING , Union
15
16
16
17
import torch
33
34
from pytorch_lightning .strategies .strategy import TBroadcast
34
35
from pytorch_lightning .trainer .states import TrainerFn
35
36
from pytorch_lightning .utilities .enums import PrecisionType
36
- from pytorch_lightning .utilities .exceptions import MisconfigurationException
37
37
from pytorch_lightning .utilities .model_helpers import is_overridden
38
38
from pytorch_lightning .utilities .types import STEP_OUTPUT
39
39
40
40
_COLOSSALAI_AVAILABLE = RequirementCache ("colossalai" )
41
+ _COLOSSALAI_GREATER_0_1_10 = RequirementCache ("colossalai>0.1.10" )
41
42
if TYPE_CHECKING and _COLOSSALAI_AVAILABLE :
42
43
with _patch_cuda_is_available ():
43
44
from colossalai .utils .model .colo_init_context import ColoInitContext
@@ -130,7 +131,7 @@ def __init__(
130
131
force_outputs_fp32 : bool = False ,
131
132
gpu_margin_mem_ratio : float = 0.0 ,
132
133
chunk_search_range : int = 64 * 1024 ** 2 ,
133
- chunk_search_n_grids : int = 1024 ,
134
+ chunk_search_n_grids : int = 4096 ,
134
135
min_chunk_size : Optional [int ] = None ,
135
136
initial_scale : float = 2 ** 16 ,
136
137
min_scale : float = 1 ,
@@ -146,7 +147,7 @@ def __init__(
146
147
precision_plugin : Optional [ColossalAIPrecisionPlugin ] = None ,
147
148
) -> None :
148
149
if not _COLOSSALAI_AVAILABLE :
149
- raise MisconfigurationException (
150
+ raise ModuleNotFoundError (
150
151
"To use the `ColossalAIStrategy`, please install `colossalai` first. "
151
152
"Download `colossalai` by consulting `https://colossalai.org/download`."
152
153
)
@@ -237,7 +238,8 @@ def _post_init_method(self, module: torch.nn.Module, *args: Any, **kwargs: Any)
237
238
if getattr (module , "_colossalai_module" , False ) is True :
238
239
return
239
240
super ()._post_init_method (module , * args , ** kwargs )
240
- module ._colossalai_module = True # type: ignore[assignment]
241
+ for sub_module in module .modules ():
242
+ sub_module ._colossalai_module = True # type: ignore[assignment]
241
243
242
244
return ModelShardedContext ()
243
245
@@ -264,23 +266,54 @@ def setup_precision_plugin(self) -> None:
264
266
)
265
267
assert isinstance (self .model , (pl .LightningModule , _LightningPrecisionModuleWrapperBase ))
266
268
pl_module = self .model
267
- process_group = ProcessGroup ()
269
+
268
270
if not hasattr (pl_module , "_colossalai_zero" ):
269
- if self .use_chunk :
270
- chunk_size = self .chunk_size or ChunkManager .search_chunk_size (
271
- self .model , ** self .chunk_size_search_kwargs
271
+ if not _COLOSSALAI_GREATER_0_1_10 :
272
+ if self .use_chunk :
273
+ chunk_size = self .chunk_size or ChunkManager .search_chunk_size (
274
+ self .model , ** self .chunk_size_search_kwargs
275
+ )
276
+ else :
277
+ chunk_size = None
278
+ process_group = ProcessGroup ()
279
+ chunk_manager = ChunkManager (
280
+ chunk_size ,
281
+ process_group ,
282
+ self .enable_distributed_storage ,
283
+ GeminiManager .get_default_device (self .placement_policy ),
272
284
)
285
+ gemini_manager = GeminiManager (self .placement_policy , chunk_manager )
286
+ model = _LightningModuleWrapperBase (self .model )
287
+ self .model = ZeroDDP (model , gemini_manager , self .force_outputs_fp32 )
273
288
else :
274
- chunk_size = None
275
- chunk_manager = ChunkManager (
276
- chunk_size ,
277
- process_group ,
278
- self .enable_distributed_storage ,
279
- GeminiManager .get_default_device (self .placement_policy ),
280
- )
281
- gemini_manager = GeminiManager (self .placement_policy , chunk_manager )
282
- model = _LightningModuleWrapperBase (self .model )
283
- self .model = ZeroDDP (model , gemini_manager , self .force_outputs_fp32 )
289
+ with _patch_cuda_is_available ():
290
+ from colossalai .nn .parallel import GeminiDDP
291
+ from colossalai .utils import get_current_device
292
+ if not self .use_chunk :
293
+ raise ValueError ("`ColossalAIStrategy` must use chunk in versions higher than 0.1.10" )
294
+ chunk_search_range : int = self .chunk_size_search_kwargs .get (
295
+ "search_range" , 32 * 1024 ** 2
296
+ ) # type: ignore[assignment]
297
+ search_range_mb : float = chunk_search_range / 1024 ** 2
298
+ search_n_grids : int = self .chunk_size_search_kwargs .get ("n_grids" , 4096 ) # type: ignore[assignment]
299
+ search_interval : int = math .ceil (chunk_search_range / search_n_grids )
300
+ min_chunk_size_mb : float = self .chunk_size_search_kwargs .get (
301
+ "min_chunk_size" , 32 * 1024 ** 2
302
+ ) # type: ignore[assignment]
303
+ min_chunk_size_mb /= 1024 ** 2
304
+
305
+ model = _LightningModuleWrapperBase (self .model )
306
+ self .model = GeminiDDP (
307
+ module = model ,
308
+ device = get_current_device (),
309
+ placement_policy = self .placement_policy ,
310
+ pin_memory = True ,
311
+ force_outputs_fp32 = self .force_outputs_fp32 ,
312
+ search_range_mb = search_range_mb ,
313
+ hidden_dim = search_interval ,
314
+ min_chunk_size_mb = min_chunk_size_mb ,
315
+ )
316
+
284
317
assert self .model is not None
285
318
pl_module ._colossalai_zero = [self .model ] # type: ignore[assignment]
286
319
else :
@@ -329,10 +362,20 @@ def setup(self, trainer: "pl.Trainer") -> None:
329
362
self .accelerator .setup (trainer )
330
363
assert self .lightning_module is not None
331
364
self .lightning_module ._device = self .root_device
365
+ self .ignore_no_grad_parameters (self .root_device )
332
366
self .setup_optimizers (trainer )
333
367
self .setup_precision_plugin ()
334
368
self .model_to_device ()
335
369
370
+ def ignore_no_grad_parameters (self , running_device : torch .device ) -> None :
371
+ # for those parameters with no gradients
372
+ # we shold ignore them on DDP and move them to CUDA
373
+ assert self .model is not None
374
+ for param in self .model .parameters ():
375
+ if not param .requires_grad :
376
+ setattr (param , "_ddp_to_ignore" , True )
377
+ param .data = param .data .to (running_device )
378
+
336
379
def model_to_device (self ) -> None :
337
380
assert self .lightning_module is not None
338
381
pl_module = self .lightning_module
0 commit comments