Skip to content

Commit

Permalink
Jiminha/gatherndindex (#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiminha authored and sayantan-nervana committed Jan 27, 2020
1 parent 45611ed commit 658b7b4
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Once TensorFlow's dependencies are installed, clone the `ngraph-bridge` repo:

git clone https://github.com/tensorflow/ngraph-bridge.git
cd ngraph-bridge
git checkout v0.19.0-rc9
git checkout v0.19.0-rc10

Run the following Python script to build TensorFlow, nGraph, and the bridge. Use Python 3.5:

Expand Down
5 changes: 3 additions & 2 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2239,7 +2239,8 @@ static Status TranslateGatherNdOp(const Node* op,

auto ng_params_shape = ng_params->get_shape();
size_t ng_params_rank = ng_params_shape.size();
size_t ng_indices_rank = ng_indices->get_shape().size();
auto ng_indices_shape = ng_indices->get_shape();
size_t ng_indices_rank = ng_indices_shape.size();

for (size_t i = 0; i < ng_params_rank; i++) {
if (ng_params_shape[i] == 0) {
Expand All @@ -2250,7 +2251,7 @@ static Status TranslateGatherNdOp(const Node* op,
}
}

if ((ng_indices_rank - 1) > ng_params_rank) {
if ((ng_indices_shape[ng_indices_rank - 1]) > ng_params_rank) {
return errors::InvalidArgument(
"The last dimension of indices can be at most the rank of params");
}
Expand Down
2 changes: 1 addition & 1 deletion ngraph_bridge/version.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
// candidate such as v0.7.0-rc0
// The code in master will always have the last released version number
// with a suffix of '-master'
#define NG_TF_VERSION_SUFFIX "-rc9"
#define NG_TF_VERSION_SUFFIX "-rc10"

#define VERSION_STR_HELPER(x) #x
#define VERSION_STR(x) VERSION_STR_HELPER(x)
Expand Down
2 changes: 1 addition & 1 deletion python/setup.in.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_tag(self):

setup(
name='ngraph_tensorflow_bridge',
version='0.19.0rc9',
version='0.19.0rc10',
description='Intel nGraph compiler and runtime for TensorFlow',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
46 changes: 46 additions & 0 deletions test/python/test_gathernd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# ==============================================================================
# Copyright 2018-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""nGraph TensorFlow bridge gather_nd operation test
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pytest

import tensorflow as tf
import os
import numpy as np

from common import NgraphTest


class TestGatherNDOperations(NgraphTest):

def test_gather_nd(self):
val = tf.placeholder(tf.float32, shape=(5, 10))
indices = np.zeros([1, 3, 3, 1], dtype=np.int32)
out = tf.gather_nd(val, indices, batch_dims=0, name='output')

def run_test(sess):
return sess.run((out,),
feed_dict={val: np.arange(50).reshape([5, 10])})[0]

self.with_ngraph(run_test)

assert (
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()

0 comments on commit 658b7b4

Please sign in to comment.