diff --git a/README.md b/README.md
index 4d080e2..0bc919f 100644
--- a/README.md
+++ b/README.md
@@ -386,13 +386,14 @@ Note | Source
6 | `tf.nn.ctc_loss` backprop | NS | NS | NS | TDO |
7 | Fused sofmax/crossentropy:
`tf.nn.*_cross_entropy_with_logits`
backprop | NS | NS | NS | NS |
-Note | Source | TF < 2.4 | NGC 20.03+ | TF 2.4 |
-----:|:----------------------------------------------------------------------------------------------------------------------------------------|:----------|:-----------|:-------|
- 8 | `tf.image.resize` with `method=ResizeMethod.BILINEAR`
and `tf.keras.layers.UpSampling2D` with
`interpolation='bilinear'` backprop | NS | TDO | TDO |
- 9 | `tf.image.resize` with `method=ResizeMethod.NEAREST`
and `tf.keras.layers.UpSampling2D` with
`interpolation='nearest'` backprop | NS | NS | NS |
- 10 | `tf.math.segment_sum` and `tf.math.unsorted_segment_sum`
forward, and `tf.gather` and `tfa.image.dense_image_warp`
backprop | NS | NS | NS |
- 11 | `tf.image.crop_and_resize` backprop to `image` (on CPU
or GPU) and backprop to `boxes` | NS | NS | NS |
- 12 | `tf.sparse.sparse_dense_matmul` forward | NS | NS | NS |
+Note | Source | TF < 2.4 | NGC 20.03+ | TF 2.4 |
+----:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------|:-----------|:-------|
+ 8 | `tf.image.resize` with `method=ResizeMethod.BILINEAR`
and `tf.keras.layers.UpSampling2D` with
`interpolation='bilinear'` backprop | NS | TDO | TDO |
+ 9 | `tf.image.resize` with `method=ResizeMethod.NEAREST`
and `tf.keras.layers.UpSampling2D` with
`interpolation='nearest'` backprop | NS | NS | NS |
+ 10 | `tf.math.segment_sum`, `tf.math.unsorted_segment_sum`,
and `tf.convert_to_tensor` forward.
And `tf.gather` and `tfa.image.dense_image_warp`
backprop | NS | NS | NS |
+ 11 | `tf.image.crop_and_resize` backprop to `image` (on CPU
or GPU) and backprop to `boxes` | NS | NS | NS |
+ 12 | `tf.sparse.sparse_dense_matmul` forward | NS | NS | NS |
+ 13 | `tf.math.unsorted_segment_mean`,
`tf.math.unsorted_segment_prod`, and
`tf.math.unsorted_segment_sqrt` forward | NS | NS | NS |
##### Key to the Solutions Referenced Above
@@ -479,11 +480,36 @@ Note | Source
issues [#12](https://github.com/NVIDIA/framework-determinism/issues/12) and
[#24](https://github.com/NVIDIA/framework-determinism/issues/24))
10. Segment reduction ops `tf.math.segment_sum` and
- `tf.math.unsorted_segment_sum` have nondeterministic forward operation on
- GPU. Other ops that are dependent on these ops, including `tf.gather` and
- `tfa.image.dense_image_warp` (both in backprop), therefore also operate
- nondeterministically. See
- [Issue 39751](https://github.com/tensorflow/tensorflow/issues/39751).
+ `tf.math.unsorted_segment_sum` can exhibit nondeterministic forward
+ operation when running on a GPU. `tf.convert_to_tensor`, when fed with
+ (sparse) `tf.IndexedSlices`, uses this potentially nondeterminitic
+ segment sum functionality in its forward direction and therefore may
+ introduce truly random noise into its output when a slice index is
+ represented more than twice in its input (such as when reducing the word
+ embedding gradients from multiple instances of the same word in a sentence
+ or across a batch of sentences). `tf.gather` is often used to select word
+ embeddings from an embedding matrix in a model's forward direction and
+ `tf.gather`'s backprop generates sparse gradients conveyed as
+ `tf.IndexedSlices`. The reduction of the back-propagated sparse gradients
+ from `tf.gather` by `tf.convert_to_tensor` can therefore introduce truly
+ random noise into an embedding trainable variable. A lower-performance
+ work-around for this nondeterminism related to the use of `tf.gather` is
+ to use `tf.linalg.matmul` instead:
+
+ ```
+ # inputs_embeds = tf.gather(embeddings, input_ids)
+ input_embeds = tf.dtypes.cast(
+ tf.one_hot(input_ids, embeddings.shape[0]),
+ embeddings.dtype) @ embeddings
+ ```
+
+ Note that the backward (and forward) functionality of `tf.gather` itself
+ _is_ deterministic. The backprop for `tfa.image.dense_image_warp` may
+ introduce truly random noise because it also uses the nondeterministic
+ segment sum functionality. See
+ [Issue 39751](https://github.com/tensorflow/tensorflow/issues/39751). A
+ patch that will make the segment sum ops function deterministically is in
+ development.
11. Backprop to `image` on `tf.image.crop_and_resize` introduces
nondeterministic noise when running on either CPU or GPU. Backprop to
`boxes` introduces nondeterministic noise when running on GPU. See
@@ -493,6 +519,13 @@ Note | Source
12. The forward path of `tf.sparse.sparse_dense_matmul` introduces
nondeterminism for `tf.float32` and (allegedly) for `tf.float64`. See
TF [Issue 18037](https://github.com/tensorflow/tensorflow/issues/18037).
+ 13. Based on initial work from [Lin Lan](https://github.com/llan-ml), we may
+ have have ruled-out nondeterminism in other `tf.math.segment_*` ops beyond
+ `tf.math.segment_sum` and in other `tf.math_unsorted_segment_*` ops beyond
+ `tf.math.unsorted_segment_sum`, `tf.math.unsorted_segment_mean`,
+ `tf.math.unsorted_segment_prod`, and `tf.math_unsorted_segment_sqrt`; see
+ [issue 31](https://github.com/NVIDIA/framework-determinism/issues/31).
+ Also see note 10, above.
#### Other Possible GPU-Specific Sources of Non-Determinism
@@ -558,7 +591,7 @@ This section catalogs relevant links.
### TensorFlow Issues
-GitHiub issues in the TensorFlow project:
+GitHub issues in the TensorFlow project:
Number | Title | Date Opened | Status |
--------------------------------------------------------------:|:-----------------------------------------------------------------------------------------|:------------|:-------|
@@ -590,7 +623,8 @@ GitHub issues in dependent or related projects:
### TensorFlow Pull Requests
The following pull requests (and some inidividual commits) are those in the
-TensorFlow GitHub repo that are directly related to this project. As we have
+TensorFlow GitHub repo (`github.com/tensorflow/tensorflow`) that are directly
+related to this project. As we have
[discovered](scripts/README.md#find-tensorflow-commits), 1.8% of all commits
seem to reference, or have some relationship with, "determinism" or
"deterministic". As of 2020-01-30, that was 1,391 commits.
@@ -618,7 +652,8 @@ ID | Title
[38089](https://github.com/tensorflow/tensorflow/pull/38089) | Add reminder to test deterministic cuDNN CTC loss | closed | | |
[38509](https://github.com/tensorflow/tensorflow/pull/38509) | List deterministic op func bug fixes in v2.2
release notes | merged | 2020-04-15 | 2.2 |
[39243](https://github.com/tensorflow/tensorflow/pull/39243) | GPU-deterministic tf.image.resize (bilinear) | merged | 2020-09-22 | 2.4 |
-
+[44717](https://github.com/tensorflow/tensorflow/pull/44717) | Add to rel notes: deterministic tf.image.resize (bilinear) | merged | 2020-11-13 | 2.4 |
+
Notes:
1. These are individual commits.
@@ -628,6 +663,15 @@ Notes:
[1004]: https://github.com/tensorflow/tensorflow/commit/8b7a3db0b6e09415b5640be4986fb4d7c6e5209a
[1005]: https://github.com/tensorflow/tensorflow/commit/9e096debc4a0909deb69970f38bee7b77e5e5f7d
+### Other TensorFlow Organization Pull Requests
+
+These are relevant pull requests against repositories in
+`github.com/tensorflow` other than `github.com/tensorflow/tensorflow`
+
+ Repository | Number | Title | Date Opened | Status |
+:-----------|---------------------------------------------------------:|:----------------------------------------------------------------------|:------------|:-------|
+ community | [346](https://github.com/tensorflow/community/pull/346) | RFC: Enhancing determinism in TF | 2021-01-19 | Open |
+
### PyTorch Pull Requests
ID | Title | Status | Date Merged | Version |
@@ -685,6 +729,7 @@ Andrew Kerr,
Xiang Bo Kong,
Nicolas Koumchatzky,
Jorge Albericio Latorre,
+Lin Lan,
Simon Layton,
Ned Letcher,
Jose Alvarez Lopez,
diff --git a/fwd9m/tensorflow/__init__.py b/fwd9m/tensorflow/__init__.py
index 975356c..602413b 100644
--- a/fwd9m/tensorflow/__init__.py
+++ b/fwd9m/tensorflow/__init__.py
@@ -19,4 +19,4 @@
# What follows is the public API for fwd9m.tensorflow
from .enable_determinism import _enable_determinism as enable_determinism
-from .patch import _patch as patch # deprecated
+from .patch import _patch as patch # deprecated
\ No newline at end of file
diff --git a/fwd9m/tensorflow/enable_determinism.py b/fwd9m/tensorflow/enable_determinism.py
index 53a77de..e82d71f 100644
--- a/fwd9m/tensorflow/enable_determinism.py
+++ b/fwd9m/tensorflow/enable_determinism.py
@@ -23,46 +23,49 @@
import tensorflow as tf
-from .patch import _patch_bias_add
-from .patch import _patch_unsorted_segment_sum
-from .patch import _patch_segment_sum
-from ..utils import _Version as Version
-from ..version import __version__ as package_version
+# By calling the deprecated patch API here, we continue to test its effect
+# without having to test it explicitly. Note that this form of import
+# necessarily breaks the Google Python Style Guide rule to import packages
+# and modules only (and not individual functions).
+from ..tensorflow import patch as patch_bias_add
+from . import patch_segment_sum
+from . import patch_unsorted_segment_sum
+from . import patch_softmax_xent
+from . import patch_sparse_softmax_xent
+from .. import utils
+from .. import version
def _enable_determinism(seed=None):
"""Provides a best-effort recipe to increase framework determinism when
running on GPUs.
-
Call this method either before or after explicitly importing TensorFlow,
but always before constructing any graphs.
-
- This function cannot address all possible sources of non-determinism. Please
+ This function cannot address all possible sources of non-determinism. Please
see further instructions at https://github.com/NVIDIA/framework-determinism
to understand how to use it in a larger deterministic context.
-
Arguments:
seed:
-
Returns: None
"""
- tf_vers = Version(tf.version.VERSION)
+ tf_vers = utils._Version(tf.version.VERSION)
ngc_tf_container_version_string = os.environ.get('NVIDIA_TENSORFLOW_VERSION')
if ngc_tf_container_version_string:
in_ngc_cont = True
- ngc_vers = Version(ngc_tf_container_version_string)
+ ngc_vers = utils._Version(ngc_tf_container_version_string)
else:
in_ngc_cont = False
if not in_ngc_cont and tf_vers.between('1.14', '2.0'):
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
- _patch_bias_add()
+ patch_bias_add(_silent=True)
if in_ngc_cont and ngc_vers.at_least('19.06') or tf_vers.at_least('2.1'):
os.environ['TF_DETERMINISTIC_OPS'] = '1'
if in_ngc_cont and ngc_vers.at_least('19.06') or tf_vers.at_least('1.14'):
- _patch_unsorted_segment_sum()
- _patch_segment_sum()
- # Apply the fused softmax/cross-entropy patch here
+ patch_segment_sum._patch_segment_sum()
+ patch_unsorted_segment_sum._patch_unsorted_segment_sum()
+ patch_softmax_xent._patch_softmax_xent()
+ patch_sparse_softmax_xent._patch_sparse_softmax_xent()
pass
# TODO: Add other recipe items (e.g. seed)
print("%s (version %s) has been applied to TensorFlow "
- "version %s" % (__name__, package_version,
+ "version %s" % (__name__, version.__version__,
tf_vers.original_version_string))
diff --git a/fwd9m/tensorflow/patch.py b/fwd9m/tensorflow/patch.py
index 6ef217d..866c0d5 100644
--- a/fwd9m/tensorflow/patch.py
+++ b/fwd9m/tensorflow/patch.py
@@ -37,155 +37,40 @@
import sys
import tensorflow as tf
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.keras import backend as K
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import gen_math_ops
-from ..utils import _Version as Version
-from ..version import __version__ as package_version
+from . import patch_bias_add
+from .. import utils
+from .. import version
# This function was used to patch tf.nn.bias_add in a limited range of stock
# TensorFlow versions. It is now deprecated and we are no longer developing it.
# enable_determinism should be used.
-def _patch():
+def _patch(_silent=False):
"""Patches TensorFlow to increase determinism when running on GPUs.
-
Calling this method either before or after explicitly importing TensorFlow,
but always before constructing any graphs, will increase the determinsism
when running on GPUs.
-
Returns: nothing
-
Raises:
TypeError (1) if a patch is not available for the installed version of
TensorFlow (either because it doesn't need one or because one has not
yet been implemented), or (2) if there is an attempt to apply the patch
inside an NGC TF container (where it should not be needed).
"""
- print("WARNING: %s has been deprecated. Please use enable_determinism (which "
- "supports all versions of TensorFlow)." % __name__)
+ if not _silent:
+ print("WARNING: %s has been deprecated. Please use enable_determinism "
+ "(which supports all versions of TensorFlow)." % __name__)
if os.environ.get('NVIDIA_TENSORFLOW_VERSION'):
raise TypeError("%s: TensorFlow inside NGC containers does not "
"require patching" % __name__)
- tf_vers = Version(tf.version.VERSION)
+ tf_vers = utils._Version(tf.version.VERSION)
if tf_vers.between('1.14', '2.0'):
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
- _patch_bias_add()
- # Apply the fused softmax/cross-entropy patch here
- print("TensorFlow version %s has been patched using %s version %s" %
- (tf_vers.original_version_string, __name__,
- package_version))
+ patch_bias_add._patch_bias_add()
+ if not _silent:
+ print("TensorFlow version %s has been patched using %s version %s" %
+ (tf_vers.original_version_string, __name__,
+ version.__version__))
else:
raise TypeError("%s: No patch available for version %s of TensorFlow" %
(__name__, tf_vers.original_version_string))
-
-def _patch_bias_add():
- _new_bias_add.__doc__ = tf.nn.bias_add.__doc__
- tf.nn.bias_add = _new_bias_add # access via public API
- nn.bias_add = _new_bias_add # called from tf.keras.layers.convolutional.Conv
- nn_ops.bias_add = _new_bias_add # called from tests
-
-# The original, pre-patched method can be viewed at
-# https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/python/ops/nn_ops.py#L2628
-#
-# This patched version of bias_add does not implement some of the error checks
-# provided by the original op. For more information, see the list of test cases
-# excluded from the testing of the patched op functionality.
-def _new_bias_add(value, bias, data_format=None, name=None):
- """ERROR: docstring should have been added programatically. """
- with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
- if data_format is not None:
- if data_format.startswith("NC"):
- data_format = "NCHW"
- elif data_format.startswith("N") and data_format.endswith("C"):
- data_format = "NHWC"
- else:
- raise ValueError("data_format must be of the form `N...C` or `NC...`")
-
- if not context.executing_eagerly():
- value = ops.convert_to_tensor(value, name="input")
- bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
-
- if data_format == 'NCHW':
- broadcast_shape_head = [1, array_ops.size(bias)]
- broadcast_shape_tail = array_ops.ones(array_ops.rank(value) - 2,
- dtype=dtypes.int32)
- broadcast_shape = array_ops.concat(
- [broadcast_shape_head, broadcast_shape_tail], 0)
- return math_ops.add(
- value, array_ops.reshape(bias, broadcast_shape), name=name)
- else: # data_format == 'NHWC' or data_format == None
- return math_ops.add(value, bias, name=name)
-
-
-def _patch_unsorted_segment_sum():
- _new_unsorted_segment_sum.__doc__ = tf.math.unsorted_segment_sum.__doc__
- math_ops.unsorted_segment_sum = _new_unsorted_segment_sum # access via public API
- tf.math.unsorted_segment_sum = _new_unsorted_segment_sum # access via public API
-
-def _patch_segment_sum():
- _new_segment_sum.__doc__ = tf.math.segment_sum.__doc__
- math_ops.segment_sum = _new_segment_sum # access via public API
- tf.math.segment_sum = _new_segment_sum # access via public API
-
-# The original, pre-patched function is automatically-generated. Therefore, we
-# cannot provide a URL to its location in the source repository.
-# For the history of this patch, please refer to
-# https://github.com/tensorflow/tensorflow/issues/39751
-def _new_unsorted_segment_sum(data, segment_ids, num_segments, name=None):
- """ERROR: docstring should have been added programatically. """
- with ops.name_scope(
- name, "UnsortedSegmentSum", [data, segment_ids, num_segments]) as name:
- # Note that data can be a vector-like list (or an n-dimensional
- # tensor-like list of lists). We convert to tensor here to replicate the
- # behavior of the pre-existing op.
- data = tf.convert_to_tensor(data)
-
- # Note that this patch does not provide determinism when the dtype of the
- # data argument is tf.float64 or tf.complex128.
- orig_dtype = data.dtype
- if 'float' in str(orig_dtype):
- data = tf.cast(data, dtype=tf.float64)
- elif 'complex' in str(orig_dtype):
- data = tf.cast(data, dtype=tf.complex128)
-
- if not context.executing_eagerly():
- data = ops.convert_to_tensor(data, name="input_data")
- segment_ids = ops.convert_to_tensor(segment_ids, name="segment_ids")
- num_segments = ops.convert_to_tensor(num_segments, name="num_segments")
-
- result = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
- return tf.cast(result, dtype=orig_dtype)
-
-# The original, pre-patched function is automatically-generated. Therefore, we
-# cannot provide a URL to its location in the source repository.
-# For the history of this patch, please refer to
-# https://github.com/tensorflow/tensorflow/issues/39751
-def _new_segment_sum(data, segment_ids, name=None):
- """ERROR: docstring should have been added programatically. """
- with ops.name_scope(name, "SegmentSum", [data, segment_ids]) as name:
- # Note that data can be a vector-like list (or an n-dimensional
- # tensor-like list of lists). We convert to tensor here to replicate the
- # behavior of the pre-existing op.
- data = tf.convert_to_tensor(data)
-
- # Note that this patch does not provide determinism when the dtype of the
- # data argument is tf.float64 or tf.complex128.
- orig_dtype = data.dtype
- if 'float' in str(orig_dtype):
- data = tf.cast(data, dtype=tf.float64)
- elif 'complex' in str(orig_dtype):
- data = tf.cast(data, dtype=tf.complex128)
-
- if not context.executing_eagerly():
- data = ops.convert_to_tensor(data, name="input_data")
- segment_ids = ops.convert_to_tensor(segment_ids, name="segment_ids")
-
- result = gen_math_ops.segment_sum(data, segment_ids)
- return tf.cast(result, dtype=orig_dtype)
diff --git a/fwd9m/tensorflow/patch_bias_add.py b/fwd9m/tensorflow/patch_bias_add.py
new file mode 100644
index 0000000..e6c8bc6
--- /dev/null
+++ b/fwd9m/tensorflow/patch_bias_add.py
@@ -0,0 +1,66 @@
+# Copyright 2020 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+
+def _patch_bias_add():
+ _new_bias_add.__doc__ = tf.nn.bias_add.__doc__
+ tf.nn.bias_add = _new_bias_add # access via public API
+ nn.bias_add = _new_bias_add # called from tf.keras.layers.convolutional.Conv
+ nn_ops.bias_add = _new_bias_add # called from tests
+
+# The original, pre-patched method can be viewed at
+# https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/python/ops/nn_ops.py#L2628
+#
+# This patched version of bias_add does not implement some of the error checks
+# provided by the original op. For more information, see the list of test cases
+# excluded from the testing of the patched op functionality.
+def _new_bias_add(value, bias, data_format=None, name=None):
+ """ERROR: docstring should have been added programatically. """
+ with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
+ if data_format is not None:
+ if data_format.startswith("NC"):
+ data_format = "NCHW"
+ elif data_format.startswith("N") and data_format.endswith("C"):
+ data_format = "NHWC"
+ else:
+ raise ValueError("data_format must be of the form `N...C` or `NC...`")
+
+ if not context.executing_eagerly():
+ value = ops.convert_to_tensor(value, name="input")
+ bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
+
+ if data_format == 'NCHW':
+ broadcast_shape_head = [1, array_ops.size(bias)]
+ broadcast_shape_tail = array_ops.ones(array_ops.rank(value) - 2,
+ dtype=dtypes.int32)
+ broadcast_shape = array_ops.concat(
+ [broadcast_shape_head, broadcast_shape_tail], 0)
+ return math_ops.add(
+ value, array_ops.reshape(bias, broadcast_shape), name=name)
+ else: # data_format == 'NHWC' or data_format == None
+ return math_ops.add(value, bias, name=name)
diff --git a/fwd9m/tensorflow/patch_segment_sum.py b/fwd9m/tensorflow/patch_segment_sum.py
new file mode 100644
index 0000000..cda72ae
--- /dev/null
+++ b/fwd9m/tensorflow/patch_segment_sum.py
@@ -0,0 +1,67 @@
+# Copyright 2020 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+import tensorflow as tf
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.framework import dtypes
+
+# NOTE: This patch only provides GPU-determinism for data type float16/32 and
+# bfloat16.
+
+def _patch_segment_sum():
+ _new_segment_sum.__doc__ = tf.math.segment_sum.__doc__
+ math_ops.segment_sum = _new_segment_sum
+ tf.math.segment_sum = _new_segment_sum # access via public API
+
+# The original, pre-patched function is automatically-generated. Therefore, we
+# cannot provide a URL to its location in the source repository.
+# For the history of this patch, please refer to
+# https://github.com/tensorflow/tensorflow/issues/39751
+def _new_segment_sum(data, segment_ids, name=None):
+ """ERROR: docstring should have been added programatically. """
+ with ops.name_scope(name, "SegmentSum", [data, segment_ids]) as name:
+ if not context.executing_eagerly():
+ # Note that data can be a vector-like list (or an n-dimensional
+ # tensor-like list of lists). We convert to tensor here to replicate the
+ # behavior of the pre-existing op.
+ data = ops.convert_to_tensor(data, name="input_data")
+ segment_ids = ops.convert_to_tensor(segment_ids, name="segment_ids")
+
+ orig_dtype = data.dtype
+
+ if orig_dtype is dtypes.float32:
+ data = tf.cast(data, dtype=tf.float64)
+ elif orig_dtype is dtypes.float16:
+ data = tf.cast(data, dtype=tf.float32)
+ elif orig_dtype is dtypes.bfloat16:
+ data = tf.cast(data, dtype=tf.float32)
+ elif orig_dtype is dtypes.float64:
+ warnings.warn(
+ "Data type %s is not supported for GPU-determinism" %
+ data.dtype, UserWarning)
+
+ result = gen_math_ops.segment_sum(data, segment_ids)
+
+ return tf.cast(result, dtype=orig_dtype)
diff --git a/fwd9m/tensorflow/patch_softmax_xent.py b/fwd9m/tensorflow/patch_softmax_xent.py
new file mode 100644
index 0000000..bccb0f7
--- /dev/null
+++ b/fwd9m/tensorflow/patch_softmax_xent.py
@@ -0,0 +1,81 @@
+# Copyright 2021 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+import functools
+import numbers
+import os
+
+import numpy as np
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend as K
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.util import deprecation
+from tensorflow.python.util import dispatch
+from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.deprecation import deprecated_argument_lookup
+from tensorflow.python.util.tf_export import tf_export
+
+# NOTE: This patch provides GPU-determinism for
+# `tf.nn.softmax_cross_entropy_with_logits` via patching the op
+# `gen_nn_ops.softmax_cross_entropy_with_logit` with sequential calling of
+# softmax, logarithm and reduce_sum which are known deterministic.
+
+def _patch_softmax_xent():
+ gen_nn_ops.softmax_cross_entropy_with_logits = _new_soft_xent_op
+
+# The original, pre-patched python wrapper can be viewed at
+# gen_nn_ops.py which is a auto-generated code and the c++ code implementation
+# is \core\kernels\xent_op.cc.
+
+def _new_soft_xent_op(features, labels, name=None):
+
+ if not context.executing_eagerly():
+ features = ops.convert_to_tensor(features)
+ labels = ops.convert_to_tensor(labels)
+ features_rank = array_ops.shape(features).shape
+ labels_rank = array_ops.shape(labels).shape
+ else:
+ features_rank = array_ops.rank(features)
+ labels_rank = array_ops.rank(labels)
+
+ if features_rank == 1 or labels_rank == 1:
+ raise ValueError("must be 2d")
+ elif features_rank == 3 or labels_rank == 3:
+ raise ValueError("rank 2, but is rank 3")
+
+ softmax = tf.nn.softmax(logits=features, axis=-1)
+ epsilon_ = constant_op.constant(K.epsilon(), dtype=softmax.dtype.base_dtype)
+ softmax = clip_ops.clip_by_value(softmax, epsilon_, 1. - epsilon_)
+ # ??? * needs the data type to be the same
+ bp = (softmax - labels)
+ return -tf.reduce_sum(tf.math.log(softmax) * labels, axis=-1), bp
+
diff --git a/fwd9m/tensorflow/patch_sparse_softmax_xent.py b/fwd9m/tensorflow/patch_sparse_softmax_xent.py
new file mode 100644
index 0000000..8b09782
--- /dev/null
+++ b/fwd9m/tensorflow/patch_sparse_softmax_xent.py
@@ -0,0 +1,180 @@
+# Copyright 2021 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+import functools
+import numbers
+import os
+
+import numpy as np
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend as K
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.util import dispatch
+from tensorflow.python.util.tf_export import tf_export
+
+
+# NOTE: This patch provides GPU-determinism for
+# `tf.nn.sparse_softmax_cross_entropy_with_logits` via overriding the fused op
+# `gen_nn_ops.sparse_softmax_cross_entropy_with_logit` with turning labels into
+# one_hot encoding and calling patched
+# gen__nn_ops.softmax_cross_entropy_with_logits.
+
+def _patch_sparse_softmax_xent():
+ _new_sparse_softmax_xent_with_logits.__doc__ = \
+ tf.nn.sparse_softmax_cross_entropy_with_logits.__doc__
+ tf.nn.sparse_softmax_cross_entropy_with_logits = \
+ _new_sparse_softmax_xent_with_logits # access via public API
+ nn.sparse_softmax_cross_entropy_with_logits = \
+ _new_sparse_softmax_xent_with_logits
+ nn_ops.sparse_softmax_cross_entropy_with_logits = \
+ _new_sparse_softmax_xent_with_logits
+ # NOTE: Since enable_determinism
+ # patches gen_nn_ops.softmax_cross_entropy_with_logits and other ops
+ # universally, there is no need to patch here.
+
+# The original, pre-patched python wrapper
+# `nn.sparse_softmax_cross_entropy_with_logits` can be found at
+# https://github.com/tensorflow/tensorflow/blob/0c95acca049a05756f63bec731dbe9a11f9d8382/tensorflow/python/ops/nn_ops.py#L4066
+# The fused op `gen_nn_ops.sparse_softmax_cross_entropy_with_logit` is
+# automatically-generated. Therefore, we cannot provide a URL to its location in
+# the source repository.
+
+@tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[])
+@dispatch.add_dispatch_support
+def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
+ return nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits, name=name)
+
+def _ensure_xent_args(name, sentinel, labels, logits):
+ # Make sure that all arguments were passed as named arguments.
+ if sentinel is not None:
+ raise ValueError("Only call `%s` with "
+ "named arguments (labels=..., logits=..., ...)" % name)
+ if labels is None or logits is None:
+ raise ValueError("Both labels and logits must be provided.")
+
+@tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"])
+@dispatch.add_dispatch_support
+def _new_sparse_softmax_xent_with_logits(
+ _sentinel=None, # pylint: disable=invalid-name
+ labels=None,
+ logits=None,
+ name=None):
+ _ensure_xent_args("sparse_softmax_cross_entropy_with_logits", _sentinel,
+ labels, logits)
+
+ # TODO(pcmurray) Raise an error when the label is not an index in
+ # [0, num_classes). Note: This could break users who call this with bad
+ # labels, but disregard the bad results.
+
+ # Reshape logits and labels to rank 2.
+ with ops.name_scope(name, "SparseSoftmaxCrossEntropyWithLogits",
+ [labels, logits]):
+ labels = ops.convert_to_tensor(labels)
+ logits = ops.convert_to_tensor(logits)
+ precise_logits = math_ops.cast(logits, dtypes.float32) if (dtypes.as_dtype(
+ logits.dtype) == dtypes.float16) else logits
+
+ # Store label shape for result later.
+ labels_static_shape = labels.get_shape()
+ labels_shape = array_ops.shape(labels)
+ static_shapes_fully_defined = (
+ labels_static_shape.is_fully_defined() and
+ logits.get_shape()[:-1].is_fully_defined())
+ if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
+ raise ValueError(
+ "Logits cannot be scalars - received shape %s." % logits.get_shape())
+ if logits.get_shape().ndims is not None and (
+ labels_static_shape.ndims is not None and
+ labels_static_shape.ndims != logits.get_shape().ndims - 1):
+ raise ValueError("Rank mismatch: Rank of labels (received %s) should "
+ "equal rank of logits minus 1 (received %s)." %
+ (labels_static_shape.ndims, logits.get_shape().ndims))
+ if (static_shapes_fully_defined and
+ labels_static_shape != logits.get_shape()[:-1]):
+ raise ValueError("Shape mismatch: The shape of labels (received %s) "
+ "should equal the shape of logits except for the last "
+ "dimension (received %s)." % (labels_static_shape,
+ logits.get_shape()))
+
+ # Check if no reshapes are required.
+ if logits.get_shape().ndims == 2:
+ # Override of `gen_nn_ops.sparse_xent_with_logit`
+ if labels.get_shape().ndims is None:
+ raise errors_impl.InvalidArgumentError(
+ None, None, ".*labels must be 1-D.*")
+ # raise errors_impl.OpError(None, None, "labels must be 1-D", errors_impl.OpError)
+ onehot_encoding = tf.one_hot(labels, precise_logits.shape[-1],
+ dtype=dtypes.as_dtype(precise_logits.dtype))
+# cost = _core_op(labels=onehot_encoding, logits=precise_logits)
+
+ cost, _ = gen_nn_ops.softmax_cross_entropy_with_logits(
+ precise_logits, onehot_encoding, name=name)
+
+ if precise_logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
+
+ # Perform a check of the dynamic shapes if the static shapes are not fully
+ # defined.
+ shape_checks = []
+ if not static_shapes_fully_defined:
+ shape_checks.append(
+ check_ops.assert_equal(
+ array_ops.shape(labels),
+ array_ops.shape(logits)[:-1]))
+ with ops.control_dependencies(shape_checks):
+ # Reshape logits to 2 dim, labels to 1 dim.
+ num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1]
+ precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
+ labels = array_ops.reshape(labels, [-1])
+ if labels.get_shape().ndims is None:
+ raise errors_impl.InvalidArgumentError(None, None,
+ ".*labels must be 1-D.*")
+ # The second output tensor of `gen_nn_ops.sparse_xent_with_logits`
+ # contains the gradients. But it's used in _CrossEntropyGrad() in nn_grad
+ # but not here.
+ # cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
+ # precise_logits, labels, name=name)
+
+ onehot_encoding = tf.one_hot(labels, num_classes)
+ cost, _ = gen_nn_ops.softmax_cross_entropy_with_logits(precise_logits, onehot_encoding,name=name)
+
+ cost = array_ops.reshape(cost, labels_shape)
+ cost.set_shape(labels_static_shape)
+
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
+
diff --git a/fwd9m/tensorflow/patch_unsorted_segment_sum.py b/fwd9m/tensorflow/patch_unsorted_segment_sum.py
new file mode 100644
index 0000000..aebdc1c
--- /dev/null
+++ b/fwd9m/tensorflow/patch_unsorted_segment_sum.py
@@ -0,0 +1,71 @@
+# Copyright 2020 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+import tensorflow as tf
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.framework import dtypes as dtypes_lib
+
+# NOTE: This patch only provides GPU-determinism for data type float16/32,
+# complex64 and bfloat16.
+
+def _patch_unsorted_segment_sum():
+ _new_unsorted_segment_sum.__doc__ = tf.math.unsorted_segment_sum.__doc__
+ math_ops.unsorted_segment_sum = _new_unsorted_segment_sum
+ tf.math.unsorted_segment_sum = _new_unsorted_segment_sum # via public API
+
+# The original, pre-patched function is automatically-generated. Therefore, we
+# cannot provide a URL to its location in the source repository.
+# For the history of this patch, please refer to
+# https://github.com/tensorflow/tensorflow/issues/39751
+def _new_unsorted_segment_sum(data, segment_ids, num_segments, name=None):
+ """ERROR: docstring should have been added programatically. """
+ with ops.name_scope(
+ name, "UnsortedSegmentSum", [data, segment_ids, num_segments]) as name:
+ # Note that data can be a vector-like list (or an n-dimensional
+ # tensor-like list of lists). We convert to tensor here to replicate the
+ # behavior of the pre-existing op.
+ data = ops.convert_to_tensor(data, name="input_data")
+ segment_ids = ops.convert_to_tensor(segment_ids, name="segment_ids")
+ num_segments = ops.convert_to_tensor(num_segments, name="num_segments")
+
+ orig_dtype = data.dtype
+ if orig_dtype is dtypes_lib.float32:
+ data = tf.cast(data, dtype=tf.float64)
+ elif orig_dtype is dtypes_lib.float16:
+ data = tf.cast(data, dtype=tf.float32)
+ elif orig_dtype is dtypes_lib.complex64:
+ data = tf.cast(data, dtype=tf.complex128)
+ elif orig_dtype is dtypes_lib.bfloat16:
+ data = tf.cast(data, dtype=tf.float32)
+ elif orig_dtype is dtypes_lib.float64 or dtypes_lib.complex128:
+ warnings.warn(
+ "Data type %s is not supported for GPU-determinism" % data.dtype,
+ UserWarning)
+
+ result = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
+
+ return tf.cast(result, dtype=orig_dtype)
diff --git a/fwd9m/utils.py b/fwd9m/utils.py
index ca18ac3..7021440 100644
--- a/fwd9m/utils.py
+++ b/fwd9m/utils.py
@@ -70,3 +70,11 @@ def between(self, oldest_version, newest_version):
return True
else:
return False
+
+ def equals(self, target_version):
+ """Is the version equal to the version provided?"""
+ target_major, target_minor = self._only_major_and_minor(target_version)
+ if (self.major == target_major and self.minor == target_minor):
+ return True
+ else:
+ return False
diff --git a/pytorch.md b/pytorch.md
index 12d6da1..9f995e3 100644
--- a/pytorch.md
+++ b/pytorch.md
@@ -7,7 +7,9 @@ models, but our level of experience, so far, is not as extensive as for
TensorFlow.
PyTorch documentation includes some guidance for attaining GPU-determinism on
-its [reproducibility page][1], which we have contributed to.
+its [reproducibility page][1], which we have contributed to. Please refer to
+that page also because it probably has different or additional information to
+this current one.
Getting reproducible functionality on a single GPU, as with other frameworks,
involves several considerations:
@@ -30,13 +32,18 @@ np.random.seed(SEED) # if you're using numpy
torch.manual_seed(SEED) # torch.cuda.manual_seed_all(SEED) is not required
```
+It's often worth confirming that the trainable variables are being reproducibly
+initialized by creating and printing some kind of digest of all the trainable
+variables before beginning to train. Appropriate digests include a sum or a
+hash.
+
## Data Loader
You'll need to make sure that your data loader process is reproducible, so that
the sequence of examples or batches of examples delivered to your model are
-prefectly reproducible. If you have a mutli-threaded data loader, then it's
+perfectly reproducible. If you have a mutlithreaded data loader, then it's
important not to share PRNG state between threads. There may be other
-dataloader restrictions that I'm not yet aware of.
+data loader restrictions that I'm not yet aware of.
Reproducible inter-epoch re-shuffling can be attained by creating
an instance (`self.g`) of `torch.Generator` in your
@@ -47,7 +54,7 @@ def set_epoch(self, epoch):
self.epoch = epoch
if self.shuffle:
# We want every epoch to shuffle differently, but in a reproducible way.
- # Therefore, reset the generator differently buy reproducibly on each
+ # Therefore, reset the generator differently but reproducibly on each
# epoch. It is recommended for the seed to have a good balance of zero and
# one bits.
# See https://pytorch.org/docs/stable/generated/torch.Generator.html
@@ -60,7 +67,7 @@ Then call `set_epoch` at the start of each epoch.
Once the trainable variables are initializing reproducibly and training
examples are being delivered reproducibly, the next step is to maximally enable
-deterministic ops. The way you do this currently (in version 1.6) of PyTorch
+deterministic ops. The way you do this in versions of PyTorch earlier than 1.7
is a follows:
```
@@ -74,7 +81,7 @@ libraries: convolution, max pooling, and CTC loss (all three from cuDNN), and
batch matrix-matrix product (from cuBLAS).
The second line disables dynamic selection of cuDNN convolution algorithms
-and ensures that the algorithm select itself is reproducible.
+and ensures that the algorithm selection itself is reproducible.
The [reproducibilty page][1] contains a reasonable but non-comprehensive list of
ops the are nondeterminsitic on GPU. Using these will cause nondeterminism to
@@ -90,33 +97,35 @@ criteria must be met, as described in the PyTorch [documentation][4] for
`torch.nn.CTCLoss`. Another way of obtaining determinsitic CTC functionality
is to use [WarpCTC][2].
-PyTorch 1.7 will include a new function, `torch.set_determinism`, which will
-preclude the need to set eithe `torch.backends.cudnn.determinsitic` or
-`torch.backends.cudnn.benchmark`. An additional advantage of using this this
-function is that it will cause an exception to be thrown if you try to use an
-op that could inject nondeterminism into your model. It's impossible for an
-exception to be thrown in all circumstances when nondeterminism could be
-introduced by an op, let alone by the many other possible sources, but this
-feature will reduce the amount of time spend isolating sources of nondeterminism
-coming from ops that have already been identified as currently not able to
-operate deterministically on a GPU.
-
-## Save and Resume
-
-When saving your model, you will need to save not only the `model.state_dict()`
-but also the `optimizer.state_dict()` (which includes the current
-learning rate and any other learning rate scheduler state), the iteration/epoch
-counter, `torch.cuda.GradScaler` statistics, as well as the following PRNG
-states:
-
-```
-save_checkpoint["torch_rng_state"] = torch.get_rng_state()
-save_checkpoint["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
-save_checkpoint["numpy_rng_state"] = np.random.get_state()
-save_checkpoint["python_rng_state"] = random.getstate()
-```
-
-Please also refer to the [Saving and Loading Models][3] documentation.
+PyTorch 1.7 includes a new function, [`torch.set_deterministic`][5], which
+precludes the need to set either `torch.backends.cudnn.deterministic` or
+`torch.backends.cudnn.benchmark`. An additional advantage of using this function
+is that it will cause an exception to be thrown if you try to use an op that
+could inject nondeterminism into your model. It's impossible for an exception to
+be thrown in all circumstances when nondeterminism could be introduced by an op,
+let alone by the many other possible sources, but this feature will reduce the
+amount of time spent isolating sources of nondeterminism coming from ops that
+have already been identified as currently not able to operate deterministically
+on a GPU.
+
+## Reproducible Checkpointing
+
+To save state and later resume reproducibly (ending the training process
+exactly as if it had not been interrupted) you should `torch.save` and
+`torch.load` the following state (as needed) using [the approach][6] given in
+the PyTorch documentation, including the [guidance][7] for saving and loading
+GPU state:
+
+ * data loader state,
+ * `model.state_dict()`,
+ * `optimizer.state_dict()`, which includes the current learning rate and any
+ other learning rate scheduler state,
+ * epoch / iteration counter,
+ * `torch.cuda.GradScaler` statistics,
+ * `torch.get_rng_state()`,
+ * `torch.cuda.get_rng_state()`,
+ * `np.random.get_state()`, and
+ * `random.getstate()`
## Multi-GPU
@@ -143,3 +152,6 @@ the content on this page.
[2]: https://github.com/SeanNaren/warp-ctc
[3]: https://pytorch.org/tutorials/beginner/saving_loading_models.html
[4]: https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html
+[5]: https://pytorch.org/docs/stable/generated/torch.set_deterministic.html
+[6]: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
+[7]: https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-on-gpu-load-on-gpu
diff --git a/test/devel.sh b/test/devel.sh
index 8c443cb..86e047a 100755
--- a/test/devel.sh
+++ b/test/devel.sh
@@ -2,17 +2,8 @@
set -e # If any test fails, this script will exit and forward the error code
-./container.sh tensorflow/tensorflow:2.3.0-gpu python test_patch_segment_reduction.py
+IMAGE=tensorflow/tensorflow:2.3.0-gpu
+#IMAGE=nvcr.io/nvidia/tensorflow:19.06-py3
+#IMAGE=gitlab-master.nvidia.com:5005/dl/dgx/tensorflow:master-py3-devel
-# The segment sum patch has been shown to pass on the following NGC containers:
-# 19.06-py2/3
-# 19.07-py2
-# 19.09-py2/3
-# 19.11-tf1/2-py3
-# 19.12-tf1/2-py3
-# 20.01-tf1/2-py3
-# 20.06-tf1/2-py3
-# 20.08-tf1/2-py3
-# 20.09-tf2-py3
-# and the following stock TensorFlow containers:
-# ?
+./container.sh ${IMAGE} python test_patch_softmax_xent.py
diff --git a/test/segment_reduction_helper.py b/test/segment_reduction_helper.py
new file mode 100644
index 0000000..5bdf60c
--- /dev/null
+++ b/test/segment_reduction_helper.py
@@ -0,0 +1,150 @@
+# Copyright 2020 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for segment reduction ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import os
+import sys
+import unittest
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradient_checker_v2
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+sys.path.insert(0, '..')
+import fwd9m.tensorflow as fwd9m_tensorflow
+import utils
+
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Simplifies logging
+
+# Notes:
+# 0. These notes are relevant to this current file and also
+# test_patch_segment_sum.py and test_patch_unsorted_segment_sum.py
+# 1. The ops were expected to operate deterministically on the CPU and they do
+# indeed operate deterministically if forcely pinned to the CPU with
+# tf.device('/device:CPU:0'). What is not fully understood is why when they
+# are placed on the CPU using self.session(use_gpu=False), the ops still
+# introduce nondeterminism. By setting the log_device_placement parameter in
+# the session config to True under these conditions, we are able to confirm
+# that the ops are running on the CPU.
+# 2. To capture nondeterminism, random input data is necessary.
+# 3. The nondeterminism of dtypes_lib.float64, dtypes_lib.complex128 cannot be
+# removed by this patch, so they are not tested.
+# 4. The regular op tests below, represented by all the test classes except the
+# final two, which have names ending in "Deterministic", were taken from
+# tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+# (as of 2020-08-02); URL to file-at-commit:
+# https://github.com/tensorflow/tensorflow/blob/6371d4a38cfb122a8d9b2a03d5f56444e95462b0/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+# 5. The names of most of the upstream test classes are confusing (even more so
+# in the context of their limited use here), so the names have been changed
+# here, as appropriate, along with comments to indicate the original test
+# class names.
+
+
+class SegmentReductionHelper(test.TestCase):
+
+ def _random_input(self, input_shape, dtype=dtypes_lib.int32):
+ np.random.seed(hash(dtype) % 256)
+
+ np_values = np.random.random(input_shape).astype(dtype.as_numpy_dtype)
+ # Add a non-zero imaginary component to complex types.
+ if dtype.is_complex:
+ np_values -= 1j * np_values
+ return constant_op.constant(
+ np_values, shape=input_shape, dtype=dtype), np_values
+
+ def _input(self, input_shape, dtype=dtypes_lib.int32):
+ num_elem = 1
+ for x in input_shape:
+ num_elem *= x
+ values = np.arange(1, num_elem + 1)
+ np_values = values.reshape(input_shape).astype(dtype.as_numpy_dtype)
+ # Add a non-zero imaginary component to complex types.
+ if dtype.is_complex:
+ np_values -= 1j * np_values
+ return constant_op.constant(
+ np_values, shape=input_shape, dtype=dtype), np_values
+
+ def _randomDataOp(self, shape, data_type, seed):
+ if seed is not None:
+ np.random.seed(seed)
+ return constant_op.constant(np.random.random_sample(shape), dtype=data_type)
+
+ def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None,
+ initial_value=0):
+ if not x.size:
+ return np.array([])
+ indices = np.asarray(indices)
+ if num_segments is None:
+ num_segments = indices[-1] + 1
+ output = [None] * num_segments
+ slice_shape = x.shape[indices.ndim:]
+ x_flat = x.reshape((indices.size,) + slice_shape)
+ for i, index in enumerate(indices.ravel()):
+ if (output[index] is not None) and op1 == np.max:
+ for j in range(0, output[index].shape[0]):
+ output[index][j] = op1([output[index][j], x_flat[i][j]])
+ elif output[index] is not None:
+ output[index] = op1(output[index], x_flat[i])
+ else:
+ output[index] = x_flat[i]
+ # zero initialize values that are still uncalcuated.
+ initial_value_slice = np.ones(slice_shape) * initial_value
+ output = [o if o is not None else initial_value_slice for o in output]
+ if op2 is not None:
+ output = [op2(o) for o in output]
+ output = [o.reshape(slice_shape) for o in output]
+ return np.array(output)
+
+ def _mean_cum_op(self, x, y):
+ return (x[0] + y, x[1] + 1) if isinstance(x, tuple) else (x + y, 2)
+
+ def _mean_reduce_op(self, x):
+ return x[0] / x[1] if isinstance(x, tuple) else x
+
+ def _sqrt_n_reduce_op(self, x):
+ return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x
diff --git a/test/test_misc.py b/test/test_misc.py
index 70dafe7..a5a1860 100644
--- a/test/test_misc.py
+++ b/test/test_misc.py
@@ -17,15 +17,13 @@
import unittest
sys.path.insert(0, '..')
-from fwd9m import __version__ as fwd9m_version
import fwd9m
-from get_version import get_version
+import get_version
class TestMisc(unittest.TestCase):
def test_version(self):
- expected_version = get_version()
- self.assertEqual(fwd9m_version, expected_version)
+ expected_version = get_version.get_version()
self.assertEqual(fwd9m.__version__, expected_version)
if __name__ == '__main__':
diff --git a/test/test_patch_bias_add.py b/test/test_patch_bias_add.py
index 9a7049e..aac6d88 100644
--- a/test/test_patch_bias_add.py
+++ b/test/test_patch_bias_add.py
@@ -34,9 +34,14 @@
import os
import sys
+sys.path.insert(0, '..')
import numpy as np
import tensorflow as tf
+
+from . import utils as test_utils
+from fwd9m import utils as package_utils
+from fwd9m import tensorflow as fwd9m_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -48,13 +53,8 @@
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
-import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
-sys.path.insert(0, '..')
-import fwd9m.tensorflow as fwd9m_tensorflow
-import utils
-
# The tests in the following class were originally copied from
# https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/python/kernel_tests/bias_op_test.py
# and were then enhanced.
@@ -401,7 +401,7 @@ def bias_gradients(local_seed):
@test_util.run_in_graph_and_eager_modes
def testDeterministicGradients(self):
- with utils.force_gpu_session(self):
+ with test_utils.force_gpu_session(self):
# There are problems with using force_gpu=True and cached_session with
# both eager mode and graph mode in the same test. Using a non-cached
# session and putting everything inside the same session context is
@@ -413,7 +413,13 @@ def testDeterministicGradients(self):
# deterministically by default. I don't know if this is true for
# all layer configurations. These cases are still being tested here,
# for completeness.
- for data_rank in (1, 2, 3):
+ # TF1.13 only includes 2 add a note here for users
+ if package_utils._Version(tf.version.VERSION).equals("1.13"):
+ data_ranks = (2,)
+ else:
+ data_ranks = (1, 2, 3)
+
+ for data_rank in data_ranks:
for data_type in (dtypes.float16, dtypes.float32, dtypes.float64):
self._testDeterministicGradientsCase(op_binding, data_layout,
data_rank, data_type)
diff --git a/test/test_patch_segment_reduction.py b/test/test_patch_segment_reduction.py
deleted file mode 100644
index 9e22687..0000000
--- a/test/test_patch_segment_reduction.py
+++ /dev/null
@@ -1,828 +0,0 @@
-# Copyright 2020 NVIDIA Corporation. All Rights Reserved
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ========================================================================
-
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Functional tests for segment reduction ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-import os
-import sys
-import unittest
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.python.eager import context
-from tensorflow.python.eager import backprop
-from tensorflow.python.client import session
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes as dtypes_lib
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradient_checker
-from tensorflow.python.ops import gradient_checker_v2
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-sys.path.insert(0, '..')
-import fwd9m.tensorflow as fwd9m_tensorflow
-import utils
-
-os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Simplifies logging
-
-# Notes:
-# 1. The ops were expected to operate deterministically on the CPU and they do
-# indeed operate deterministically if forcely pinned to the CPU with
-# tf.device('/device:CPU:0'). What is not fully understood is why when they
-# are placed on the CPU using self.session(use_gpu=False), the ops still
-# introduce nondeterminism. By setting the log_device_placement parameter in
-# the session config to True under these conditions, we are able to confirm
-# that the ops are running on the CPU.
-# 2. To capture nondeterminism, random input data is necessary.
-# 3. The nondeterminism of dtypes_lib.float64, dtypes_lib.complex128 cannot be
-# removed by this patch, so they are not tested.
-# 4. The regular op tests below, represented by all the test classes except the
-# final two, which have names ending in "Deterministic", were taken from
-# tensorflow/python/kernel_tests/segment_reduction_ops_test.py
-# (as of 2020-08-02); URL to file-at-commit:
-# https://github.com/tensorflow/tensorflow/blob/6371d4a38cfb122a8d9b2a03d5f56444e95462b0/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
-# 5. The names of most of the upstream test classes are confusing (even more so
-# in the context of their limited use here), so the names have been changed
-# here, as appropriate, along with comments to indicate the original test
-# class names.
-
-
-class SegmentReductionHelper(test.TestCase):
-
- def _random_input(self, input_shape, dtype=dtypes_lib.int32):
- np.random.seed(hash(dtype) % 256)
-
- np_values = np.random.random(input_shape).astype(dtype.as_numpy_dtype)
- # Add a non-zero imaginary component to complex types.
- if dtype.is_complex:
- np_values -= 1j * np_values
- return constant_op.constant(
- np_values, shape=input_shape, dtype=dtype), np_values
-
- def _input(self, input_shape, dtype=dtypes_lib.int32):
- num_elem = 1
- for x in input_shape:
- num_elem *= x
- values = np.arange(1, num_elem + 1)
- np_values = values.reshape(input_shape).astype(dtype.as_numpy_dtype)
- # Add a non-zero imaginary component to complex types.
- if dtype.is_complex:
- np_values -= 1j * np_values
- return constant_op.constant(
- np_values, shape=input_shape, dtype=dtype), np_values
-
- def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None,
- initial_value=0):
- if not x.size:
- return np.array([])
- indices = np.asarray(indices)
- if num_segments is None:
- num_segments = indices[-1] + 1
- output = [None] * num_segments
- slice_shape = x.shape[indices.ndim:]
- x_flat = x.reshape((indices.size,) + slice_shape)
- for i, index in enumerate(indices.ravel()):
- if (output[index] is not None) and op1 == np.max:
- for j in range(0, output[index].shape[0]):
- output[index][j] = op1([output[index][j], x_flat[i][j]])
- elif output[index] is not None:
- output[index] = op1(output[index], x_flat[i])
- else:
- output[index] = x_flat[i]
- # zero initialize values that are still uncalcuated.
- initial_value_slice = np.ones(slice_shape) * initial_value
- output = [o if o is not None else initial_value_slice for o in output]
- if op2 is not None:
- output = [op2(o) for o in output]
- output = [o.reshape(slice_shape) for o in output]
- return np.array(output)
-
- def _mean_cum_op(self, x, y):
- return (x[0] + y, x[1] + 1) if isinstance(x, tuple) else (x + y, 2)
-
- def _mean_reduce_op(self, x):
- return x[0] / x[1] if isinstance(x, tuple) else x
-
- def _sqrt_n_reduce_op(self, x):
- return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x
-
-
-# Upstream class name: SegmentReductionOpTest
-class SegmentSumTest(SegmentReductionHelper):
-
- def testValues(self):
- dtypes = [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
- dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128
- ]
-
- # Each item is np_op1, np_op2, tf_op
- ops_list = [(np.add, None, math_ops.segment_sum)]
-
- # A subset of ops has been enabled for complex numbers
- complex_ops_list = [(np.add, None, math_ops.segment_sum)]
-
- n = 10
- shape = [n, 2]
- indices = [i // 3 for i in range(n)]
- for dtype in dtypes:
- if dtype in (dtypes_lib.complex64, dtypes_lib.complex128):
- curr_ops_list = complex_ops_list
- else:
- curr_ops_list = ops_list
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=use_gpu):
- tf_x, np_x = self._input(shape, dtype=dtype)
- for np_op1, np_op2, tf_op in curr_ops_list:
- np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
- s = tf_op(data=tf_x, segment_ids=indices)
- tf_ans = self.evaluate(s)
- self.assertAllClose(np_ans, tf_ans)
- # NOTE(mrry): The static shape inference that computes
- # `tf_ans.shape` can only infer that sizes from dimension 1
- # onwards, because the size of dimension 0 is data-dependent
- # and may therefore vary dynamically.
- self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
-
- @test_util.run_deprecated_v1
- def testSegmentIdsShape(self):
- shape = [4, 4]
- tf_x, _ = self._input(shape)
- indices = constant_op.constant([0, 1, 2, 2], shape=[2, 2])
- with self.assertRaises(ValueError):
- math_ops.segment_sum(data=tf_x, segment_ids=indices)
-
- @test_util.run_deprecated_v1
- def testSegmentIdsSize(self):
- shape = [4, 4]
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=use_gpu):
- tf_x, _ = self._input(shape)
- indices = [0, 1]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError("segment_ids should be the same size"):
- self.evaluate(s)
-
- @test_util.run_deprecated_v1
- def testSegmentIdsValid(self):
- # This is a baseline for the following SegmentIdsInvalid* tests.
- shape = [4, 4]
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=use_gpu):
- tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
- indices = [0, 0, 0, 1]
- result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
- self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
-
- def testSegmentIdsGreaterThanZero(self):
- shape = [4, 4]
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=use_gpu):
- tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
- indices = [1, 1, 2, 2]
- np_ans = self._segmentReduce(indices, np_x, np.add)
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- tf_ans = self.evaluate(s)
- self.assertAllClose(np_ans, tf_ans)
-
- def testSegmentIdsHole(self):
- shape = [4, 4]
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=use_gpu):
- tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
- indices = [0, 0, 3, 3]
- np_ans = self._segmentReduce(indices, np_x, np.add)
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- tf_ans = self.evaluate(s)
- self.assertAllClose(np_ans, tf_ans)
-
- @test_util.run_deprecated_v1
- def testSegmentIdsInvalid1(self):
- shape = [4, 4]
- with self.cached_session():
- tf_x, _ = self._input(shape)
- indices = [-1, -1, 0, 0]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError(
- r"Segment id -1 out of range \[0, 1\), possibly because "
- "'segment_ids' input is not sorted."):
- self.evaluate(s)
-
- @test_util.run_deprecated_v1
- def testSegmentIdsInvalid2(self):
- shape = [4, 4]
- with self.cached_session():
- tf_x, _ = self._input(shape)
- indices = [0, 1, 0, 1]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError("segment ids are not increasing"):
- self.evaluate(s)
-
- @test_util.run_deprecated_v1
- def testSegmentIdsInvalid3(self):
- shape = [4, 4]
- with self.cached_session():
- tf_x, _ = self._input(shape)
- indices = [0, 1, 2, 0]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError(
- r"Segment id 1 out of range \[0, 1\), possibly "
- "because 'segment_ids' input is not sorted."):
- self.evaluate(s)
-
- @test_util.run_deprecated_v1
- def testSegmentIdsInvalid4(self):
- shape = [4, 4]
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=use_gpu):
- tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
- indices = [0, 0, 0, -1]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError("segment ids must be >= 0"):
- self.evaluate(s)
-
- @test_util.run_deprecated_v1
- def testSegmentIdsInvalid5(self):
- shape = [4, 4]
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=use_gpu):
- tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
- indices = [0, 0, 0, -2]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError("segment ids must be >= 0"):
- self.evaluate(s)
-
- @test_util.run_deprecated_v1
- def testGradient(self):
- shape = [4, 4]
- indices = [0, 1, 2, 2]
- for tf_op in [
- math_ops.segment_sum]:
- with self.cached_session():
- tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
- s = tf_op(data=tf_x, segment_ids=indices)
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- tf_x,
- shape,
- s, [3, 4],
- x_init_value=np_x.astype(np.double),
- delta=1)
- self.assertAllClose(jacob_t, jacob_n)
-
- # Method removed because it only tests math_ops.segment_mean
- # def testDataInvalid(self):
- # ...
-
-
-# Upstream class name: UnsortedSegmentTest
-class UnsortedSegmentSumTest(SegmentReductionHelper):
-
- def __init__(self, methodName='runTest'):
- # Each item is np_op1, np_op2, tf_op, initial_value functor
- self.ops_list = [(np.add, None,
- math_ops.unsorted_segment_sum, lambda t: 0)]
-
- # A subset of ops has been enabled for complex numbers
- self.complex_ops_list = [(np.add, None,
- math_ops.unsorted_segment_sum, lambda t: 0)]
- self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32,
- dtypes_lib.float64]
- self.all_dtypes = (self.differentiable_dtypes +
- [dtypes_lib.bfloat16,
- dtypes_lib.int64, dtypes_lib.int32,
- dtypes_lib.complex64, dtypes_lib.complex128])
- super(UnsortedSegmentSumTest, self).__init__(methodName=methodName)
-
- def testValues(self):
- indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
- num_segments = 12
- for indices in indices_flat, indices_flat.reshape(5, 2):
- shape = indices.shape + (2,)
- for dtype in self.all_dtypes:
- ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
- tf_x, np_x = self._input(shape, dtype=dtype)
- for use_gpu in [True, False]:
- with self.cached_session(use_gpu=True):
- for np_op1, np_op2, tf_op, init_op in ops_list:
- # sqrt_n doesn't support integers
- if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer):
- continue
- # todo(philjd): enable this test once real_div supports bfloat16
- if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and
- dtype == dtypes_lib.bfloat16):
- continue
- np_ans = self._segmentReduce(
- indices, np_x, np_op1, np_op2, num_segments=num_segments,
- initial_value=init_op(dtype))
- s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments)
- tf_ans = self.evaluate(s)
- if dtype is dtypes_lib.bfloat16:
- tf_ans = tf_ans.astype(np.float32)
- self.assertAllCloseAccordingToType(np_ans, tf_ans)
- self.assertShapeEqual(np_ans, s)
-
- def testNumSegmentsTypes(self):
- dtypes = [dtypes_lib.int32, dtypes_lib.int64]
- indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
- num_segments = 12
- for indices in indices_flat, indices_flat.reshape(5, 2):
- shape = indices.shape + (2,)
- for dtype in dtypes:
- with self.cached_session(use_gpu=True):
- tf_x, np_x = self._input(shape)
- num_segments_constant = constant_op.constant(
- num_segments, dtype=dtype)
- np_ans = self._segmentReduce(
- indices, np_x, np.add, op2=None, num_segments=num_segments)
- s = math_ops.unsorted_segment_sum(
- data=tf_x,
- segment_ids=indices,
- num_segments=num_segments_constant)
- tf_ans = self.evaluate(s)
- self.assertAllClose(np_ans, tf_ans)
- self.assertShapeEqual(np_ans, s)
-
- @test_util.run_deprecated_v1
- def testGradientsTFGradients(self):
- num_cols = 2
- indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
- num_segments = max(indices_flat) + 3
- for dtype in self.differentiable_dtypes:
- ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
- for indices in indices_flat, indices_flat.reshape(5, 2):
- shape = indices.shape + (num_cols,)
- # test CPU and GPU as tf.gather behaves differently on each device
- for use_gpu in [False, True]:
- with self.cached_session(use_gpu=use_gpu):
- for _, _, tf_op, _ in ops_list:
- tf_x, np_x = self._input(shape, dtype=dtype)
- s = tf_op(tf_x, indices, num_segments)
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- tf_x,
- shape,
- s, [num_segments, num_cols],
- x_init_value=np_x,
- delta=1.)
- self.assertAllCloseAccordingToType(jacob_t, jacob_n,
- half_atol=1e-2)
-
- def _computeGradient(self, tf_op, indices, num_segments,
- shape, num_cols, dtype):
- tf_x, np_x = self._input(shape, dtype=dtype)
- if context.executing_eagerly():
- def f(x):
- return tf_op(x, indices, num_segments)
-
- gradient_tape_jacob_t, jacob_n = gradient_checker_v2.compute_gradient(
- f, [tf_x], delta=1.0)
- self.assertAllClose(jacob_n, gradient_tape_jacob_t)
- else:
- with self.cached_session():
- s = tf_op(tf_x, indices, num_segments)
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- tf_x,
- shape,
- s, [num_segments, num_cols],
- x_init_value=np_x,
- delta=1)
- self.assertAllClose(jacob_t, jacob_n)
-
- # This method has been enhanced to run on older versions of TensorFlow
- @test_util.run_in_graph_and_eager_modes
- def testGradientsGradientTape(self):
- num_cols = 2
- indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
- num_segments = max(indices_flat) + 3
- for dtype in self.differentiable_dtypes:
- ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
- for indices in indices_flat, indices_flat.reshape(5, 2):
- shape = indices.shape + (num_cols,)
- # test CPU and GPU as tf.gather behaves differently on each device
- for use_gpu in [test_util.use_gpu, test_util.force_cpu]:
- with use_gpu():
- for _, _, tf_op, _ in ops_list:
- self._computeGradient(tf_op, indices, num_segments, shape,
- num_cols, dtype)
-
- # Method removed because it only tests math_ops.unsorted_segment_prod
- # def testProdGrad(self):
- # ...
-
- @test_util.run_deprecated_v1
- def testGradientMatchesSegmentSum(self):
- # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum
- # and compare the outputs, which should be identical.
- # NB: for this test to work, indices must be valid for SegmentSum, namely
- # it must be sorted, the indices must be contiguous, and num_segments
- # must be max(indices) + 1.
- indices = [0, 0, 1, 1, 1, 2, 3, 4, 5]
- n = len(indices)
- num_cols = 2
- shape = [n, num_cols]
- num_segments = max(indices) + 1
- for dtype in self.differentiable_dtypes:
- with self.cached_session(use_gpu=True):
- tf_x, np_x = self._input(shape, dtype=dtype)
- # Results from UnsortedSegmentSum
- unsorted_s = math_ops.unsorted_segment_sum(
- data=tf_x, segment_ids=indices, num_segments=num_segments)
- unsorted_jacob_t, unsorted_jacob_n = (
- gradient_checker.compute_gradient(tf_x, shape, unsorted_s,
- [num_segments, num_cols],
- x_init_value=np_x, delta=1))
-
- # Results from SegmentSum
- sorted_s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- sorted_jacob_t, sorted_jacob_n = gradient_checker.compute_gradient(
- tf_x,
- shape,
- sorted_s, [num_segments, num_cols],
- x_init_value=np_x,
- delta=1)
- self.assertAllClose(unsorted_jacob_t, sorted_jacob_t)
- self.assertAllClose(unsorted_jacob_n, sorted_jacob_n)
-
- @test_util.run_deprecated_v1
- def testBadIndices(self):
- # Note: GPU kernel does not return the out-of-range error needed for this
- # test, so this test is marked as cpu-only.
- # Note: With PR #13055 a negative index will be ignored silently.
- with self.session(use_gpu=False):
- for bad in [[2]], [[7]]:
- unsorted = math_ops.unsorted_segment_sum([[17]], bad, num_segments=2)
- with self.assertRaisesOpError(
- r"segment_ids\[0,0\] = %d is out of range \[0, 2\)" % bad[0][0]):
- self.evaluate(unsorted)
-
- @test_util.run_deprecated_v1
- def testEmptySecondDimension(self):
- dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32,
- np.complex64, np.complex128]
- with self.session(use_gpu=True):
- for dtype in dtypes:
- for itype in (np.int32, np.int64):
- data = np.zeros((2, 0), dtype=dtype)
- segment_ids = np.array([0, 1], dtype=itype)
- unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
- self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
-
- def testDropNegatives(self):
- # Note: the test is done by replacing segment_ids with 8 to -1
- # for index and replace values generated by numpy with 0.
- indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
- num_segments = 12
- for indices in indices_flat, indices_flat.reshape(5, 2):
- shape = indices.shape + (2,)
- for dtype in self.all_dtypes:
- with self.session(use_gpu=True):
- tf_x, np_x = self._input(shape, dtype=dtype)
- np_ans = self._segmentReduce(
- indices, np_x, np.add, op2=None, num_segments=num_segments)
- # Replace np_ans[8] with 0 for the value
- np_ans[8:] = 0
- # Replace 8 with -1 in indices
- np.place(indices, indices == 8, [-1])
- s = math_ops.unsorted_segment_sum(
- data=tf_x, segment_ids=indices, num_segments=num_segments)
- tf_ans = self.evaluate(s)
- self.assertAllClose(np_ans, tf_ans)
- self.assertShapeEqual(np_ans, s)
-
-
-class SegmentReductionOpBenchmark(test.Benchmark):
-
- outer_dim_options = [2**x for x in range(9, 14, 2)]
- ratio_options = [2**x for x in range(1, 6, 2)]
- inner_dim_options = [2**x for x in range(9, 14, 2)]
- # randomly generated sizes with less alignments
- inner_dim_options += [
- 1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584
- ]
- dtype_options = [np.float32, np.float64]
- options = (outer_dim_options, ratio_options, inner_dim_options, dtype_options)
- # pylint: disable=g-long-lambda
- op_functors = [lambda vc, vs, seg_ids:
- ("sorted", math_ops.segment_sum(vc, vs)),
- lambda vc, vs, seg_ids:
- ("unsorted",
- math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))]
- # pylint: enable=g-long-lambda
- repeat = 10
-
- def _npTypeToStr(self, t):
- if t == np.float32:
- return "fp32"
- if t == np.float64:
- return "fp64"
-
- def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype):
- output_outer_dim = int(outer_dim / ratio)
- const = np.random.randint(5, size=(outer_dim, inner_dim))
- seg_ids = np.sort(np.random.randint(output_outer_dim, size=outer_dim))
- vs = variables.Variable(seg_ids.astype(np.int32))
- with ops.device("/gpu:0"):
- vc = variables.Variable(const.astype(dtype))
- name, op = op_functor(vc, vs, seg_ids)
- with session.Session() as sess:
- variables.global_variables_initializer().run()
- r = self.run_op_benchmark(
- sess,
- op,
- min_iters=self.repeat,
- name="_".join(
- map(str,
- [name, outer_dim, ratio, inner_dim,
- self._npTypeToStr(dtype)])))
- return name, r["wall_time"]
-
- def benchmarkSegmentSumGPU(self):
- if not test.is_gpu_available(cuda_only=True):
- return
- for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
- op_functor = self.op_functors[0]
- with ops.Graph().as_default():
- self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
-
- def benchmarkUnsortedSegmentSumGPU(self):
- if not test.is_gpu_available(cuda_only=True):
- return
- for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
- op_functor = self.op_functors[1]
- with ops.Graph().as_default():
- self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
-
-
-class SegmentSumDeterministicTest(SegmentReductionHelper):
-
- def __init__(self, methodName='runTest'):
- # Each item is np_op1, np_op2, tf_op, initial_value functor
- self.ops_list = [(np.add, None,
- math_ops.segment_sum, lambda t: 0),
- (np.add, None,
- tf.math.segment_sum, lambda t: 0)]
-
- # A subset of ops has been enabled for complex numbers
- self.complex_ops_list = [(np.add, None,
- math_ops.segment_sum, lambda t: 0),
- (np.add, None,
- tf.math.segment_sum, lambda t: 0)]
-
- self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32]
-
- self.all_dtypes = (self.differentiable_dtypes +
- [dtypes_lib.bfloat16,
- dtypes_lib.int64, dtypes_lib.int32,
- dtypes_lib.complex64])
- self.repeat_count = 5
- super(SegmentSumDeterministicTest,
- self).__init__(methodName=methodName)
-
- def _testForwardCase(self, dtype, indices, ops_list, shape):
- # have to use float to exec nond9m
- tf_x, _ = self._random_input(shape, dtype=dtype)
- # with utils.force_gpu_session(self):
- with self.session(use_gpu=True):
- for _, _, tf_op, _ in ops_list:
- run_ref = tf_op(data=tf_x, segment_ids=indices, name="tf_op_output")
- for i in range(self.repeat_count):
- self.assertAllEqual(tf_op(data=tf_x, segment_ids=indices), run_ref)
-
- def _testBackwardCase(self, dtype, indices, tf_op, shape):
- numpy_seed = 123
-
- def _randomDataOp(shape, data_type, seed):
- if seed is not None:
- np.random.seed(seed)
- return constant_op.constant(np.random.random_sample(shape),
- dtype=data_type)
-
- input_val = _randomDataOp(shape, dtype, seed=None)
- output_shape = [indices[-1]+1, shape[1]]
- if context.executing_eagerly():
- def op_gradients(local_seed):
- with backprop.GradientTape() as tape:
- tape.watch(input_val)
- op_output = tf_op(input_val, indices)
- upstream_gradients = _randomDataOp(output_shape, dtype, local_seed)
- gradient_injector_output = op_output * upstream_gradients
- return tape.gradient(gradient_injector_output, input_val)
-
- for i in range(self.repeat_count):
- local_seed = numpy_seed + i # select different upstream gradients
- result_a = op_gradients(local_seed)
- result_b = op_gradients(local_seed)
- self.assertAllEqual(result_a, result_b)
-
- else:
- op_output = tf_op(input_val, indices)
- upstream_gradients = array_ops.placeholder(dtype, shape=output_shape,
- name='upstream_gradients')
- gradient_injector_output = op_output * upstream_gradients
- op_gradients = gradients_impl.gradients(
- gradient_injector_output,
- input_val,
- grad_ys=None,
- colocate_gradients_with_ops=True)[0]
-
- for i in range(self.repeat_count):
- feed_dict = {upstream_gradients:np.random.random(output_shape)}
- result_a = op_gradients.eval(feed_dict=feed_dict)
- result_b = op_gradients.eval(feed_dict=feed_dict)
- self.assertAllEqual(result_a, result_b)
-
- @test_util.run_in_graph_and_eager_modes
- def testForward(self):
- num_cols = 8
- num_segments = 32
- segment_size = 256
-
- shape = [segment_size, num_cols]
- indices = np.random.randint(low=0, high=num_segments, size=(segment_size,))
- indices = np.sort(indices)
-
- for dtype in self.all_dtypes:
- ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
- self._testForwardCase(dtype, indices, ops_list, shape)
-
- # The backward operation is not known or expected to introduce nondeterminism
- # but we're testing it for completeness.
- @test_util.run_in_graph_and_eager_modes
- def testBackward(self):
- gradient_test = True
- num_cols = 8
- num_segments = 32
- segment_size = 256
- shape = [segment_size, num_cols]
- indices = np.random.randint(low=0, high=num_segments, size=(segment_size,))
- indices = np.sort(indices)
-
- with utils.force_gpu_session(self):
- # with self.session(force_gpu=True):#force_gpu=True leads to XLA issue
- for dtype in self.differentiable_dtypes:
- ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
- for _, _, tf_op, _ in ops_list:
- self._testBackwardCase(dtype, indices, tf_op, shape)
-
-
-class UnsortedSegmentSumDeterministicTest(SegmentReductionHelper):
-
- def __init__(self, methodName='runTest'):
- # Each item is np_op1, np_op2, tf_op, initial_value functor
- self.ops_list = [(np.add, None,
- math_ops.unsorted_segment_sum, lambda t: 0),
- (np.add, None,
- tf.math.unsorted_segment_sum, lambda t: 0)]
-
- # A subset of ops has been enabled for complex numbers
- self.complex_ops_list = [(np.add, None,
- math_ops.unsorted_segment_sum, lambda t: 0),
- (np.add, None,
- tf.math.unsorted_segment_sum, lambda t: 0)]
-
- self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32]
- self.all_dtypes = (self.differentiable_dtypes +
- [dtypes_lib.bfloat16,
- dtypes_lib.int64, dtypes_lib.int32,
- dtypes_lib.complex64])
- self.repeat_count = 5
- super(
- UnsortedSegmentSumDeterministicTest, self).__init__(
- methodName=methodName)
-
- def _testForwardCase(self, dtype, indices, num_segments, num_cols, ops_list,
- shape):
- x, _ = self._random_input(shape, dtype=dtype)
- def forward(tf_op):
- s = tf_op(x, indices, num_segments)
- tf_ans = self.evaluate(s)
- return tf_ans
-
- # with utils.force_gpu_session(self):
- with self.session(use_gpu=True):
- for _, _, tf_op, _ in ops_list:
- run_ref = forward(tf_op)
- for i in range(self.repeat_count):
- self.assertAllEqual(forward(tf_op), run_ref)
-
- def _testBackwardCase(self, dtype, indices, num_segments, op_binding, shape):
- numpy_seed = 123
- _, _, tf_op, _ = op_binding
-
- def _randomDataOp(shape, data_type, seed):
- if seed is not None:
- np.random.seed(seed)
- return constant_op.constant(np.random.random_sample(shape),
- dtype=data_type)
-
- input_val = _randomDataOp(shape, dtype, seed=None)
-
- if context.executing_eagerly():
- def op_gradients(local_seed):
- with backprop.GradientTape() as tape:
- tape.watch(input_val)
- op_output = tf_op(input_val, indices, num_segments)
- upstream_gradients = _randomDataOp(op_output.shape, dtype, local_seed)
- gradient_injector_output = op_output * upstream_gradients
- return tape.gradient(gradient_injector_output, input_val)
-
- for i in range(self.repeat_count):
- local_seed = numpy_seed + i # select different upstream gradients
- result_a = op_gradients(local_seed)
- result_b = op_gradients(local_seed)
- self.assertAllEqual(result_a, result_b)
-
- else:
- op_output = tf_op(input_val, indices, num_segments)
- output_shape = op_output.shape
- upstream_gradients = array_ops.placeholder(dtype, shape=output_shape,
- name='upstream_gradients')
- gradient_injector_output = op_output * upstream_gradients
- op_gradients = gradients_impl.gradients(
- gradient_injector_output,
- input_val,
- grad_ys=None,
- colocate_gradients_with_ops=True)[0]
-
- for i in range(self.repeat_count):
- feed_dict = {upstream_gradients:np.random.random(output_shape)}
- result_a = op_gradients.eval(feed_dict=feed_dict)
- result_b = op_gradients.eval(feed_dict=feed_dict)
- self.assertAllEqual(result_a, result_b)
-
- @test_util.run_in_graph_and_eager_modes
- def testForward(self):
- num_cols = 2
- num_rows = 64
- num_segments = 64
- segment_size = num_cols * num_rows
- indices_flat = np.random.randint(low=-1, high=num_segments,
- size=(segment_size,))
-
- for dtype in self.all_dtypes:
- ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
- for indices in indices_flat, indices_flat.reshape(num_rows, num_cols):
- shape = indices.shape + (num_cols,)
- self._testForwardCase(
- dtype, indices, num_segments, num_cols, ops_list, shape)
-
- # The backward operation is not known or expected to introduce nondeterminism
- # but we're testing it for completeness.
- @test_util.run_in_graph_and_eager_modes
- def testBackward(self):
- num_cols = 2
- num_rows = 64
- num_segments = 64
- segment_size = num_cols * num_rows
- indices_flat = np.random.randint(low=-1, high=num_segments,
- size=(segment_size,))
-
- with utils.force_gpu_session(self):
- # with self.session(force_gpu=True):
- for dtype in self.differentiable_dtypes:
- ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
- for op_binding in ops_list:
- for indices in indices_flat, indices_flat.reshape(num_rows, num_cols):
- shape = indices.shape + (num_cols,)
- self._testBackwardCase(
- dtype, indices, num_segments, op_binding, shape)
-
-
-if __name__ == "__main__":
- fwd9m_tensorflow.enable_determinism()
- test.main()
diff --git a/test/test_patch_segment_sum.py b/test/test_patch_segment_sum.py
new file mode 100644
index 0000000..4d0af4a
--- /dev/null
+++ b/test/test_patch_segment_sum.py
@@ -0,0 +1,386 @@
+# Copyright 2020 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for segment reduction ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import warnings
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from segment_reduction_helper import SegmentReductionHelper
+
+sys.path.insert(0, '..')
+import fwd9m.tensorflow as fwd9m_tensorflow
+import utils
+
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Simplifies logging
+
+# The tests in the following class were originally copied from
+# https://github.com/tensorflow/tensorflow/blob/1e9b9b1568d550e6779d2ddd5d193968254d3029/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+# and were then enhanced.
+
+# NOTE: Op `gen_math_ops.segment_sum` has GPU kernels for the following data
+# types float16/32/64. The dynamic patch adopts a "super-accumulator" approach
+# which does the operation in higher precision with necessary pre-conversion
+# and post-conversion. Also note that integer operation generally has no issue
+# with the non-associativity of floating-point rounding errors. Therefore the
+# patch will not provide determinism for float64 or integer operands. For
+# bfloat16, no GPU kernel is available for TF version less than(and equal to)
+# 2.3. But it is likely that the patched ops will operate, in any given
+# configuration, faster using float32 on GPU than using bfloat16 on a CPU.
+# Therefore, we demonstrate a proof-of-concept for rapidly providing accelerated
+# GPU support in frameworks for new data formats before they are implemented
+# natively in hardware.
+
+# Upstream class name: SegmentReductionOpTest
+class SegmentSumTest(SegmentReductionHelper):
+
+ def testValues(self):
+ dtypes = [
+ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
+ dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128
+ ]
+
+ # Each item is np_op1, np_op2, tf_op
+ ops_list = [(np.add, None, math_ops.segment_sum)]
+
+ # A subset of ops has been enabled for complex numbers
+ complex_ops_list = [(np.add, None, math_ops.segment_sum)]
+
+ n = 10
+ shape = [n, 2]
+ indices = [i // 3 for i in range(n)]
+ for dtype in dtypes:
+ if dtype in (dtypes_lib.complex64, dtypes_lib.complex128):
+ curr_ops_list = complex_ops_list
+ else:
+ curr_ops_list = ops_list
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=use_gpu):
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ for np_op1, np_op2, tf_op in curr_ops_list:
+ np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
+ s = tf_op(data=tf_x, segment_ids=indices)
+ tf_ans = self.evaluate(s)
+ self.assertAllClose(np_ans, tf_ans)
+ # NOTE(mrry): The static shape inference that computes
+ # `tf_ans.shape` can only infer that sizes from dimension 1
+ # onwards, because the size of dimension 0 is data-dependent
+ # and may therefore vary dynamically.
+ self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsShape(self):
+ shape = [4, 4]
+ tf_x, _ = self._input(shape)
+ indices = constant_op.constant([0, 1, 2, 2], shape=[2, 2])
+ with self.assertRaises(ValueError):
+ math_ops.segment_sum(data=tf_x, segment_ids=indices)
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsSize(self):
+ shape = [4, 4]
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape)
+ indices = [0, 1]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment_ids should be the same size"):
+ self.evaluate(s)
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsValid(self):
+ # This is a baseline for the following SegmentIdsInvalid* tests.
+ shape = [4, 4]
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 0, 1]
+ result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
+ self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
+
+ def testSegmentIdsGreaterThanZero(self):
+ shape = [4, 4]
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=use_gpu):
+ tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [1, 1, 2, 2]
+ np_ans = self._segmentReduce(indices, np_x, np.add)
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ tf_ans = self.evaluate(s)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testSegmentIdsHole(self):
+ shape = [4, 4]
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=use_gpu):
+ tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 3, 3]
+ np_ans = self._segmentReduce(indices, np_x, np.add)
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ tf_ans = self.evaluate(s)
+ self.assertAllClose(np_ans, tf_ans)
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsInvalid1(self):
+ shape = [4, 4]
+ with self.cached_session():
+ tf_x, _ = self._input(shape)
+ indices = [-1, -1, 0, 0]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError(
+ r"Segment id -1 out of range \[0, 1\), possibly because "
+ "'segment_ids' input is not sorted."):
+ self.evaluate(s)
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsInvalid2(self):
+ shape = [4, 4]
+ with self.cached_session():
+ tf_x, _ = self._input(shape)
+ indices = [0, 1, 0, 1]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment ids are not increasing"):
+ self.evaluate(s)
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsInvalid3(self):
+ shape = [4, 4]
+ with self.cached_session():
+ tf_x, _ = self._input(shape)
+ indices = [0, 1, 2, 0]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError(
+ r"Segment id 1 out of range \[0, 1\), possibly "
+ "because 'segment_ids' input is not sorted."):
+ self.evaluate(s)
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsInvalid4(self):
+ shape = [4, 4]
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 0, -1]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment ids must be >= 0"):
+ self.evaluate(s)
+
+ @test_util.run_deprecated_v1
+ def testSegmentIdsInvalid5(self):
+ shape = [4, 4]
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 0, -2]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment ids must be >= 0"):
+ self.evaluate(s)
+
+ @test_util.run_deprecated_v1
+ def testGradient(self):
+ shape = [4, 4]
+ indices = [0, 1, 2, 2]
+ for tf_op in [
+ math_ops.segment_sum]:
+ with self.cached_session():
+ tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
+ s = tf_op(data=tf_x, segment_ids=indices)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ s, [3, 4],
+ x_init_value=np_x.astype(np.double),
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n)
+
+ # Method removed because it only tests math_ops.segment_mean
+ # def testDataInvalid(self):
+ # ...
+
+
+class SegmentSumDeterministicTest(SegmentReductionHelper):
+
+ def __init__(self, methodName='runTest'):
+ # Each item is np_op1, np_op2, tf_op, initial_value functor
+ self.ops_list = [(np.add, None,
+ math_ops.segment_sum, lambda t: 0),
+ (np.add, None,
+ tf.math.segment_sum, lambda t: 0)]
+
+ # A subset of ops has been enabled for complex numbers
+ self.complex_ops_list = [(np.add, None,
+ math_ops.segment_sum, lambda t: 0),
+ (np.add, None,
+ tf.math.segment_sum, lambda t: 0)]
+
+ self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32]
+
+ self.all_dtypes = (self.differentiable_dtypes + [dtypes_lib.bfloat16])
+ self.repeat_count = 5
+ super(SegmentSumDeterministicTest,
+ self).__init__(methodName=methodName)
+
+ def _testBackwardCase(self, dtype, indices, tf_op, shape):
+ numpy_seed = 123
+
+ input_val = self._randomDataOp(shape, dtype, seed=None)
+ output_shape = [indices[-1]+1, shape[1]]
+ if context.executing_eagerly():
+ def op_gradients(local_seed):
+ with backprop.GradientTape() as tape:
+ tape.watch(input_val)
+ op_output = tf_op(input_val, indices)
+ upstream_gradients = self._randomDataOp(output_shape, dtype, local_seed)
+ gradient_injector_output = op_output * upstream_gradients
+ return tape.gradient(gradient_injector_output, input_val)
+
+ for i in range(self.repeat_count):
+ local_seed = numpy_seed + i # select different upstream gradients
+ result_a = op_gradients(local_seed)
+ result_b = op_gradients(local_seed)
+ self.assertAllEqual(result_a, result_b)
+
+ else:
+ op_output = tf_op(input_val, indices)
+ upstream_gradients = array_ops.placeholder(dtype, shape=output_shape,
+ name='upstream_gradients')
+ gradient_injector_output = op_output * upstream_gradients
+ op_gradients = gradients_impl.gradients(
+ gradient_injector_output,
+ input_val,
+ grad_ys=None,
+ colocate_gradients_with_ops=True)[0]
+
+ for i in range(self.repeat_count):
+ feed_dict = {upstream_gradients:np.random.random(output_shape)}
+ result_a = op_gradients.eval(feed_dict=feed_dict)
+ result_b = op_gradients.eval(feed_dict=feed_dict)
+ self.assertAllEqual(result_a, result_b)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testForward(self):
+ num_cols = 8
+ num_segments = 32
+ segment_size = 256
+
+ shape = [segment_size, num_cols]
+ indices = np.random.randint(low=0, high=num_segments, size=(segment_size,))
+ indices = np.sort(indices)
+
+ with utils.force_gpu_session(self):
+ for dtype in self.all_dtypes:#(dtypes_lib.complex64,)
+ ops_list = self.complex_ops_list if dtype.is_complex \
+ else self.ops_list
+ tf_x, _ = self._random_input(shape, dtype=dtype)
+ # have to use float to exec nond9m
+ for _, _, tf_op, _ in ops_list:
+ for _ in range(self.repeat_count):
+ result_a = tf_op(data=tf_x, segment_ids=indices)
+ result_b = tf_op(data=tf_x, segment_ids=indices)
+ self.assertAllEqual(result_a, result_b)
+
+ # The backward operation is not known or expected to introduce nondeterminism
+ # but we're testing it for completeness.
+ @test_util.run_in_graph_and_eager_modes
+ def testBackward(self):
+ num_cols = 8
+ num_segments = 32
+ segment_size = 256
+ shape = [segment_size, num_cols]
+ indices = np.random.randint(low=0, high=num_segments, size=(segment_size,))
+ indices = np.sort(indices)
+
+ with utils.force_gpu_session(self):
+ # with self.session(force_gpu=True):#force_gpu=True leads to XLA issue
+ for dtype in self.differentiable_dtypes:
+ ops_list = self.complex_ops_list if dtype.is_complex \
+ else self.ops_list
+ for _, _, tf_op, _ in ops_list:
+ self._testBackwardCase(dtype, indices, tf_op, shape)
+
+ # Op `gen_math_ops.segment_sum()` is not patched for data type float64 on GPU.
+ # A warning will be thrown to indicate users float64 is still exposed to
+ # GPU-nondeterminism.
+ @test_util.run_in_graph_and_eager_modes
+ def testNonSupportedDataTypes(self):
+ shape = [10, 2]
+ indices = [i // 3 for i in range(10)]
+ non_supported_types = (dtypes_lib.float64,)
+ with utils.force_gpu_session(self):
+ for dtype in non_supported_types:
+ ops_list = self.complex_ops_list if dtype.is_complex \
+ else self.ops_list
+ tf_x, _ = self._input(shape, dtype)
+ for _, _, tf_op, _ in ops_list:
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ s = tf_op(data=tf_x, segment_ids=indices)
+ self.evaluate(s)
+ self.assertEqual(len(w), 1)
+ self.assertIsInstance(w[0].message, UserWarning)
+ self.assertTrue("GPU-determinism" in str(w[-1].message))
+
+class SegmentReductionTestMisc(test.TestCase):
+
+ def testSDocstring(self):
+ op = tf.math.segment_sum
+ docstring = op.__doc__
+
+ if not docstring: # falsy (None or "")
+ self.fail("The patched op %s has no docstring" % op.__name__)
+ if docstring.startswith('ERROR'):
+ self.fail("The docstring for the patched op %s has not been assigned"
+ % op.__name__)
+
+
+if __name__ == "__main__":
+ fwd9m_tensorflow.enable_determinism()
+ test.main()
diff --git a/test/test_patch_softmax_xent.py b/test/test_patch_softmax_xent.py
new file mode 100644
index 0000000..5f157ea
--- /dev/null
+++ b/test/test_patch_softmax_xent.py
@@ -0,0 +1,505 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Copyright 2021 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+"""Tests for SoftmaxCrossEntropyWithLogits op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import sys
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
+from tensorflow.python.platform import test
+
+sys.path.insert(0, '..')
+import fwd9m.tensorflow as fwd9m_tensorflow
+import utils
+
+# The tests in the following class were originally copied from
+# https://github.com/tensorflow/tensorflow/blob/b36436b087bd8e8701ef51718179037cccdfc26e/tensorflow/python/kernel_tests/xent_op_test.py
+# and were then enhanced.
+
+class XentTest(test.TestCase):
+
+ def _npXent(self, features, labels, dim=-1):
+ if dim == -1:
+ dim = len(features.shape) - 1
+ one_only_on_dim = list(features.shape)
+ one_only_on_dim[dim] = 1
+ e = np.exp(
+ features - np.reshape(np.amax(features, axis=dim), one_only_on_dim))
+ probs = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim)
+ bp = (probs - labels)
+ l = -np.sum(labels * np.log(probs + 1.0e-20), axis=dim)
+ return l, bp
+
+ # TODO(b/123860949): The values are constant folded for XLA, so placeholders
+ # are needed.
+ def _testXent(self,
+ np_features,
+ np_labels,
+ use_gpu=False,
+ with_placeholders=False):
+ np_loss, np_backprop = self._npXent(np_features, np_labels)
+ with self.cached_session(use_gpu=use_gpu) as sess:
+ if with_placeholders:
+ features_placeholder = array_ops.placeholder(np_features.dtype)
+ labels_placeholder = array_ops.placeholder(np_labels.dtype)
+ loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels_placeholder, features=features_placeholder)
+ tf_loss, tf_backprop = sess.run([loss, backprop],
+ feed_dict={
+ labels_placeholder: np_labels,
+ features_placeholder: np_features
+ })
+ else:
+ loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
+ np_features, np_labels)
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
+ self.assertAllCloseAccordingToType(np_loss, tf_loss, half_rtol=1e-2)
+ self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
+
+ def _testXentWrapper(self, np_features, np_labels, dim=-1, use_gpu=False):
+ np_loss, _ = self._npXent(np_features, np_labels, dim=dim)
+ with self.cached_session(use_gpu=use_gpu) as sess:
+ loss = nn_ops.softmax_cross_entropy_with_logits(
+ labels=np_labels, logits=np_features, dim=dim)
+ tf_loss = self.evaluate(loss)
+ print("np_loss:", np_loss)
+ print("tf_loss:", tf_loss)
+ self.assertAllCloseAccordingToType(np_loss, tf_loss)
+
+ # TODO(b/123860949): The values are constant folded for XLA, so placeholders
+ # are needed.
+ def _testAll(self, features, labels, with_placeholders=False):
+ self._testXent(
+ features, labels, use_gpu=False, with_placeholders=with_placeholders)
+ self._testXent(
+ features, labels, use_gpu=True, with_placeholders=with_placeholders)
+
+ def _testSingleClass(self, use_gpu=False):
+ for dtype in np.float16, np.float32:
+ with self.cached_session(use_gpu=use_gpu) as sess:
+ loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
+ np.array([[1.], [-1.], [0.]]).astype(dtype),
+ np.array([[-1.], [0.], [1.]]).astype(dtype))
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
+ self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
+ self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop)
+
+ def testSingleClass(self):
+ self._testSingleClass(True)
+ self._testSingleClass(False)
+
+ @test_util.run_deprecated_v1
+ def testRankTooLarge(self):
+ for dtype in np.float16, np.float32:
+ np_features = np.array([[[1., 1., 1., 1.]], [[1., 2., 3.,
+ 4.]]]).astype(dtype)
+ np_labels = np.array([[[0., 0., 0., 1.]], [[0., .5, .5,
+ 0.]]]).astype(dtype)
+ self.assertRaisesRegex(ValueError, "rank 2, but is rank 3",
+ gen_nn_ops.softmax_cross_entropy_with_logits,
+ np_features, np_labels)
+
+ def testNpXent(self):
+ # We create 2 batches of logits for testing.
+ # batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3.
+ # batch 1 has a bit of difference: 1, 2, 3, 4, with soft targets (1, 2).
+ features = [[1., 1., 1., 1.], [1., 2., 3., 4.]]
+ labels = [[0., 0., 0., 1.], [0., .5, .5, 0.]]
+
+ # For batch 0, we expect the uniform distribution: 0.25, 0.25, 0.25, 0.25
+ # With a hard target 3, the backprop is [0.25, 0.25, 0.25, -0.75]
+ # The loss for this batch is -log(0.25) = 1.386
+ #
+ # For batch 1, we have:
+ # exp(0) = 1
+ # exp(1) = 2.718
+ # exp(2) = 7.389
+ # exp(3) = 20.085
+ # SUM = 31.192
+ # So we have as probabilities:
+ # exp(0) / SUM = 0.032
+ # exp(1) / SUM = 0.087
+ # exp(2) / SUM = 0.237
+ # exp(3) / SUM = 0.644
+ # With a soft target (1, 2), the backprop is
+ # [0.032, 0.087 - 0.5 = -0.413, 0.237 - 0.5 = -0.263, 0.644]
+ # The loss for this batch is [0.5 * -log(0.087), 0.5 * -log(0.237)]
+ # = [1.3862, 1.9401]
+ np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
+ self.assertAllClose(
+ np.array([[0.25, 0.25, 0.25, -0.75], [0.0321, -0.4129, -0.2632,
+ 0.6439]]),
+ np_backprop,
+ rtol=1.e-3,
+ atol=1.e-3)
+ self.assertAllClose(
+ np.array([1.3862, 1.9401]), np_loss, rtol=1.e-3, atol=1.e-3)
+
+ def testShapeBroadcast(self):
+ np_f = np.array([[1., 2., 3., 4.],
+ [1., 2., 3., 4.]]).astype(np.float32)
+ np_l = np.array([[0., 0., 0., 1.],
+ [0., .5, .5, 0.]]).astype(np.float32)
+ np_loss, np_backprop = self._npXent(np_f, np_l)
+ tf_f = constant_op.constant(
+ np.array([[1., 2., 3., 4.]]).astype(np.float32))
+ tf_l = constant_op.constant(
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32))
+ for use_gpu in [False, True]:
+ with self.cached_session(use_gpu=use_gpu) as sess:
+ loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
+ tf_f, tf_l)
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
+ self.assertAllCloseAccordingToType(np_loss, tf_loss)
+ self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
+
+ # TODO(b/123860949): The values are constant folded for XLA, so placeholders
+ # are needed.
+ @test_util.run_deprecated_v1
+ def testFeatureBroadcast(self):
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16),
+ np.array([[0., 0., 0., 1.]]).astype(np.float16),
+ with_placeholders=True)
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16),
+ np.array([[0.], [2.]]).astype(np.float16),
+ with_placeholders=True)
+
+ @test_util.run_deprecated_v1
+ def testShapeMismatch(self):
+ with self.cached_session():
+ with self.assertRaises(ValueError):
+ gen_nn_ops.softmax_cross_entropy_with_logits(
+ [[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]])
+
+ @test_util.run_deprecated_v1
+ def testNotMatrix(self):
+ with self.cached_session():
+ with self.assertRaises(ValueError):
+ gen_nn_ops.softmax_cross_entropy_with_logits([0., 1., 2., 3.],
+ [0., 1., 0., 1.])
+
+ def testHalf(self):
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16),
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float16))
+
+ def testFloat(self):
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32),
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32))
+
+ def testDouble(self):
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64))
+
+ def testLargeNegative(self):
+ np_features = np.array(
+ [[-1000., 0., 1000., 2000.], [1., 2., 3., 4.]]).astype(np.float32)
+ np_labels = np.array(
+ [[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32)
+
+ loss = nn_ops.softmax_cross_entropy_with_logits(
+ labels=np_labels, logits=np_features)
+ tf.debugging.check_numerics(
+ loss, "Nan in loss when logit has large negative Num")
+
+
+ @test_util.run_deprecated_v1
+ def testGradient(self):
+ with self.cached_session() as sess:
+ l = constant_op.constant(
+ [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="l")
+ f = constant_op.constant(
+ [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="f")
+ x = nn_ops.softmax_cross_entropy_with_logits(
+ labels=l, logits=f, name="xent")
+ err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
+
+ # Check that no extra computation performed. When only first derivative is requested,
+ # second derivative must not be computed. So when there is no second derivative,
+ # there is no `BatchMatMul` op in the graph.
+ op_names = [
+ op.op_def.name for op in sess.graph.get_operations() if op.op_def
+ ]
+ self.assertNotIn("BatchMatMul", op_names)
+ self.assertNotIn("BatchMatMulV2", op_names)
+
+ print("cross entropy gradient err = ", err)
+ self.assertLess(err, 5e-8)
+
+ @test_util.run_deprecated_v1
+ def testGradientLabelWithV2(self):
+ with self.cached_session():
+ l = constant_op.constant(
+ [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="l")
+ f = constant_op.constant(
+ [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="f")
+ x = nn_ops.softmax_cross_entropy_with_logits_v2(
+ labels=l, logits=f, name="xent")
+ err = gradient_checker.compute_gradient_error(l, [3, 4], x, [3])
+
+ self.assertLess(err, 5e-8)
+
+ @test_util.run_deprecated_v1
+ def testSecondGradient(self):
+ with self.cached_session() as sess:
+ l = constant_op.constant(
+ [
+ 0.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, 0.0, 0.0, 0.0, 0.0, 0.5 / 3, 0.0,
+ 0.5 / 3
+ ],
+ shape=[12],
+ dtype=dtypes.float64,
+ name="l")
+ f = constant_op.constant(
+ [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
+ shape=[12],
+ dtype=dtypes.float64,
+ name="f")
+ x = nn_ops.softmax_cross_entropy_with_logits(
+ labels=l, logits=f, name="xent")
+ loss = math_ops.reduce_sum(x)
+
+ gradients = gradients_impl.gradients(loss, [f])[0]
+
+ err = gradient_checker.compute_gradient_error(f, [12], gradients, [12])
+
+ # Check that second derivative is calculated. is it important? Ian comment?
+ # (it is equivalent to being `BatchMatMul` op in the graph because of implementation of xentropy grad)
+ op_names = [
+ op.op_def.name for op in sess.graph.get_operations() if op.op_def
+ ]
+ # self.assertIn("BatchMatMulV2", op_names)
+
+ print("cross entropy hessian err = ", err)
+ self.assertLess(err, 5e-8)
+
+ def testWrapper(self):
+ features = np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
+ [[2., 3., 4., 5.], [6., 7., 8., 9.]],
+ [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(
+ np.float32)
+ labels = np.array([[[0., 0., 0., 1.], [0., 1., 0., 0.]],
+ [[0., 0.5, 0.5, 0.], [0.5, 0.5, 0., 0.]],
+ [[0., 1., 0., 0.], [0., 0., 1., 0.]]]).astype(
+ np.float32)
+ self._testXentWrapper(features, labels, dim=0, use_gpu=False)
+ self._testXentWrapper(features, labels, dim=0, use_gpu=True)
+ self._testXentWrapper(features, labels, dim=1, use_gpu=False)
+ self._testXentWrapper(features, labels, dim=1, use_gpu=True)
+ self._testXentWrapper(features, labels, dim=-1, use_gpu=False)
+ self._testXentWrapper(features, labels, dim=-1, use_gpu=True)
+
+ def testZeroDimension(self):
+ features = np.zeros([0, 2, 4]).astype(np.float32)
+ labels = np.zeros([0, 2, 4]).astype(np.float32)
+ np_loss, _ = self._npXent(features, labels)
+ with self.session(use_gpu=True) as sess:
+ loss = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=features)
+ tf_loss = self.evaluate(loss)
+ self.assertAllEqual(np_loss, tf_loss)
+
+
+class XentBenchmark(test.Benchmark):
+
+ def benchmarkZeroDimension(self):
+ for (m, n, p, use_gpu) in itertools.product(
+ [128],
+ [10, 100, 1000, 10000, 100000],
+ [0.001, 0.01, 0.5, 0.99, 1.0],
+ [False]):
+ k = int(p * n)
+ if k == 0:
+ continue
+ name = "zero_dimension_m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu)
+ device = "/%s:0" % ("gpu" if use_gpu else "cpu")
+ with ops.Graph().as_default():
+ with ops.device(device):
+ labels = array_ops.zeros([0, 2, 4], dtype=dtypes.float32)
+ logits = array_ops.zeros([0, 2, 4], dtype=dtypes.float32)
+ op = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ with session.Session() as sess:
+ r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
+ gb_processed_input = m * n / 1.0e9
+ throughput = gb_processed_input / r["wall_time"]
+ print("Benchmark: %s \t wall_time: %0.03g s \t "
+ "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
+ sys.stdout.flush()
+
+ def benchmarkSingleClass(self):
+ for (m, n, p, use_gpu) in itertools.product(
+ [128],
+ [10, 100, 1000, 10000, 100000],
+ [0.001, 0.01, 0.5, 0.99, 1.0],
+ [False]):
+ k = int(p * n)
+ if k == 0:
+ continue
+ name = "single_class_m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu)
+ device = "/%s:0" % ("gpu" if use_gpu else "cpu")
+ with ops.Graph().as_default():
+ with ops.device(device):
+ labels = constant_op.constant([[1.], [-1.], [0.]],
+ dtype=dtypes.float32)
+ logits = constant_op.constant([[-1.], [0.], [1.]],
+ dtype=dtypes.float32)
+ op = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ with session.Session() as sess:
+ r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
+ gb_processed_input = m * n / 1.0e9
+ throughput = gb_processed_input / r["wall_time"]
+ print("Benchmark: %s \t wall_time: %0.03g s \t "
+ "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
+ sys.stdout.flush()
+
+class SoftmaxXentDeterministicTest(tf.test.TestCase):
+
+ def _randomInts(self, shape, high, dtype):
+ return tf.constant(
+ np.random.randint(low=0, high=high, size=shape).astype(dtype))
+
+ def _randomFloats(self, shape, dtype, normalized_rows=False):
+ a = (2 * np.random.random_sample(shape) - 1).astype(dtype)
+
+ if normalized_rows:
+ def normalize(row):
+ return row / row.sum()
+ a = np.apply_along_axis(normalize, 1, a)
+
+ return tf.constant(a)
+
+ def gradients(self, seed, output_shape, output_dtype, labels, logits):
+ np.random.seed(seed)
+ upstream_gradients = self._randomFloats(output_shape, output_dtype)
+
+ with tf.GradientTape(persistent=True) as tape:
+ tape.watch(logits)
+ op_output = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ gradient_injector_output = op_output * upstream_gradients
+
+ return tape.gradient(gradient_injector_output, logits)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testForward(self):
+ batch_size = 1024
+ classes_count = 1000
+ logits_shape = (batch_size, classes_count)
+ logits_dtype = np.float32
+ logits = self._randomFloats(logits_shape, logits_dtype)
+
+ labels_shape = logits_shape
+ labels_dtype = logits_dtype
+ labels = self._randomFloats(labels_shape, labels_dtype,
+ normalized_rows=True)
+
+ with utils.force_gpu_session(self):
+ repeat_count = 5
+ for _ in range(repeat_count):
+ result_a = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ result_b = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ self.assertAllEqual(result_a, result_b)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDistributionLabelsDeterministicGradients(self):
+ with utils.force_gpu_session(self):
+ batch_size = 1024
+ classes_count = 1000
+ logits_shape = (batch_size, classes_count)
+ logits_dtype = np.float32
+ logits = self._randomFloats(logits_shape, logits_dtype)
+
+ labels_shape = logits_shape
+ labels_dtype = logits_dtype
+ labels = self._randomFloats(labels_shape, labels_dtype,
+ normalized_rows=True)
+ output_shape = (batch_size)
+ output_dtype = logits_dtype
+
+ args = (output_shape, output_dtype, labels, logits)
+ repeat_count = 5
+ for seed in range(repeat_count):
+ result_a = self.gradients(seed, *args)
+ result_b = self.gradients(seed, *args)
+ self.assertAllEqual(result_a, result_b)
+
+class SoftmaxXentTestMisc(test.TestCase):
+
+ def testSDocstring(self):
+ op = tf.nn.softmax_cross_entropy_with_logits
+ docstring = op.__doc__
+ if not docstring: # falsy (None or "")
+ self.fail("The patched op %s has no docstring" % op.__name__)
+ if docstring.startswith('ERROR'):
+ self.fail("The docstring for the patched op %s has not been assigned"
+ % op.__name__)
+
+
+if __name__ == "__main__":
+ fwd9m_tensorflow.enable_determinism()
+ test.main()
+
diff --git a/test/test_patch_sparse_softmax_xent.py b/test/test_patch_sparse_softmax_xent.py
new file mode 100644
index 0000000..fc050f6
--- /dev/null
+++ b/test/test_patch_sparse_softmax_xent.py
@@ -0,0 +1,499 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Copyright 2021 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+"""Tests for SparseSoftmaxCrossEntropyWithLogits op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import time
+import unittest
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import backprop as backprop_lib
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import device_lib
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import app
+from tensorflow.python.platform import test
+
+sys.path.insert(0, '..')
+import fwd9m.tensorflow as fwd9m_tensorflow
+import utils
+
+from fwd9m.utils import _Version as Version
+tf_version = Version(tf.version.VERSION)
+
+# The tests in the following class were originally copied from
+# https://github.com/tensorflow/tensorflow/blob/582c8d236cb079023657287c318ff26adb239002/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+# and were then enhanced.
+
+class SparseXentTest(test.TestCase):
+
+ def _npXent(self, features, labels):
+ features = np.reshape(features, [-1, features.shape[-1]])
+ labels = np.reshape(labels, [-1])
+ batch_dim = 0
+ class_dim = 1
+ batch_size = features.shape[batch_dim]
+ e = np.exp(features - np.reshape(
+ np.amax(
+ features, axis=class_dim), [batch_size, 1]))
+ probs = e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1])
+ labels_mat = np.zeros_like(probs).astype(probs.dtype)
+ labels_mat[np.arange(batch_size), labels] = 1.0
+ bp = (probs - labels_mat)
+ l = -np.sum(labels_mat * np.log(probs + 1.0e-20), axis=1)
+ return l, bp
+
+ def _testXent(self, np_features, np_labels):
+ np_loss, np_backprop = self._npXent(np_features, np_labels)
+ with self.cached_session(use_gpu=True) as sess:
+ loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
+ np_features, np_labels)
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
+ self.assertAllCloseAccordingToType(np_loss, tf_loss)
+ self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
+
+ def testSingleClass(self):
+ for label_dtype in np.int32, np.int64:
+ with self.cached_session(use_gpu=True) as sess:
+ loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
+ np.array([[1.], [-1.], [0.]]).astype(np.float32),
+ np.array([0, 0, 0]).astype(label_dtype))
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
+ self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
+ self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop)
+
+ @test_util.run_deprecated_v1
+ @test_util.disable_xla("XLA cannot assert inside of a kernel.")
+ def testInvalidLabel(self):
+ features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.],
+ [1., 2., 3., 4.]]
+ labels = [4, 3, 0, -1]
+
+ if test.is_built_with_gpu_support() and utils.is_gpu_available_xla():
+ with self.session(use_gpu=True) as sess:
+ loss, backprop = (
+ gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
+ features, labels))
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
+ self.assertAllClose(
+ [[np.nan] * 4, [0.25, 0.25, 0.25, -0.75],
+ [-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4],
+ tf_backprop,
+ rtol=1e-3,
+ atol=1e-3)
+ self.assertAllClose(
+ [np.nan, 1.3862, 3.4420, np.nan], tf_loss, rtol=1e-3, atol=1e-3)
+
+ with self.session(use_gpu=False) as sess:
+ loss, backprop = (
+ gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels))
+ with self.assertRaisesOpError("Received a label value of"):
+ self.evaluate([loss, backprop])
+
+ def testNpXent(self):
+ # We create 2 batches of logits for testing.
+ # batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3.
+ # batch 1 has a bit of difference: 1, 2, 3, 4, with target 0.
+ features = [[1., 1., 1., 1.], [1., 2., 3., 4.]]
+ labels = [3, 0]
+
+ # For batch 0, we expect the uniform distribution: 0.25, 0.25, 0.25, 0.25
+ # With a hard target 3, the backprop is [0.25, 0.25, 0.25, -0.75]
+ # The loss for this batch is -log(0.25) = 1.386
+ #
+ # For batch 1, we have:
+ # exp(0) = 1
+ # exp(1) = 2.718
+ # exp(2) = 7.389
+ # exp(3) = 20.085
+ # SUM = 31.192
+ # So we have as probabilities:
+ # exp(0) / SUM = 0.032
+ # exp(1) / SUM = 0.087
+ # exp(2) / SUM = 0.237
+ # exp(3) / SUM = 0.644
+ # With a hard 1, the backprop is [0.032 - 1.0 = -0.968, 0.087, 0.237, 0.644]
+ # The loss for this batch is [1.0 * -log(0.25), 1.0 * -log(0.032)]
+ # = [1.3862, 3.4420]
+ np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
+ self.assertAllClose(
+ np.array([[0.25, 0.25, 0.25, -0.75], [-0.968, 0.087, 0.237, 0.6439]]),
+ np_backprop,
+ rtol=1.e-3,
+ atol=1.e-3)
+ self.assertAllClose(
+ np.array([1.3862, 3.4420]), np_loss, rtol=1.e-3, atol=1.e-3)
+
+ def testShapeMismatch(self):
+ with self.session(use_gpu=True):
+ with self.assertRaisesRegexp(ValueError, ".*Rank mismatch:*"):
+ nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=[[0, 2]], logits=[[0., 1.], [2., 3.], [2., 3.]])
+
+ def testScalar(self):
+ with self.session(use_gpu=True):
+ with self.assertRaisesRegexp(ValueError, ".*Logits cannot be scalars*"):
+ nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=constant_op.constant(0), logits=constant_op.constant(1.0))
+
+ @test_util.run_deprecated_v1
+ def testLabelsPlaceholderScalar(self):
+ with self.session(use_gpu=True):
+ labels = array_ops.placeholder(np.int32)
+ # (Ian) Since `gen_nn_ops.*` has been overridden, the way exception is thrown
+ # has been changed.
+ with self.assertRaisesOpError("labels must be 1-D"):
+ y = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=[[7.]])
+ # y.eval(feed_dict={labels: 0})
+
+ def testVector(self):
+ with self.session(use_gpu=True):
+ loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=constant_op.constant(0), logits=constant_op.constant([1.0]))
+ self.assertAllClose(0.0, self.evaluate(loss))
+
+ def testFloat(self):
+ for label_dtype in np.int32, np.int64:
+ self._testXent(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32),
+ np.array([3, 0]).astype(label_dtype))
+
+ def testDouble(self):
+ for label_dtype in np.int32, np.int64:
+ self._testXent(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
+ np.array([0, 3]).astype(label_dtype))
+
+ def testHalf(self):
+ for label_dtype in np.int32, np.int64:
+ self._testXent(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16),
+ np.array([3, 0]).astype(label_dtype))
+
+ def testEmpty(self):
+ self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32))
+
+ @test_util.run_deprecated_v1
+ def testGradient(self):
+ with self.session(use_gpu=True) as sess:
+ l = constant_op.constant([3, 0, 1], name="l")
+ f = constant_op.constant(
+ [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="f")
+ x = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=l, logits=f, name="xent")
+ err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
+
+ # Check that no extra computation performed. When only first derivative is
+ # requested, second derivative must not be computed. So when there is no
+ # second derivative, there is no `BatchMatMul` op in the graph.
+ op_names = [
+ op.op_def.name for op in sess.graph.get_operations() if op.op_def
+ ]
+ self.assertNotIn("BatchMatMul", op_names)
+ self.assertNotIn("BatchMatMulV2", op_names)
+
+ self.assertLess(err, 5e-8)
+
+ @unittest.skipIf(
+ tf_version.at_most('2.1'),
+ "Currently there is no way to take the second derivative of \
+ sparse_softmax_cross_entropy_with_logits due to the fused implementation's \
+ interaction with tf.gradients() ")
+ @test_util.run_deprecated_v1
+ def testSecondGradient(self):
+ with self.session() as sess:
+ l = constant_op.constant([3, 0, 1], name="l")
+ f = constant_op.constant(
+ [0.3, 0.4, 0.1, 1.2, 0.1, 1.9, 0.1, 0.7, 0.8, 0.2, 1.3, 1.3],
+ shape=[3, 4],
+ dtype=dtypes.float64,
+ name="f")
+ x = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=l, logits=f, name="xent")
+
+ gradients = gradients_impl.gradients(x, [f])[0]
+ err = gradient_checker.compute_gradient_error(f, [3, 4], gradients,
+ [3, 4])
+
+ # Check that second derivative is calculated.
+ # (it is equivalent to being `BatchMatMul` op in the graph because of
+ # implementation of xentropy grad)
+ op_names = [
+ op.op_def.name for op in sess.graph.get_operations() if op.op_def
+ ]
+ # self.assertIn("BatchMatMulV2", op_names)
+
+ self.assertLess(err, 5e-8)
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def _testHighDim(self, features, labels):
+ np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
+ # manually reshape loss
+ np_loss = np.reshape(np_loss, np.array(labels).shape)
+ tf_loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=features)
+ if not context.executing_eagerly():
+ # (Ian) Since the deterministic solution has overrided
+ # `gen_nn_ops.sparse_softmax_cross_entropy_with_logits` which contains the
+ # gradients as the second output tensor. It is used in _CrossEntropyGrad()
+ # in nn_grad but not here. Not need to test here.
+ # https://github.com/tensorflow/tensorflow/blob/11659c3dcaffb5ccbaa464f2ef1f4bde7ed5c49f/tensorflow/python/ops/nn_grad.py#L544
+ # tf_backprop = tf_loss.op.inputs[0].op.outputs[1]
+ pass
+ else:
+ with backprop_lib.GradientTape() as tape:
+ features = constant_op.constant(features)
+ tape.watch(features)
+ tf_backprop = tape.gradient(
+ nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=features), [features])[0]
+ tf_backprop = array_ops.reshape(tf_backprop, np_backprop.shape)
+
+ self.assertAllCloseAccordingToType(np_loss, tf_loss)
+ self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
+
+ def testHighDim(self):
+ features = [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]
+ labels = [[3], [0]]
+ self._testHighDim(features, labels)
+
+ def testHighDim2(self):
+ features = [[[1., 1., 1., 1.], [2., 2., 2., 2.]],
+ [[1., 2., 3., 4.], [5., 6., 7., 8.]]]
+ labels = [[3, 2], [0, 3]]
+ self._testHighDim(features, labels)
+
+ @test_util.run_deprecated_v1
+ def testScalarHandling(self):
+ with self.session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ ".*labels must be 1-D.*"):
+ labels = array_ops.placeholder(dtypes.int32, shape=[None, 1])
+ logits = array_ops.placeholder(dtypes.float32, shape=[None, 3])
+ ce = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=array_ops.squeeze(labels), logits=logits)
+ labels_v2 = np.zeros((1, 1), dtype=np.int32)
+ logits_v2 = np.random.randn(1, 3)
+ sess.run([ce], feed_dict={labels: labels_v2, logits: logits_v2})
+
+def _sparse_vs_dense_xent_benchmark_dense(labels, logits):
+ labels = array_ops.identity(labels)
+ logits = array_ops.identity(logits)
+ with ops_lib.device("/cpu:0"): # Sparse-to-dense must be on CPU
+ batch_size = array_ops.shape(logits)[0]
+ num_entries = array_ops.shape(logits)[1]
+ length = batch_size * num_entries
+ labels += num_entries * math_ops.range(batch_size)
+ target = sparse_ops.sparse_to_dense(labels,
+ array_ops.stack([length]), 1.0, 0.0)
+ target = array_ops.reshape(target, array_ops.stack([-1, num_entries]))
+ crossent = nn_ops.softmax_cross_entropy_with_logits(
+ labels=target, logits=logits, name="SequenceLoss/CrossEntropy")
+ crossent_sum = math_ops.reduce_sum(crossent)
+ grads = gradients_impl.gradients([crossent_sum], [logits])[0]
+
+ return (crossent_sum, grads)
+
+def _sparse_vs_dense_xent_benchmark_sparse(labels, logits):
+ # Using sparse_softmax_cross_entropy_with_logits
+ labels = labels.astype(np.int64)
+ labels = array_ops.identity(labels)
+ logits = array_ops.identity(logits)
+ crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name="SequenceLoss/CrossEntropy")
+ crossent_sum = math_ops.reduce_sum(crossent)
+ grads = gradients_impl.gradients([crossent_sum], [logits])[0]
+
+ return (crossent_sum, grads)
+
+def sparse_vs_dense_xent_benchmark(batch_size, num_entries, use_gpu):
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ labels = np.random.randint(num_entries, size=batch_size).astype(np.int32)
+ logits = np.random.randn(batch_size, num_entries).astype(np.float32)
+
+ def _timer(sess, ops):
+ # Warm in
+ for _ in range(20):
+ sess.run(ops)
+
+ # Timing run
+ start = time.time()
+ for _ in range(20):
+ sess.run(ops)
+ end = time.time()
+
+ return (end - start) / 20.0 # Average runtime per iteration
+
+ # Using sparse_to_dense and softmax_cross_entropy_with_logits
+ with session.Session(config=config) as sess:
+ if not use_gpu:
+ with ops_lib.device("/cpu:0"):
+ ops = _sparse_vs_dense_xent_benchmark_dense(labels, logits)
+ else:
+ ops = _sparse_vs_dense_xent_benchmark_dense(labels, logits)
+ delta_dense = _timer(sess, ops)
+
+ # Using sparse_softmax_cross_entropy_with_logits
+ with session.Session(config=config) as sess:
+ if not use_gpu:
+ with test_util.device("/cpu:0"):
+ ops = _sparse_vs_dense_xent_benchmark_sparse(labels, logits)
+ else:
+ ops = _sparse_vs_dense_xent_benchmark_sparse(labels, logits)
+ delta_sparse = _timer(sess, ops)
+
+ print("%d \t %d \t %s \t %f \t %f \t %f" % (batch_size, num_entries, use_gpu,
+ delta_dense, delta_sparse,
+ delta_sparse / delta_dense))
+
+def main(_):
+ print("Sparse Xent vs. SparseToDense + Xent")
+ print("batch \t depth \t gpu \t dt(dense) \t dt(sparse) "
+ "\t dt(sparse)/dt(dense)")
+ for use_gpu in (False, True):
+ for batch_size in (32, 64, 128):
+ for num_entries in (100, 1000, 10000):
+ sparse_vs_dense_xent_benchmark(batch_size, num_entries, use_gpu)
+ sparse_vs_dense_xent_benchmark(32, 100000, use_gpu)
+ sparse_vs_dense_xent_benchmark(8, 1000000, use_gpu)
+
+class SparseSoftmaxXentDeterministicTest(tf.test.TestCase):
+
+ def _randomInts(self, shape, high, dtype):
+ return tf.constant(
+ np.random.randint(low=0, high=high, size=shape).astype(dtype))
+
+ def _randomFloats(self, shape, dtype, normalized_rows=False):
+ a = (2 * np.random.random_sample(shape) - 1).astype(dtype)
+
+ if normalized_rows:
+ def normalize(row):
+ return row / row.sum()
+ a = np.apply_along_axis(normalize, 1, a)
+
+ return tf.constant(a)
+
+ def gradients(self, seed, labels, logits):
+ np.random.seed(seed)
+ output_dtype = logits.dtype.as_numpy_dtype
+ output_shape = labels.shape
+ upstream_gradients = self._randomFloats(output_shape, output_dtype)
+
+ with tf.GradientTape(persistent=True) as tape:
+ tape.watch(logits)
+ op_output = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ gradient_injector_output = op_output * upstream_gradients
+
+ return tape.gradient(gradient_injector_output, logits)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testForward(self):
+ batch_size = 128
+ num_entries = 100000
+
+ labels = self._randomInts(batch_size, num_entries, np.int64)
+ logits = self._randomFloats((batch_size, num_entries), np.float32)
+
+ labels = array_ops.identity(labels)
+ logits = array_ops.identity(logits)
+
+ with utils.force_gpu_session(self):
+ repeat_count = 5
+ for _ in range(repeat_count):
+ result_a = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ result_b = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ self.assertAllEqual(result_a, result_b)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDeterministicGradients(self):
+ batch_size = 128
+ num_entries = 100000
+
+ labels = self._randomInts(batch_size, num_entries, np.int64)
+ logits = self._randomFloats((batch_size, num_entries), np.float32)
+
+ labels = array_ops.identity(labels)
+ logits = array_ops.identity(logits)
+
+ with utils.force_gpu_session(self):
+ repeat_count = 5
+ for seed in range(repeat_count):
+ result_a = self.gradients(seed, labels, logits)
+ result_b = self.gradients(seed, labels, logits)
+ self.assertAllEqual(result_a, result_b)
+
+class SparseSoftmaxXentTestMisc(test.TestCase):
+
+ def testSDocstring(self):
+ op = tf.nn.sparse_softmax_cross_entropy_with_logits
+ docstring = op.__doc__
+ if not docstring: # falsy (None or "")
+ self.fail("The patched op %s has no docstring" % op.__name__)
+ if docstring.startswith('ERROR'):
+ self.fail("The docstring for the patched op %s has not been assigned"
+ % op.__name__)
+
+if __name__ == "__main__":
+ if "--benchmarks" in sys.argv:
+ sys.argv.remove("--benchmarks")
+ app.run()
+ else:
+ fwd9m_tensorflow.enable_determinism()
+ test.main()
+
diff --git a/test/test_patch_unsorted_segment_sum.py b/test/test_patch_unsorted_segment_sum.py
new file mode 100644
index 0000000..5de7516
--- /dev/null
+++ b/test/test_patch_unsorted_segment_sum.py
@@ -0,0 +1,449 @@
+# Copyright 2020 NVIDIA Corporation. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ========================================================================
+
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for unsorted segment reduction ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import warnings
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradient_checker_v2
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from segment_reduction_helper import SegmentReductionHelper
+
+sys.path.insert(0, '..')
+import fwd9m.tensorflow as fwd9m_tensorflow
+import utils
+
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Simplifies logging
+
+# The tests in the following class were originally copied from
+# https://github.com/tensorflow/tensorflow/blob/1e9b9b1568d550e6779d2ddd5d193968254d3029/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+# and were then enhanced.
+
+# NOTE: gen_math_ops.unsorted_segment_sum has GPU kernels for the following
+# data types, float16/32/64, complex64/128. The dynamic patch adopts a
+# "super-accumulator" approach which does the operation in higher precision with
+# necessary pre-conversion and post-conversion. Also note that integer operation
+# generally has no issue with the non-associativity of floating-point rounding
+# errors. Therefore the patch will not provide determinism for float64,
+# complex128 or integer operands. For bfloat16, no GPU kernel is available for
+# TF version less than(and equal to) 2.3. But it is likely that the patched ops
+# will operate, in any given configuration, faster using float32 on GPU than
+# using bfloat16 on a CPU. Therefore, we demonstrate a proof-of-concept for
+# rapidly providing accelerated GPU support in frameworks for new data formats
+# before they are implemented natively in hardware.
+
+# Upstream class name: UnsortedSegmentTest
+class UnsortedSegmentSumTest(SegmentReductionHelper):
+
+ def __init__(self, methodName='runTest'):
+ # Each item is np_op1, np_op2, tf_op, initial_value functor
+ self.ops_list = [(np.add, None,
+ math_ops.unsorted_segment_sum, lambda t: 0)]
+
+ # A subset of ops has been enabled for complex numbers
+ self.complex_ops_list = [(np.add, None,
+ math_ops.unsorted_segment_sum, lambda t: 0)]
+ self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32,
+ dtypes_lib.float64]
+ self.all_dtypes = (self.differentiable_dtypes +
+ [dtypes_lib.bfloat16,
+ dtypes_lib.int64, dtypes_lib.int32,
+ dtypes_lib.complex64, dtypes_lib.complex128])
+ super(UnsortedSegmentSumTest, self).__init__(methodName=methodName)
+
+ def testValues(self):
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = 12
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (2,)
+ for dtype in self.all_dtypes:
+ ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ for use_gpu in [True, False]:
+ with self.cached_session(use_gpu=True):
+ for np_op1, np_op2, tf_op, init_op in ops_list:
+ # sqrt_n doesn't support integers
+ if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer):
+ continue
+ # todo(philjd): enable this test once real_div supports bfloat16
+ if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and
+ dtype == dtypes_lib.bfloat16):
+ continue
+ np_ans = self._segmentReduce(
+ indices, np_x, np_op1, np_op2, num_segments=num_segments,
+ initial_value=init_op(dtype))
+ s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments)
+ tf_ans = self.evaluate(s)
+ if dtype is dtypes_lib.bfloat16:
+ tf_ans = tf_ans.astype(np.float32)
+ self.assertAllCloseAccordingToType(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, s)
+
+ def testNumSegmentsTypes(self):
+ dtypes = [dtypes_lib.int32, dtypes_lib.int64]
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = 12
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (2,)
+ for dtype in dtypes:
+ with self.cached_session(use_gpu=True):
+ tf_x, np_x = self._input(shape)
+ num_segments_constant = constant_op.constant(
+ num_segments, dtype=dtype)
+ np_ans = self._segmentReduce(
+ indices, np_x, np.add, op2=None, num_segments=num_segments)
+ s = math_ops.unsorted_segment_sum(
+ data=tf_x,
+ segment_ids=indices,
+ num_segments=num_segments_constant)
+ tf_ans = self.evaluate(s)
+ self.assertAllClose(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, s)
+
+ @test_util.run_deprecated_v1
+ def testGradientsTFGradients(self):
+ num_cols = 2
+ indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
+ num_segments = max(indices_flat) + 3
+ for dtype in self.differentiable_dtypes:
+ ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (num_cols,)
+ # test CPU and GPU as tf.gather behaves differently on each device
+ for use_gpu in [False, True]:
+ with self.cached_session(use_gpu=use_gpu):
+ for _, _, tf_op, _ in ops_list:
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ s = tf_op(tf_x, indices, num_segments)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ s, [num_segments, num_cols],
+ x_init_value=np_x,
+ delta=1.)
+ self.assertAllCloseAccordingToType(jacob_t, jacob_n,
+ half_atol=1e-2)
+
+ def _computeGradient(self, tf_op, indices, num_segments,
+ shape, num_cols, dtype):
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ if context.executing_eagerly():
+ def f(x):
+ return tf_op(x, indices, num_segments)
+
+ gradient_tape_jacob_t, jacob_n = gradient_checker_v2.compute_gradient(
+ f, [tf_x], delta=1.0)
+ self.assertAllClose(jacob_n, gradient_tape_jacob_t)
+ else:
+ with self.cached_session():
+ s = tf_op(tf_x, indices, num_segments)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ s, [num_segments, num_cols],
+ x_init_value=np_x,
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n)
+
+ # This method has been enhanced to run on older versions of TensorFlow
+ @test_util.run_in_graph_and_eager_modes
+ def testGradientsGradientTape(self):
+ num_cols = 2
+ indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
+ num_segments = max(indices_flat) + 3
+ for dtype in self.differentiable_dtypes:
+ ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (num_cols,)
+ # test CPU and GPU as tf.gather behaves differently on each device
+ # fwd9m note: the upstream test uses test_util.use_gpu, which seems to
+ # suffer from the same problem, and presumably does the same thing, as
+ # self.session(force_gpu=true). So we replaced test_util.use_gpu with
+ # utils.force_gpu_session(self).
+ for use_gpu in [utils.force_gpu_session(self), test_util.force_cpu()]:
+ with use_gpu:
+ # with utils.force_gpu_session(self):
+ for _, _, tf_op, _ in ops_list:
+ self._computeGradient(tf_op, indices, num_segments, shape,
+ num_cols, dtype)
+
+ # Method removed because it only tests math_ops.unsorted_segment_prod
+ # def testProdGrad(self):
+ # ...
+
+ @test_util.run_deprecated_v1
+ def testGradientMatchesSegmentSum(self):
+ # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum
+ # and compare the outputs, which should be identical.
+ # NB: for this test to work, indices must be valid for SegmentSum, namely
+ # it must be sorted, the indices must be contiguous, and num_segments
+ # must be max(indices) + 1.
+ indices = [0, 0, 1, 1, 1, 2, 3, 4, 5]
+ n = len(indices)
+ num_cols = 2
+ shape = [n, num_cols]
+ num_segments = max(indices) + 1
+ for dtype in self.differentiable_dtypes:
+ with self.cached_session(use_gpu=True):
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ # Results from UnsortedSegmentSum
+ unsorted_s = math_ops.unsorted_segment_sum(
+ data=tf_x, segment_ids=indices, num_segments=num_segments)
+ unsorted_jacob_t, unsorted_jacob_n = (
+ gradient_checker.compute_gradient(tf_x, shape, unsorted_s,
+ [num_segments, num_cols],
+ x_init_value=np_x, delta=1))
+
+ # Results from SegmentSum
+ sorted_s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ sorted_jacob_t, sorted_jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ sorted_s, [num_segments, num_cols],
+ x_init_value=np_x,
+ delta=1)
+ self.assertAllClose(unsorted_jacob_t, sorted_jacob_t)
+ self.assertAllClose(unsorted_jacob_n, sorted_jacob_n)
+
+ @test_util.run_deprecated_v1
+ def testBadIndices(self):
+ # Note: GPU kernel does not return the out-of-range error needed for this
+ # test, so this test is marked as cpu-only.
+ # Note: With PR #13055 a negative index will be ignored silently.
+ with self.session(use_gpu=False):
+ for bad in [[2]], [[7]]:
+ unsorted = math_ops.unsorted_segment_sum([[17]], bad, num_segments=2)
+ with self.assertRaisesOpError(
+ r"segment_ids\[0,0\] = %d is out of range \[0, 2\)" % bad[0][0]):
+ self.evaluate(unsorted)
+
+ @test_util.run_deprecated_v1
+ def testEmptySecondDimension(self):
+ dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32,
+ np.complex64, np.complex128]
+ with self.session(use_gpu=True):
+ for dtype in dtypes:
+ for itype in (np.int32, np.int64):
+ data = np.zeros((2, 0), dtype=dtype)
+ segment_ids = np.array([0, 1], dtype=itype)
+ unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
+ self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
+
+ def testDropNegatives(self):
+ # Note: the test is done by replacing segment_ids with 8 to -1
+ # for index and replace values generated by numpy with 0.
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = 12
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (2,)
+ for dtype in self.all_dtypes:
+ with self.session(use_gpu=True):
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ np_ans = self._segmentReduce(
+ indices, np_x, np.add, op2=None, num_segments=num_segments)
+ # Replace np_ans[8] with 0 for the value
+ np_ans[8:] = 0
+ # Replace 8 with -1 in indices
+ np.place(indices, indices == 8, [-1])
+ s = math_ops.unsorted_segment_sum(
+ data=tf_x, segment_ids=indices, num_segments=num_segments)
+ tf_ans = self.evaluate(s)
+ self.assertAllClose(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, s)
+
+
+class UnsortedSegmentSumDeterministicTest(SegmentReductionHelper):
+
+ def __init__(self, methodName='runTest'):
+ # Each item is np_op1, np_op2, tf_op, initial_value functor
+ self.ops_list = [(np.add, None,
+ math_ops.unsorted_segment_sum, lambda t: 0),
+ (np.add, None,
+ tf.math.unsorted_segment_sum, lambda t: 0)]
+
+ # A subset of ops has been enabled for complex numbers
+ self.complex_ops_list = [(np.add, None,
+ math_ops.unsorted_segment_sum, lambda t: 0),
+ (np.add, None,
+ tf.math.unsorted_segment_sum, lambda t: 0)]
+
+ self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32]
+
+ self.all_dtypes = (self.differentiable_dtypes +
+ [dtypes_lib.complex64, dtypes_lib.bfloat16])
+ self.repeat_count = 5
+ super(
+ UnsortedSegmentSumDeterministicTest, self).__init__(
+ methodName=methodName)
+
+ def _testBackwardCase(self, dtype, indices, num_segments, op_binding, shape):
+ numpy_seed = 123
+ _, _, tf_op, _ = op_binding
+
+ input_val = self._randomDataOp(shape, dtype, seed=None)
+
+ if context.executing_eagerly():
+ def op_gradients(local_seed):
+ with backprop.GradientTape() as tape:
+ tape.watch(input_val)
+ op_output = tf_op(input_val, indices, num_segments)
+ upstream_gradients = self._randomDataOp(op_output.shape,
+ dtype, local_seed)
+ gradient_injector_output = op_output * upstream_gradients
+ return tape.gradient(gradient_injector_output, input_val)
+
+ for i in range(self.repeat_count):
+ local_seed = numpy_seed + i # select different upstream gradients
+ result_a = op_gradients(local_seed)
+ result_b = op_gradients(local_seed)
+ self.assertAllEqual(result_a, result_b)
+
+ else:
+ op_output = tf_op(input_val, indices, num_segments)
+ output_shape = op_output.shape
+ upstream_gradients = array_ops.placeholder(dtype, shape=output_shape,
+ name='upstream_gradients')
+ gradient_injector_output = op_output * upstream_gradients
+ op_gradients = gradients_impl.gradients(
+ gradient_injector_output,
+ input_val,
+ grad_ys=None,
+ colocate_gradients_with_ops=True)[0]
+
+ for i in range(self.repeat_count):
+ feed_dict = {upstream_gradients:np.random.random(output_shape)}
+ result_a = op_gradients.eval(feed_dict=feed_dict)
+ result_b = op_gradients.eval(feed_dict=feed_dict)
+ self.assertAllEqual(result_a, result_b)
+
+
+ # The backward operation is not known or expected to introduce nondeterminism
+ # but we're testing it for completeness.
+ @test_util.run_in_graph_and_eager_modes
+ def testBackward(self):
+ num_cols = 2
+ num_rows = 64
+ num_segments = 64
+ segment_size = num_cols * num_rows
+ indices_flat = np.random.randint(low=-1, high=num_segments,
+ size=(segment_size,))
+
+ with utils.force_gpu_session(self):
+ for dtype in self.differentiable_dtypes:
+ for indices in indices_flat, indices_flat.reshape(num_rows, num_cols):
+ ops_list = self.complex_ops_list if dtype.is_complex \
+ else self.ops_list
+ for op_binding in ops_list:
+ shape = indices.shape + (num_cols,)
+ self._testBackwardCase(dtype, indices, num_segments,
+ op_binding, shape)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testForward(self):
+ num_cols = 2
+ num_rows = 64
+ num_segments = 64
+ segment_size = num_cols * num_rows
+ indices_flat = np.random.randint(low=-1, high=num_segments,
+ size=(segment_size,))
+ with utils.force_gpu_session(self):
+ for dtype in self.all_dtypes:
+ for indices in indices_flat, indices_flat.reshape(num_rows, num_cols):
+ shape = indices.shape + (num_cols,)
+ ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
+ x, _ = self._random_input(shape, dtype=dtype)
+
+ for _, _, tf_op, _ in ops_list:
+ for _ in range(self.repeat_count):
+ result_a = self.evaluate(tf_op(x, indices, num_segments))
+ result_b = self.evaluate(tf_op(x, indices, num_segments))
+ self.assertAllEqual(result_a, result_b)
+
+
+ # Op `gen_math_ops.unsorted_segment_sum()` is not patched for data type
+ # float64 and complex128 on GPU. A warning will be thrown to indicate users
+ # float64/complex128 is still exposed to GPU-nondeterminism.
+ @test_util.run_deprecated_v1
+ def testNonSupportedDataTypes(self):
+ non_supported_types = (dtypes_lib.float64, dtypes_lib.complex128)
+ indices = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = 12
+ shape = indices.shape + (2,)
+ with utils.force_gpu_session(self):
+ for dtype in non_supported_types:
+ ops_list = self.complex_ops_list if dtype.is_complex \
+ else self.ops_list
+ tf_x, _ = self._input(shape, dtype)
+
+ for _, _, tf_op, _ in ops_list:
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ s = tf_op(tf_x, indices, num_segments)
+ self.evaluate(s)
+ self.assertEqual(len(w), 1)
+ self.assertIsInstance(w[0].message, UserWarning)
+ self.assertTrue("GPU-determinism" in str(w[-1].message))
+
+
+class SegmentReductionTestMisc(test.TestCase):
+
+ def testSDocstring(self):
+ op = tf.math.unsorted_segment_sum
+ docstring = op.__doc__
+
+ if not docstring: # falsy (None or "")
+ self.fail("The patched op %s has no docstring" % op.__name__)
+ if docstring.startswith('ERROR'):
+ self.fail("The docstring for the patched op %s has not been assigned"
+ % op.__name__)
+
+if __name__ == "__main__":
+ fwd9m_tensorflow.enable_determinism()
+ test.main()
diff --git a/test/utils.py b/test/utils.py
index 6eb1498..fd0f3bd 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -1,6 +1,9 @@
import tensorflow as tf
+from tensorflow.python.platform import test
+
from fwd9m.utils import _Version as Version
+
# Notes about force_gpu_session:
#
# In TF1.15 and TF2.0, an apparent bug in tf.test.TestCase::session prevents us
@@ -61,3 +64,17 @@ def force_gpu_session(test_object):
return test_object.session(use_gpu=True)
else:
return test_object.session(force_gpu=True)
+
+def is_gpu_available_xla():
+ tf_version = Version(tf.version.VERSION)
+ if tf_version.in_list(['1.15', '2.0']):
+ print("WARNING:"
+ "an exception will not be thrown if there is no GPU present.")
+ gpus = tf.config.experimental.list_physical_devices('GPU')
+ if len(gpus)>0:
+ return True
+ else:
+ print("WARNING: no GPU present.")
+ return False
+ else:
+ return test.is_gpu_available()
\ No newline at end of file