@@ -102,7 +102,7 @@ def get_swap_buffers_and_paths(self, pinned):
102
102
def get_or_create_gradient_paths (self , offsets , lengths ):
103
103
gradient_paths = []
104
104
for offset , length in zip (offsets , lengths ):
105
- if not offset in self .swapped_gradients .keys ():
105
+ if offset not in self .swapped_gradients .keys ():
106
106
path = os .path .join (self .swap_folder , f'{ self .param_id } _gradient_{ offset } _{ length } .tensor.swp' )
107
107
self .swapped_gradients [offset ] = FlattenedTensorSwapInfo (path , length , offset )
108
108
@@ -233,7 +233,7 @@ def _flush_gradient_swapper(self, gradient_swapper):
233
233
self .timer_names .update (gradient_swapper .get_timer_names ())
234
234
235
235
def _swap_out_gradients (self , parameter , gradient_offsets , gradient_tensors , gradient_swapper ):
236
- if not OptimizerSwapper .parameter_id (parameter ) in self .swap_params_info .keys ():
236
+ if OptimizerSwapper .parameter_id (parameter ) not in self .swap_params_info .keys ():
237
237
return
238
238
239
239
swap_info = self .swap_params_info [OptimizerSwapper .parameter_id (parameter )]
@@ -471,7 +471,7 @@ def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer):
471
471
)
472
472
473
473
def _get_state_tensors (self , parameter ):
474
- if not parameter in self .optimizer .state :
474
+ if parameter not in self .optimizer .state :
475
475
return []
476
476
477
477
tensor_list = []
@@ -490,7 +490,7 @@ def _update_param_state_info(self, swap_info, parameter):
490
490
491
491
def _create_param_swap_info (self , parameter , numel ):
492
492
param_id = OptimizerSwapper .parameter_id (parameter )
493
- assert not param_id in self .swap_params_info
493
+ assert param_id not in self .swap_params_info
494
494
495
495
self .swap_params_info [param_id ] = OptimizerStateSwapInfo (parameter = parameter ,
496
496
numel = numel ,
0 commit comments