diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 7c7f7c8fe..d015139a5 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -247,6 +247,9 @@ static Status TensorDataToVector(const Tensor& tensor, std::vector* vector) { // Else we have to convert. else { switch (dt) { + case DT_HALF: + ConvertTensorDataToVector(tensor, vector); + break; case DT_FLOAT: ConvertTensorDataToVector(tensor, vector); break; @@ -350,7 +353,9 @@ Builder::TF_NGRAPH_CONST_MAP() { {DataType::DT_UINT16, make_pair(MakeConstOp, ng::element::u16)}, {DataType::DT_BOOL, - make_pair(MakeConstOp, ng::element::boolean)}}; + make_pair(MakeConstOp, ng::element::boolean)}, + {DataType::DT_HALF, + make_pair(MakeConstOp, ng::element::f16)}}; return the_map; } diff --git a/ngraph_bridge/ngraph_mark_for_clustering.cc b/ngraph_bridge/ngraph_mark_for_clustering.cc index fb1cf2bb6..8ab779822 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.cc +++ b/ngraph_bridge/ngraph_mark_for_clustering.cc @@ -77,9 +77,7 @@ static Status TypeConstraintOk(Node* node, for (const auto& name_and_set : itr->second) { auto& type_attr_name = name_and_set.first; auto& allowed_types = name_and_set.second; - DataType dt; - if (GetNodeAttr(node->attrs(), type_attr_name, &dt) != Status::OK() || std::find(allowed_types.begin(), allowed_types.end(), dt) == allowed_types.end()) { @@ -566,6 +564,7 @@ const TypeConstraintMap& GetTypeConstraintMap() { type_constraint_map["NonMaxSuppressionV4"]["T"] = { DT_FLOAT}; // TF allows half too type_constraint_map["OneHot"]["T"] = NGraphDTypes(); + type_constraint_map["OneHot"]["TI"] = NGraphIndexDTypes(); type_constraint_map["Pack"]["T"] = NGraphDTypes(); type_constraint_map["Pad"]["T"] = NGraphDTypes(); type_constraint_map["Pad"]["Tpaddings"] = NGraphIndexDTypes(); diff --git a/ngraph_bridge/ngraph_utils.cc b/ngraph_bridge/ngraph_utils.cc index d1aeb3975..1accaa7ac 100644 --- a/ngraph_bridge/ngraph_utils.cc +++ b/ngraph_bridge/ngraph_utils.cc @@ -293,23 +293,23 @@ void print_node_histogram(const std::unordered_map& histogram, const gtl::ArraySlice& NGraphDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, - DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, - DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, + DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, + DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16, DT_HALF}; return result; } const gtl::ArraySlice& NGraphNumericDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, - DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, + DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_HALF}; return result; } const gtl::ArraySlice& NGraphNumericAndQuantizedDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, - DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_QINT8, DT_QUINT8}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, + DT_UINT16, DT_UINT32, DT_UINT64, DT_QINT8, DT_QUINT8, DT_HALF}; return result; } @@ -330,7 +330,8 @@ const gtl::ArraySlice& NGraphSupportedQuantizedDTypes() { } const gtl::ArraySlice& NGraphRealDTypes() { - static gtl::ArraySlice result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}; + static gtl::ArraySlice result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, + DT_HALF}; return result; } diff --git a/ngraph_bridge/ngraph_utils.h b/ngraph_bridge/ngraph_utils.h index c43102d91..d0facaab1 100644 --- a/ngraph_bridge/ngraph_utils.h +++ b/ngraph_bridge/ngraph_utils.h @@ -97,7 +97,9 @@ Status ValuesFromConstNode(const NodeDef& node, return errors::InvalidArgument("Node not a Const"); } - if (node.attr().at("dtype").type() != DataTypeToEnum::value) { + auto dt = node.attr().at("dtype").type(); + if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 || + (dt != DT_HALF && dt != DataTypeToEnum::value)) { std::stringstream ss; ss << "Invalid data type defined for Const. Defined: " << node.attr().at("dtype").type(); @@ -151,25 +153,29 @@ Status ValuesFromConstNode(const NodeDef& node, switch (dt) { // TODO(amprocte/NGRAPH-2502): there are more element types to support // here + case DT_HALF: + val_size = tensor.half_val_size(); + if (val_size > 0) val_i = static_cast(tensor.half_val()[i]); + break; case DT_INT32: val_size = tensor.int_val_size(); - if (val_size > 0) val_i = tensor.int_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.int_val()[i]); break; case DT_INT64: val_size = tensor.int64_val_size(); - if (val_size > 0) val_i = tensor.int64_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.int64_val()[i]); break; case DT_FLOAT: val_size = tensor.float_val_size(); - if (val_size > 0) val_i = tensor.float_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.float_val()[i]); break; case DT_BOOL: val_size = tensor.bool_val_size(); - if (val_size > 0) val_i = tensor.bool_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.bool_val()[i]); break; case DT_DOUBLE: val_size = tensor.double_val_size(); - if (val_size > 0) val_i = tensor.double_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.double_val()[i]); break; default: NGRAPH_VLOG(0)