6
6
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
7
7
"""
8
8
9
- import torch
10
-
11
9
from deepspeed .utils .logging import logger
12
10
from deepspeed .ops .op_builder import AsyncIOBuilder
13
11
from deepspeed import comm as dist
@@ -63,71 +61,98 @@ def initialize_from_swapped_fp16_params(self, fp16_partitions_info, fp16_num_ele
63
61
def flush_gradients (self ):
64
62
self ._flush_gradient_swapper (self .gradient_swapper )
65
63
64
+ def release_swap_buffers (self , parameter ):
65
+ swap_info = self ._get_param_swap_info (parameter )
66
+ if swap_info is None :
67
+ return
68
+ swap_info .release_memory ()
69
+
70
+ self .swap_buffer_manager .free (swap_info .swap_buffers )
71
+ swap_info .swap_buffers = []
72
+
66
73
def swap_in_optimizer_state (self , parameter , async_parameter = None ):
67
74
swap_info = self ._get_param_swap_info (parameter )
68
75
if swap_info is None :
69
76
return
70
77
71
78
self ._flush_gradient_swapper (self .gradient_swapper )
72
79
73
- required_buffer_count = len ( swap_info .tensors ) + (1 if swap_info .has_gradients () else 0 )
80
+ required_buffer_count = swap_info .num_tensors ( ) + (1 if swap_info .has_gradients () else 0 )
74
81
aligned_numel = self ._io_aligned_numel (swap_info .numel ())
75
82
pinned_buffers = self .swap_buffer_manager .allocate (num_elems = aligned_numel ,
76
83
count = required_buffer_count ,
77
84
dtype = parameter .dtype )
78
85
assert pinned_buffers is not None
79
- self . allocated_swap_buffers = pinned_buffers .copy ()
86
+ swap_info . swap_buffers = pinned_buffers .copy ()
80
87
81
88
self ._start_timer (SWAP_IN_PARAM_TIMER )
82
89
self ._swap_in_parameter (aio_handle = self .aio_handle ,
83
90
parameter = parameter ,
84
- dest_buffers = pinned_buffers [:required_buffer_count ])
91
+ dest_buffers = pinned_buffers [:swap_info . num_tensors () ])
85
92
self ._stop_timer (SWAP_IN_PARAM_TIMER )
86
93
self .timer_names .add (SWAP_IN_PARAM_TIMER )
87
94
88
- self ._start_timer (SWAP_IN_GRADIENT_TIMER )
89
- self ._swap_in_gradients (aio_handle = self .aio_handle , parameter = parameter , dest_buffer = pinned_buffers [- 1 ])
90
- self ._stop_timer (SWAP_IN_GRADIENT_TIMER )
91
- self .timer_names .add (SWAP_IN_GRADIENT_TIMER )
92
-
93
- def swap_out_optimizer_state (self , parameter , async_swap = False ):
94
- swap_info = self ._get_param_swap_info (parameter = parameter )
95
-
96
- if swap_info is None :
97
- return
98
-
99
- self ._start_timer (SWAP_OUT_PARAM_TIMER )
100
- pinned_tensors , pinned_paths , unpinned_tensors , unpinned_paths = self ._separate_pinned_tensors (swap_info )
101
- swap_bytes = sum ([self ._io_aligned_numel (t .numel ()) * t .element_size () for t in swap_info .tensors ])
95
+ if swap_info .has_gradients ():
96
+ self ._start_timer (SWAP_IN_GRADIENT_TIMER )
97
+ self ._swap_in_gradients (aio_handle = self .aio_handle , parameter = parameter , dest_buffer = pinned_buffers [- 1 ])
98
+ self ._stop_timer (SWAP_IN_GRADIENT_TIMER )
99
+ self .timer_names .add (SWAP_IN_GRADIENT_TIMER )
102
100
101
+ def _swap_out_optimizer_state (self , swap_info ):
102
+ pinned_tensors , pinned_paths = swap_info .get_swap_buffers_and_paths (True )
103
103
WRITE_TIMER = 'swap_submit_write'
104
104
self ._start_timer (WRITE_TIMER )
105
105
106
106
swap_out_tensors (self .aio_handle , pinned_tensors , pinned_paths )
107
107
assert self .aio_handle .wait () == len (pinned_tensors )
108
- for t in pinned_tensors :
109
- t .data = torch .Tensor ()
110
108
109
+ unpinned_tensors , unpinned_paths = swap_info .get_swap_buffers_and_paths (False )
111
110
if len (unpinned_tensors ) > 0 :
112
111
pinned_buffers = self .swap_buffer_manager .allocate_all (num_elems = self .largest_numel , dtype = self .dtype )
113
112
self ._swap_out_unpinned_tensors (aio_handle = self .aio_handle ,
114
113
unpinned_tensors = unpinned_tensors ,
115
114
dest_paths = unpinned_paths ,
116
115
pinned_buffers = pinned_buffers )
117
- self . allocated_swap_buffers += pinned_buffers
116
+ swap_info . swap_buffers += pinned_buffers . copy ()
118
117
119
- for t in unpinned_tensors :
120
- t .data = torch .Tensor ()
121
118
self ._stop_timer (WRITE_TIMER )
119
+ self ._log_timers ([WRITE_TIMER ])
120
+
121
+ def writeback_optimizer_state_and_gradients (self , parameter , write_opt_state , write_gradients ):
122
+ swap_info = self ._get_param_swap_info (parameter = parameter )
123
+
124
+ if swap_info is None :
125
+ return
122
126
123
- self . swap_buffer_manager . free ( self . allocated_swap_buffers )
124
- self . allocated_swap_buffers = []
127
+ if write_opt_state :
128
+ self . _swap_out_optimizer_state ( swap_info )
125
129
130
+ if write_gradients and swap_info .has_gradients ():
131
+ param_gradients = swap_info .swapped_gradients .values ()
132
+ swap_buffers = [parameter .grad .narrow (0 , grad .offset , grad .length ) for grad in param_gradients ]
133
+ swap_paths = [grad .path for grad in param_gradients ]
134
+ swap_out_tensors (self .aio_handle , swap_buffers , swap_paths )
135
+ assert len (swap_buffers ) == self .aio_handle .wait ()
136
+ if swap_info .unswapped_gradients :
137
+ swap_info .write_unswapped_gradients (src_buffer = parameter .grad )
138
+
139
+ self .release_swap_buffers (parameter )
140
+
141
+ def swap_out_optimizer_state (self , parameter , async_swap = False ):
142
+ swap_info = self ._get_param_swap_info (parameter = parameter )
143
+
144
+ if swap_info is None :
145
+ return
146
+
147
+ swap_bytes = sum (
148
+ [self ._io_aligned_numel (t .numel ()) * t .element_size () for t in swap_info .get_compute_tensors ()])
149
+
150
+ self ._start_timer (SWAP_OUT_PARAM_TIMER )
151
+ self ._swap_out_optimizer_state (swap_info )
152
+ self .release_swap_buffers (parameter )
126
153
self ._stop_timer (SWAP_OUT_PARAM_TIMER )
127
154
self .timer_names .add (SWAP_OUT_PARAM_TIMER )
128
155
129
- self ._log_timers ([WRITE_TIMER ])
130
-
131
156
if DEBUG_MODE and dist .get_rank () == 0 :
132
157
logger .info (f'optimizer_param_swap_out: { (swap_bytes / (1024 ** 3 )):5.2f} GB' )
133
158
@@ -142,16 +167,20 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
142
167
if swap_info is None :
143
168
return
144
169
145
- assert len (swap_info .tensors ) <= len (dest_buffers )
170
+ num_swap_tensors = swap_info .num_tensors ()
171
+ assert num_swap_tensors <= len (dest_buffers )
146
172
147
- swap_lengths = [self ._io_aligned_numel (swap_info .numel ())] * len ( swap_info . tensors )
173
+ swap_lengths = [self ._io_aligned_numel (swap_info .numel ())] * num_swap_tensors
148
174
swap_buffers = get_sized_buffers (dest_buffers , swap_lengths )
149
175
176
+ compute_lengths = [swap_info .numel ()] * num_swap_tensors
177
+ compute_buffers = get_sized_buffers (dest_buffers , compute_lengths )
178
+
150
179
READ_TIMER = 'swap_submit_read_param'
151
180
WAIT_TIMER = 'swap_wait_read_param'
152
181
153
182
self ._start_timer (READ_TIMER )
154
- swap_in_tensors (aio_handle , swap_buffers , swap_info .swap_paths )
183
+ swap_in_tensors (aio_handle , swap_buffers , swap_info .get_swap_paths () )
155
184
self ._stop_timer (READ_TIMER )
156
185
157
186
swap_bytes = sum ([buffer .numel () * buffer .element_size () for buffer in swap_buffers ])
@@ -160,40 +189,19 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
160
189
aio_handle .wait ()
161
190
self ._stop_timer (WAIT_TIMER )
162
191
163
- compute_lengths = [swap_info .numel ()] * len (swap_info .tensors )
164
- compute_buffers = get_sized_buffers (dest_buffers , compute_lengths )
165
- for t , buffer in zip (swap_info .tensors , compute_buffers ):
166
- t .data = buffer .data
192
+ swap_info .set_swap_buffers (dest_buffers , self ._io_aligned_numel (swap_info .numel ()))
167
193
168
194
self ._log_timers ([READ_TIMER , WAIT_TIMER ])
169
195
if DEBUG_MODE and dist .get_rank () == 0 :
170
196
logger .info (f'optimizer_param_swap_in: { (swap_bytes / (1024 ** 3 )):5.2f} GB' )
171
197
172
- def _separate_pinned_tensors (self , swap_info ):
173
- pinned_tensors = []
174
- pinned_paths = []
175
-
176
- unpinned_tensors = []
177
- unpinned_paths = []
178
-
179
- for tensor , path in zip (swap_info .tensors , swap_info .swap_paths ):
180
- if get_accelerator ().is_pinned (tensor ):
181
- pinned_tensors .append (tensor )
182
- pinned_paths .append (path )
183
- else :
184
- unpinned_tensors .append (tensor )
185
- unpinned_paths .append (path )
186
-
187
- return pinned_tensors , pinned_paths , unpinned_tensors , unpinned_paths
188
-
189
198
def _swap_in_pinned_gradients (self , aio_handle , parameter , gradient_tensor ):
190
199
swap_info = self .swap_params_info [OptimizerSwapper .parameter_id (parameter )]
191
200
param_gradients = swap_info .swapped_gradients .values ()
192
201
swap_buffers = [gradient_tensor .narrow (0 , grad .offset , grad .length ) for grad in param_gradients ]
193
202
swap_paths = [grad .path for grad in param_gradients ]
194
203
SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
195
204
SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'
196
-
197
205
self ._start_timer (SWAP_READ_GRADIENTS )
198
206
swap_in_tensors (aio_handle , swap_buffers , swap_paths )
199
207
self ._stop_timer (SWAP_READ_GRADIENTS )
0 commit comments