From 00e771ec0b8a3f2c360960ed47564def4d6b6b1a Mon Sep 17 00:00:00 2001 From: Michal Szutenberg Date: Mon, 11 Apr 2022 13:26:06 +0200 Subject: [PATCH] Add T=bfloat16 to custom_ops registration --- .../custom_ops/image/cc/ops/distort_image_ops.cc | 5 +++-- tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc | 6 +++--- tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc | 6 +++--- .../custom_ops/layers/cc/ops/embedding_bag_ops.cc | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc index 916d9d9eb1..a09d850c2e 100644 --- a/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc +++ b/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc @@ -31,7 +31,8 @@ REGISTER_OP("Addons>AdjustHsvInYiq") .Input("scale_s: float") .Input("scale_v: float") .Output("output: T") - .Attr("T: {uint8, int8, int16, int32, int64, half, float, double}") + .Attr( + "T: {uint8, int8, int16, int32, int64, half, float, double, bfloat16}") .SetShapeFn([](InferenceContext* c) { ShapeHandle images, delta_h, scale_s, scale_v; @@ -70,4 +71,4 @@ output: The hsv-adjusted image or images. No clipping will be done in this op. )Doc"); } // end namespace addons -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc index 52c2a0c62f..3de289656b 100644 --- a/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc +++ b/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc @@ -55,7 +55,7 @@ components: Component ids for each pixel in "image". Same shape as "image". Zero REGISTER_OP("Addons>EuclideanDistanceTransform") .Input("images: uint8") - .Attr("dtype: {float16, float32, float64}") + .Attr("dtype: {bfloat16, float16, float32, float64}") .Output("transformed_images: dtype") .SetShapeFn(shape_inference::UnchangedShape) .Doc(EuclideanDistanceTransformDoc); @@ -65,9 +65,9 @@ REGISTER_OP("Addons>ImageConnectedComponents") .Output("components: int64") .Attr( "dtype: {int64, int32, uint16, int16, uint8, int8, half, float, " - "double, bool, string}") + "bfloat16, double, bool, string}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(ImageConnectedComponentsDoc); } // end namespace addons -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc index ac22a54b25..51673bea58 100644 --- a/tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc +++ b/tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc @@ -29,7 +29,7 @@ REGISTER_OP("Addons>Resampler") .Input("data: T") .Input("warp: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, half, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle data; ShapeHandle warp; @@ -53,7 +53,7 @@ REGISTER_OP("Addons>ResamplerGrad") .Input("grad_output: T") .Output("grad_data: T") .Output("grad_warp: T") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, half, float, double}") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(0)); c->set_output(1, c->input(1)); @@ -62,4 +62,4 @@ REGISTER_OP("Addons>ResamplerGrad") .Doc(R"doc(Resampler Grad op.)doc"); } // namespace addons -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/layers/cc/ops/embedding_bag_ops.cc b/tensorflow_addons/custom_ops/layers/cc/ops/embedding_bag_ops.cc index c63a355d51..709692f108 100644 --- a/tensorflow_addons/custom_ops/layers/cc/ops/embedding_bag_ops.cc +++ b/tensorflow_addons/custom_ops/layers/cc/ops/embedding_bag_ops.cc @@ -28,7 +28,7 @@ REGISTER_OP("Addons>EmbeddingBag") .Input("params: T") .Input("weights: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, half, float, double}") .Attr("Tindices: {int32, int64}") .Attr("combiner: {'SUM', 'MEAN'} = 'SUM'") .SetShapeFn([](InferenceContext* c) { @@ -51,7 +51,7 @@ REGISTER_OP("Addons>EmbeddingBagGrad") .Input("grads: T") .Output("params_grads: T") .Output("weights_grads: T") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, half, float, double}") .Attr("Tindices: {int32, int64}") .Attr("combiner: {'SUM', 'MEAN'} = 'SUM'") .SetShapeFn([](InferenceContext* c) {