Skip to content

Commit 9765b46

Browse files
tjruwaseloadamsMasahiro TanakatohtanaGuanhuaWang
authored andcommitted
Enable ZeRO set/get APIs for NVMe offload (deepspeedai#7046)
- Extend APIs for [debugging](https://deepspeed.readthedocs.io/en/latest/zero3.html#debugging) and [modifying](https://deepspeed.readthedocs.io/en/latest/zero3.html#modifying-partitioned-states) ZeRO partitioned states to NVMe offload. - Add vectorized update API. This is performance-critical for NVMe offloading scenarios. --------- Signed-off-by: Olatunji Ruwase <[email protected]> Signed-off-by: Masahiro Tanaka <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Guanhua Wang <[email protected]> Signed-off-by: Max Kovalenko <[email protected]>
1 parent 5d647a5 commit 9765b46

File tree

13 files changed

+480
-166
lines changed

13 files changed

+480
-166
lines changed

.github/workflows/nv-torch-latest-v100.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ jobs:
5555
run: |
5656
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
5757
cd tests
58-
pytest $PYTEST_OPTS --forked -n 8 unit/ --torch_ver="2.6" --cuda_ver="12.4"
58+
pytest -x $PYTEST_OPTS --forked -n 8 unit/ --torch_ver="2.6" --cuda_ver="12.4"
5959
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.6" --cuda_ver="12.4"

deepspeed/runtime/swap_tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
# DeepSpeed Team
5+
from .utils import MIN_SWAPPABLE_BYTES

deepspeed/runtime/swap_tensor/optimizer_utils.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,45 +26,78 @@ def __init__(self, path, length, offset):
2626
self.length = length
2727

2828

29+
class SwapTensorContext(object):
30+
31+
def __init__(self, tensor, swap_folder):
32+
self.compute_tensor = tensor
33+
self.swap_tensor = torch.Tensor()
34+
self.swap_path = os.path.join(swap_folder, f'{OptimizerSwapper.parameter_id(tensor)}.tensor.swp')
35+
36+
def release_memory(self):
37+
self.compute_tensor.data = torch.Tensor()
38+
self.swap_tensor.data = torch.Tensor()
39+
40+
def set_buffers(self, compute_buffer, swap_buffer):
41+
self.compute_tensor.data = compute_buffer.data
42+
self.swap_tensor.data = swap_buffer.data
43+
44+
2945
class OptimizerStateSwapInfo(object):
3046

3147
def __init__(self, parameter, numel, base_folder):
3248
self.tensors = []
3349
self.param_id = OptimizerSwapper.parameter_id(parameter)
3450
self.swap_folder = base_folder
35-
self.swap_paths = []
3651
self.swapped_gradients = {}
3752
self.unswapped_gradients = {}
3853
self.tensor_numel = numel
3954
self.tensor_dtype = parameter.dtype
4055
self.tensor_device = parameter.device
4156
self.has_state_tensors = False
57+
self.swap_buffers = []
4258
self._add_tensors([parameter])
4359

4460
def numel(self):
4561
return self.tensor_numel
4662

4763
def has_gradients(self):
48-
return self.swapped_gradients or self.unswapped_gradients
64+
return bool(self.swapped_gradients) or bool(self.unswapped_gradients)
4965

5066
def _add_tensors(self, tensor_list):
5167
for t in tensor_list:
52-
self.tensors.append(t)
53-
self.swap_paths.append(os.path.join(self.swap_folder, f'{OptimizerSwapper.parameter_id(t)}.tensor.swp'))
68+
self.tensors.append(SwapTensorContext(t, self.swap_folder))
5469

5570
def add_state_tensors(self, tensor_list):
5671
self.has_state_tensors = True
5772
self._add_tensors(tensor_list)
5873

74+
def num_tensors(self):
75+
return len(self.tensors)
76+
5977
def device(self):
6078
return self.tensor_device
6179

6280
def dtype(self):
6381
return self.tensor_dtype
6482

6583
def release_memory(self):
66-
for tensor in self.tensors:
67-
tensor.data = torch.Tensor()
84+
for t in self.tensors:
85+
t.release_memory()
86+
87+
def get_compute_tensors(self):
88+
return [t.compute_tensor for t in self.tensors]
89+
90+
def get_swap_paths(self):
91+
return [t.swap_path for t in self.tensors]
92+
93+
def get_swap_buffers_and_paths(self, pinned):
94+
swap_buffers = []
95+
swap_paths = []
96+
select_tensors = [t for t in self.tensors if get_accelerator().is_pinned(t.compute_tensor) == pinned]
97+
for t in select_tensors:
98+
swap_buffers.append(t.swap_tensor if pinned else t.compute_tensor)
99+
swap_paths.append(t.swap_path)
100+
return swap_buffers, swap_paths
68101

69102
def get_or_create_gradient_paths(self, offsets, lengths):
70103
gradient_paths = []
@@ -77,11 +110,15 @@ def get_or_create_gradient_paths(self, offsets, lengths):
77110

78111
return gradient_paths
79112

80-
def set_swap_buffers(self, buffers):
81-
compute_lengths = [self.numel()] * len(self.tensors)
113+
def set_swap_buffers(self, buffers, aligned_numel):
114+
num_tensors = len(self.tensors)
115+
compute_lengths = [self.numel()] * num_tensors
82116
compute_buffers = get_sized_buffers(buffers, compute_lengths)
83-
for t, buffer in zip(self.tensors, compute_buffers):
84-
t.data = buffer.data
117+
swap_lengths = [aligned_numel] * num_tensors
118+
swap_buffers = get_sized_buffers(buffers, swap_lengths)
119+
120+
for i, t in enumerate(self.tensors):
121+
t.set_buffers(compute_buffer=compute_buffers[i], swap_buffer=swap_buffers[i])
85122

86123
def get_swap_gradient_buffers(self, swap_buffer):
87124
assert self.numel() <= swap_buffer.numel()
@@ -91,7 +128,7 @@ def get_swap_gradient_paths(self):
91128
return [grad.path for grad in self.swapped_gradients.values()]
92129

93130
def get_unpinned_state_tensors(self):
94-
return [t for t in self.tensors if not get_accelerator().is_pinned(t)]
131+
return [t.compute_tensor for t in self.tensors if not get_accelerator().is_pinned(t.compute_tensor)]
95132

96133
def read_unswapped_gradients(self, dest_buffer):
97134
num_elem_count = 0
@@ -102,6 +139,15 @@ def read_unswapped_gradients(self, dest_buffer):
102139

103140
return num_elem_count
104141

142+
def write_unswapped_gradients(self, src_buffer):
143+
num_elem_count = 0
144+
for offset, grad_partition in self.unswapped_gradients.items():
145+
src_tensor = src_buffer.narrow(0, offset, grad_partition.numel())
146+
grad_partition.data.copy_(src_tensor.data)
147+
num_elem_count += grad_partition.numel()
148+
149+
return num_elem_count
150+
105151
def release_unswapped_gradients(self):
106152
self.unswapped_gradients = {}
107153

@@ -158,10 +204,10 @@ def purge_state(self):
158204
swap_info.tensors = [swap_info.tensors[0]]
159205
swap_info.has_state_tensors = False
160206

161-
def swappable_tensor(self, param=None, numel=None):
162-
assert param is not None or numel is not None, "Either param or numel must be provided"
163-
if param is not None:
164-
return self.min_aio_bytes <= (param.numel() * self.swap_element_size)
207+
def is_swappable_tensor(self, tensor=None, numel=None):
208+
assert tensor is not None or numel is not None, "Either tensor or numel must be provided"
209+
if tensor is not None:
210+
return self.min_aio_bytes <= (tensor.numel() * self.swap_element_size)
165211
return self.min_aio_bytes <= (numel * self.swap_element_size)
166212

167213
def init_timers(self):
@@ -201,7 +247,7 @@ def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gra
201247

202248
self._start_timer(SWAP_OUT_GRADIENT_TIMER)
203249
for tensor, offset in zip(aligned_gradients, aligned_offsets):
204-
if not self.swappable_tensor(param=tensor):
250+
if not self.is_swappable_tensor(tensor=tensor):
205251
swap_info.unswapped_gradients[offset] = tensor
206252
continue
207253

@@ -355,7 +401,7 @@ def _get_swap_paths(self, parameters, num_elems):
355401
]
356402
assert len(swap_info_list) == len(num_elems)
357403

358-
swap_paths = [info.swap_paths[0] for info in swap_info_list]
404+
swap_paths = [info.tensors[0].swap_path for info in swap_info_list]
359405
return swap_paths
360406

361407
def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers):
@@ -386,7 +432,7 @@ def _adjust_for_misaligned_lengths(self, tensors, offsets):
386432
new_offsets = []
387433

388434
for orig_tensor, orig_offset in zip(tensors, offsets):
389-
if not self.swappable_tensor(param=orig_tensor):
435+
if not self.is_swappable_tensor(tensor=orig_tensor):
390436
new_tensors.append(orig_tensor)
391437
new_offsets.append(orig_offset)
392438
continue
@@ -430,7 +476,7 @@ def _get_state_tensors(self, parameter):
430476

431477
tensor_list = []
432478
for state_name, value in self.optimizer.state[parameter].items():
433-
if torch.is_tensor(value):
479+
if torch.is_tensor(value) and self.is_swappable_tensor(tensor=value):
434480
value.ds_id = state_name + '-' + parameter.ds_id
435481
tensor_list.append(value)
436482

deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
77
"""
88

9-
import torch
10-
119
from deepspeed.utils.logging import logger
1210
from deepspeed.ops.op_builder import AsyncIOBuilder
1311
from deepspeed import comm as dist
@@ -63,71 +61,98 @@ def initialize_from_swapped_fp16_params(self, fp16_partitions_info, fp16_num_ele
6361
def flush_gradients(self):
6462
self._flush_gradient_swapper(self.gradient_swapper)
6563

64+
def release_swap_buffers(self, parameter):
65+
swap_info = self._get_param_swap_info(parameter)
66+
if swap_info is None:
67+
return
68+
swap_info.release_memory()
69+
70+
self.swap_buffer_manager.free(swap_info.swap_buffers)
71+
swap_info.swap_buffers = []
72+
6673
def swap_in_optimizer_state(self, parameter, async_parameter=None):
6774
swap_info = self._get_param_swap_info(parameter)
6875
if swap_info is None:
6976
return
7077

7178
self._flush_gradient_swapper(self.gradient_swapper)
7279

73-
required_buffer_count = len(swap_info.tensors) + (1 if swap_info.has_gradients() else 0)
80+
required_buffer_count = swap_info.num_tensors() + (1 if swap_info.has_gradients() else 0)
7481
aligned_numel = self._io_aligned_numel(swap_info.numel())
7582
pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel,
7683
count=required_buffer_count,
7784
dtype=parameter.dtype)
7885
assert pinned_buffers is not None
79-
self.allocated_swap_buffers = pinned_buffers.copy()
86+
swap_info.swap_buffers = pinned_buffers.copy()
8087

8188
self._start_timer(SWAP_IN_PARAM_TIMER)
8289
self._swap_in_parameter(aio_handle=self.aio_handle,
8390
parameter=parameter,
84-
dest_buffers=pinned_buffers[:required_buffer_count])
91+
dest_buffers=pinned_buffers[:swap_info.num_tensors()])
8592
self._stop_timer(SWAP_IN_PARAM_TIMER)
8693
self.timer_names.add(SWAP_IN_PARAM_TIMER)
8794

88-
self._start_timer(SWAP_IN_GRADIENT_TIMER)
89-
self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1])
90-
self._stop_timer(SWAP_IN_GRADIENT_TIMER)
91-
self.timer_names.add(SWAP_IN_GRADIENT_TIMER)
92-
93-
def swap_out_optimizer_state(self, parameter, async_swap=False):
94-
swap_info = self._get_param_swap_info(parameter=parameter)
95-
96-
if swap_info is None:
97-
return
98-
99-
self._start_timer(SWAP_OUT_PARAM_TIMER)
100-
pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info)
101-
swap_bytes = sum([self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors])
95+
if swap_info.has_gradients():
96+
self._start_timer(SWAP_IN_GRADIENT_TIMER)
97+
self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1])
98+
self._stop_timer(SWAP_IN_GRADIENT_TIMER)
99+
self.timer_names.add(SWAP_IN_GRADIENT_TIMER)
102100

101+
def _swap_out_optimizer_state(self, swap_info):
102+
pinned_tensors, pinned_paths = swap_info.get_swap_buffers_and_paths(True)
103103
WRITE_TIMER = 'swap_submit_write'
104104
self._start_timer(WRITE_TIMER)
105105

106106
swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths)
107107
assert self.aio_handle.wait() == len(pinned_tensors)
108-
for t in pinned_tensors:
109-
t.data = torch.Tensor()
110108

109+
unpinned_tensors, unpinned_paths = swap_info.get_swap_buffers_and_paths(False)
111110
if len(unpinned_tensors) > 0:
112111
pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
113112
self._swap_out_unpinned_tensors(aio_handle=self.aio_handle,
114113
unpinned_tensors=unpinned_tensors,
115114
dest_paths=unpinned_paths,
116115
pinned_buffers=pinned_buffers)
117-
self.allocated_swap_buffers += pinned_buffers
116+
swap_info.swap_buffers += pinned_buffers.copy()
118117

119-
for t in unpinned_tensors:
120-
t.data = torch.Tensor()
121118
self._stop_timer(WRITE_TIMER)
119+
self._log_timers([WRITE_TIMER])
120+
121+
def writeback_optimizer_state_and_gradients(self, parameter, write_opt_state, write_gradients):
122+
swap_info = self._get_param_swap_info(parameter=parameter)
123+
124+
if swap_info is None:
125+
return
122126

123-
self.swap_buffer_manager.free(self.allocated_swap_buffers)
124-
self.allocated_swap_buffers = []
127+
if write_opt_state:
128+
self._swap_out_optimizer_state(swap_info)
125129

130+
if write_gradients and swap_info.has_gradients():
131+
param_gradients = swap_info.swapped_gradients.values()
132+
swap_buffers = [parameter.grad.narrow(0, grad.offset, grad.length) for grad in param_gradients]
133+
swap_paths = [grad.path for grad in param_gradients]
134+
swap_out_tensors(self.aio_handle, swap_buffers, swap_paths)
135+
assert len(swap_buffers) == self.aio_handle.wait()
136+
if swap_info.unswapped_gradients:
137+
swap_info.write_unswapped_gradients(src_buffer=parameter.grad)
138+
139+
self.release_swap_buffers(parameter)
140+
141+
def swap_out_optimizer_state(self, parameter, async_swap=False):
142+
swap_info = self._get_param_swap_info(parameter=parameter)
143+
144+
if swap_info is None:
145+
return
146+
147+
swap_bytes = sum(
148+
[self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.get_compute_tensors()])
149+
150+
self._start_timer(SWAP_OUT_PARAM_TIMER)
151+
self._swap_out_optimizer_state(swap_info)
152+
self.release_swap_buffers(parameter)
126153
self._stop_timer(SWAP_OUT_PARAM_TIMER)
127154
self.timer_names.add(SWAP_OUT_PARAM_TIMER)
128155

129-
self._log_timers([WRITE_TIMER])
130-
131156
if DEBUG_MODE and dist.get_rank() == 0:
132157
logger.info(f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')
133158

@@ -142,16 +167,20 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
142167
if swap_info is None:
143168
return
144169

145-
assert len(swap_info.tensors) <= len(dest_buffers)
170+
num_swap_tensors = swap_info.num_tensors()
171+
assert num_swap_tensors <= len(dest_buffers)
146172

147-
swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(swap_info.tensors)
173+
swap_lengths = [self._io_aligned_numel(swap_info.numel())] * num_swap_tensors
148174
swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)
149175

176+
compute_lengths = [swap_info.numel()] * num_swap_tensors
177+
compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
178+
150179
READ_TIMER = 'swap_submit_read_param'
151180
WAIT_TIMER = 'swap_wait_read_param'
152181

153182
self._start_timer(READ_TIMER)
154-
swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
183+
swap_in_tensors(aio_handle, swap_buffers, swap_info.get_swap_paths())
155184
self._stop_timer(READ_TIMER)
156185

157186
swap_bytes = sum([buffer.numel() * buffer.element_size() for buffer in swap_buffers])
@@ -160,40 +189,19 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
160189
aio_handle.wait()
161190
self._stop_timer(WAIT_TIMER)
162191

163-
compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
164-
compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
165-
for t, buffer in zip(swap_info.tensors, compute_buffers):
166-
t.data = buffer.data
192+
swap_info.set_swap_buffers(dest_buffers, self._io_aligned_numel(swap_info.numel()))
167193

168194
self._log_timers([READ_TIMER, WAIT_TIMER])
169195
if DEBUG_MODE and dist.get_rank() == 0:
170196
logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
171197

172-
def _separate_pinned_tensors(self, swap_info):
173-
pinned_tensors = []
174-
pinned_paths = []
175-
176-
unpinned_tensors = []
177-
unpinned_paths = []
178-
179-
for tensor, path in zip(swap_info.tensors, swap_info.swap_paths):
180-
if get_accelerator().is_pinned(tensor):
181-
pinned_tensors.append(tensor)
182-
pinned_paths.append(path)
183-
else:
184-
unpinned_tensors.append(tensor)
185-
unpinned_paths.append(path)
186-
187-
return pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths
188-
189198
def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor):
190199
swap_info = self.swap_params_info[OptimizerSwapper.parameter_id(parameter)]
191200
param_gradients = swap_info.swapped_gradients.values()
192201
swap_buffers = [gradient_tensor.narrow(0, grad.offset, grad.length) for grad in param_gradients]
193202
swap_paths = [grad.path for grad in param_gradients]
194203
SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
195204
SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'
196-
197205
self._start_timer(SWAP_READ_GRADIENTS)
198206
swap_in_tensors(aio_handle, swap_buffers, swap_paths)
199207
self._stop_timer(SWAP_READ_GRADIENTS)

0 commit comments

Comments
 (0)