File tree Expand file tree Collapse file tree 2 files changed +12
-5
lines changed
tensorflow_privacy/privacy/dp_query Expand file tree Collapse file tree 2 files changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -268,10 +268,11 @@ def derive_metrics(self, global_state):
268
268
269
269
def _zeros_like (arg ):
270
270
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
271
- try :
272
- arg = tf .convert_to_tensor (value = arg )
273
- except TypeError :
274
- pass
271
+ if not isinstance (arg , tf .TensorSpec ):
272
+ try :
273
+ arg = tf .convert_to_tensor (value = arg )
274
+ except TypeError :
275
+ pass
275
276
return tf .zeros (arg .shape , arg .dtype )
276
277
277
278
Original file line number Diff line number Diff line change 18
18
19
19
import distutils
20
20
import math
21
- from typing import Optional
21
+ from typing import Any , Optional
22
22
23
23
import attr
24
24
import dp_accounting
@@ -136,6 +136,12 @@ def initial_global_state(self):
136
136
arity = self ._arity ,
137
137
inner_query_state = self ._inner_query .initial_global_state ())
138
138
139
+ def initial_sample_state (self , template : Optional [Any ] = None ):
140
+ """Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
141
+ unprocessed_sample_state = super ().initial_sample_state (template )
142
+ sample_params = self .derive_sample_params (self .initial_global_state ())
143
+ return self .preprocess_record (sample_params , unprocessed_sample_state )
144
+
139
145
def derive_sample_params (self , global_state ):
140
146
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
141
147
return (global_state .arity ,
You can’t perform that action at this time.
0 commit comments