Skip to content

Commit 17c8be0

Browse files
authored
Fix the GPU memory usage of ZeRO-Offload (only update stage_1_and_2.py) (#7309)
Signed-off-by: Armin Zhu <[email protected]> Fix the memory usage of ZeRO-Offload with stage 1 and 2. Before the fix, the memory usage is about 3x that of params_FP16. This is caused by the H2D data copy is using different data type. Now the GPU memory usage is about 1x params_FP16. And the H2D memory copy needs a 16bit pinned memory buffer.
1 parent b666844 commit 17c8be0

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,10 @@ def __init__(self,
258258
# that this process will update
259259
self.single_partition_of_fp32_groups = []
260260

261+
# a 16-bit CPU param buffer for cpu offload
262+
if self.cpu_offload:
263+
self.param_buffer_of_bit16_for_cpu_offload_groups = []
264+
261265
# param partition info
262266

263267
# These are the parameters in each group that will not be updated by this process directly
@@ -406,6 +410,16 @@ def __init__(self,
406410

407411
if self.cpu_offload:
408412
weights_partition = get_accelerator().pin_memory(weights_partition)
413+
temp_dtype = self.parallel_partitioned_bit16_groups[i][partition_id].dtype
414+
temp_buffer_bit16 = torch.full(weights_partition.shape,
415+
fill_value=0.0,
416+
dtype=temp_dtype,
417+
device=weights_partition.device)
418+
if self.cpu_offload_pin_memory:
419+
temp_pinned = get_accelerator().pin_memory(temp_buffer_bit16)
420+
self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_pinned)
421+
else:
422+
self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_buffer_bit16)
409423

410424
self.single_partition_of_fp32_groups.append(weights_partition)
411425

@@ -1887,8 +1901,9 @@ def step(self, closure=None):
18871901
# bit16_partitions[partition_id].data.copy_(fp32_partition.data)
18881902
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
18891903
fp32_partition = self.single_partition_of_fp32_groups[i]
1890-
bit16_partitions[partition_id].data.copy_(
1891-
fp32_partition.to(get_accelerator().current_device_name()).data)
1904+
bit16_partition_buffer = self.param_buffer_of_bit16_for_cpu_offload_groups[i]
1905+
bit16_partition_buffer.data.copy_(fp32_partition.data)
1906+
bit16_partitions[partition_id].data.copy_(bit16_partition_buffer.data, non_blocking=True)
18921907

18931908
self.timers(OPTIMIZER_STEP_TIMER).stop()
18941909
else:

0 commit comments

Comments
 (0)