Skip to content

Commit 9f4fead

Browse files
shs037tensorflower-gardener
authored andcommitted
Add more documentation for gradient_accumulation_steps in keras optimizer.
PiperOrigin-RevId: 469310667
1 parent 9e25eee commit 9f4fead

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,23 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring
105105
opt.minimize(loss, var_list=[var])
106106
```
107107
108-
Note that when using this feature effective batch size is
109-
`gradient_accumulation_steps * one_step_batch_size` where
110-
`one_step_batch_size` size of the batch which is passed to single step
111-
of the optimizer. Thus user may have to adjust learning rate, weight decay
112-
and possibly other training hyperparameters accordingly.
108+
Note that when using this feature,
109+
1. effective batch size is `gradient_accumulation_steps * one_step_batch_size`
110+
where `one_step_batch_size` is the size of the batch passed to single step
111+
of the optimizer. Thus user may have to adjust learning rate, weight decay
112+
and possibly other training hyperparameters accordingly.
113+
2. effective noise (the noise to be used for privacy computation) is
114+
`noise_multiplier * sqrt(gradient_accumulation_steps)`, as the optimizer
115+
adds noise of `self._noise_multiplier` to every step. Thus user may have
116+
to adjust the `noise_multiplier` or the privacy computation.
117+
Additionally, user may need to adjust the batch size in the data generator,
118+
or the number of calls to the data generator, depending on the training
119+
framework used. For example, when using Keras model.fit(...) with a
120+
user-defined data generator, one may need to make the data generator return
121+
`one_step_batch_size` examples each time, and scale the `steps_per_epoch`
122+
by `gradient_accumulation_steps`. This is because the data generator is
123+
called `steps_per_epoch` times per epoch, and one call only returns
124+
`one_step_batch_size` (instead of `effective_batch_size`) examples now.
113125
""".format(
114126
base_class='tf.keras.optimizers.' + cls.__name__,
115127
short_base_class=cls.__name__,

0 commit comments

Comments
 (0)