Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 17, 2024
1 parent bc219ed commit b6aa1d4
Showing 1 changed file with 63 additions and 64 deletions.
127 changes: 63 additions & 64 deletions onnxruntime/test/providers/cpu/model_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

#include <iostream>
#include <iterator>
#include <string>
#include <codecvt>
#include <locale>
#include <filesystem>
#include <utility>
#include <unordered_map>
#include <gtest/gtest.h>

#include "core/session/onnxruntime_c_api.h"
Expand All @@ -15,9 +21,6 @@
#include <core/platform/path_lib.h>
#include "default_providers.h"
#include "test/onnx/TestCase.h"
#include <string>
#include <codecvt>
#include <locale>

#ifdef USE_DNNL
#include "core/providers/dnnl/dnnl_provider_factory.h"
Expand Down Expand Up @@ -361,46 +364,46 @@ TEST_P(ModelTest, Run) {
}

using ORT_STRING_VIEW = std::basic_string_view<ORTCHAR_T>;
static ORT_STRING_VIEW opset7 = ORT_TSTR("opset7");
static ORT_STRING_VIEW opset8 = ORT_TSTR("opset8");
static ORT_STRING_VIEW opset9 = ORT_TSTR("opset9");
static ORT_STRING_VIEW opset10 = ORT_TSTR("opset10");
static ORT_STRING_VIEW opset11 = ORT_TSTR("opset11");
static ORT_STRING_VIEW opset12 = ORT_TSTR("opset12");
static ORT_STRING_VIEW opset13 = ORT_TSTR("opset13");
static ORT_STRING_VIEW opset14 = ORT_TSTR("opset14");
static ORT_STRING_VIEW opset15 = ORT_TSTR("opset15");
static ORT_STRING_VIEW opset16 = ORT_TSTR("opset16");
static ORT_STRING_VIEW opset17 = ORT_TSTR("opset17");
static ORT_STRING_VIEW opset18 = ORT_TSTR("opset18");
static constexpr ORT_STRING_VIEW opset7 = ORT_TSTR("opset7");
static constexpr ORT_STRING_VIEW opset8 = ORT_TSTR("opset8");
static constexpr ORT_STRING_VIEW opset9 = ORT_TSTR("opset9");
static constexpr ORT_STRING_VIEW opset10 = ORT_TSTR("opset10");
static constexpr ORT_STRING_VIEW opset11 = ORT_TSTR("opset11");
static constexpr ORT_STRING_VIEW opset12 = ORT_TSTR("opset12");
static constexpr ORT_STRING_VIEW opset13 = ORT_TSTR("opset13");
static constexpr ORT_STRING_VIEW opset14 = ORT_TSTR("opset14");
static constexpr ORT_STRING_VIEW opset15 = ORT_TSTR("opset15");
static constexpr ORT_STRING_VIEW opset16 = ORT_TSTR("opset16");
static constexpr ORT_STRING_VIEW opset17 = ORT_TSTR("opset17");
static constexpr ORT_STRING_VIEW opset18 = ORT_TSTR("opset18");
// TODO: enable opset19 tests
// static ORT_STRING_VIEW opset19 = ORT_TSTR("opset19");
// static constexpr ORT_STRING_VIEW opset19 = ORT_TSTR("opset19");

static ORT_STRING_VIEW provider_name_cpu = ORT_TSTR("cpu");
static ORT_STRING_VIEW provider_name_tensorrt = ORT_TSTR("tensorrt");
static constexpr ORT_STRING_VIEW provider_name_cpu = ORT_TSTR("cpu");
static constexpr ORT_STRING_VIEW provider_name_tensorrt = ORT_TSTR("tensorrt");
#ifdef USE_MIGRAPHX
static ORT_STRING_VIEW provider_name_migraphx = ORT_TSTR("migraphx");
static constexpr ORT_STRING_VIEW provider_name_migraphx = ORT_TSTR("migraphx");
#endif
static ORT_STRING_VIEW provider_name_openvino = ORT_TSTR("openvino");
static ORT_STRING_VIEW provider_name_cuda = ORT_TSTR("cuda");
static constexpr ORT_STRING_VIEW provider_name_openvino = ORT_TSTR("openvino");
static constexpr ORT_STRING_VIEW provider_name_cuda = ORT_TSTR("cuda");
#ifdef USE_ROCM
static ORT_STRING_VIEW provider_name_rocm = ORT_TSTR("rocm");
static constexpr ORT_STRING_VIEW provider_name_rocm = ORT_TSTR("rocm");
#endif
static ORT_STRING_VIEW provider_name_dnnl = ORT_TSTR("dnnl");
static constexpr ORT_STRING_VIEW provider_name_dnnl = ORT_TSTR("dnnl");
// For any non-Android system, NNAPI will only be used for ort model converter
#if defined(USE_NNAPI) && defined(__ANDROID__)
static ORT_STRING_VIEW provider_name_nnapi = ORT_TSTR("nnapi");
static constexpr ORT_STRING_VIEW provider_name_nnapi = ORT_TSTR("nnapi");
#endif
#ifdef USE_RKNPU
static ORT_STRING_VIEW provider_name_rknpu = ORT_TSTR("rknpu");
static constexpr ORT_STRING_VIEW provider_name_rknpu = ORT_TSTR("rknpu");
#endif
#ifdef USE_ACL
static ORT_STRING_VIEW provider_name_acl = ORT_TSTR("acl");
static constexpr ORT_STRING_VIEW provider_name_acl = ORT_TSTR("acl");
#endif
#ifdef USE_ARMNN
static ORT_STRING_VIEW provider_name_armnn = ORT_TSTR("armnn");
static constexpr ORT_STRING_VIEW provider_name_armnn = ORT_TSTR("armnn");
#endif
static ORT_STRING_VIEW provider_name_dml = ORT_TSTR("dml");
static constexpr ORT_STRING_VIEW provider_name_dml = ORT_TSTR("dml");

::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
// Map key is provider name(CPU, CUDA, etc). Value is the ONNX node tests' opsets to run.
Expand Down Expand Up @@ -598,7 +601,7 @@ ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
ORT_TSTR("SSD"), // needs to run symbolic shape inference shape first
ORT_TSTR("size") // INVALID_ARGUMENT: Cannot find binding of given name: x
};
std::vector<std::basic_string<ORTCHAR_T>> paths;
std::vector<std::filesystem::path> paths;

for (std::pair<ORT_STRING_VIEW, std::vector<ORT_STRING_VIEW>> kvp : provider_names) {
// Setup ONNX node tests. The test data is preloaded on our CI build machines.
Expand Down Expand Up @@ -627,7 +630,7 @@ ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
}
#endif

ORT_STRING_VIEW provider_name = kvp.first;
const ORT_STRING_VIEW provider_name = kvp.first;
std::unordered_set<std::basic_string<ORTCHAR_T>> all_disabled_tests(std::begin(immutable_broken_tests),
std::end(immutable_broken_tests));
if (provider_name == provider_name_cuda) {
Expand Down Expand Up @@ -682,45 +685,41 @@ ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
all_disabled_tests.insert(ORT_TSTR("fp16_tiny_yolov2"));

while (!paths.empty()) {
std::basic_string<ORTCHAR_T> node_data_root_path = paths.back();
std::filesystem::path node_data_root_path = paths.back();
paths.pop_back();
std::basic_string<ORTCHAR_T> my_dir_name = GetLastComponent(node_data_root_path);
ORT_TRY {
LoopDir(node_data_root_path, [&](const ORTCHAR_T* filename, OrtFileType f_type) -> bool {
if (filename[0] == ORT_TSTR('.'))
return true;
if (f_type == OrtFileType::TYPE_DIR) {
std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(node_data_root_path, filename);
paths.push_back(p);
return true;
}
std::basic_string<PATH_CHAR_TYPE> filename_str = filename;
if (!HasExtensionOf(filename_str, ORT_TSTR("onnx")))
return true;
std::basic_string<PATH_CHAR_TYPE> test_case_name = my_dir_name;
if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0)
test_case_name = test_case_name.substr(5);
if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end())
return true;
if (!std::filesystem::exists(node_data_root_path) || !std::filesystem::is_directory(node_data_root_path)) {
continue;
}
for (auto const& dir_entry : std::filesystem::directory_iterator(node_data_root_path)) {
if (dir_entry.is_directory()) {
paths.push_back(dir_entry.path());
continue;
}
const std::filesystem::path& path = dir_entry.path();
if (!path.filename().has_extension()) {
continue;
}
if (path.filename().extension().compare(ORT_TSTR(".onnx")) != 0) continue;
std::basic_string<PATH_CHAR_TYPE> test_case_name = path.parent_path().filename().string();
if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0)
test_case_name = test_case_name.substr(5);
if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end())
continue;

#ifdef DISABLE_ML_OPS
auto starts_with = [](const std::basic_string<PATH_CHAR_TYPE>& find_in,
const std::basic_string<PATH_CHAR_TYPE>& find_what) {
return find_in.compare(0, find_what.size(), find_what) == 0;
};
if (starts_with(test_case_name, ORT_TSTR("XGBoost_")) || starts_with(test_case_name, ORT_TSTR("coreml_")) ||
starts_with(test_case_name, ORT_TSTR("scikit_")) || starts_with(test_case_name, ORT_TSTR("libsvm_"))) {
return true;
}
auto starts_with = [](const std::basic_string<PATH_CHAR_TYPE>& find_in,
const std::basic_string<PATH_CHAR_TYPE>& find_what) {
return find_in.compare(0, find_what.size(), find_what) == 0;
};
if (starts_with(test_case_name, ORT_TSTR("XGBoost_")) || starts_with(test_case_name, ORT_TSTR("coreml_")) ||
starts_with(test_case_name, ORT_TSTR("scikit_")) || starts_with(test_case_name, ORT_TSTR("libsvm_"))) {
continue;
}
#endif
std::basic_ostringstream<PATH_CHAR_TYPE> oss;
oss << provider_name << ORT_TSTR("_") << ConcatPathComponent(node_data_root_path, filename_str);
v.emplace_back(oss.str());
return true;
});
std::basic_ostringstream<PATH_CHAR_TYPE> oss;
oss << provider_name << ORT_TSTR("_") << path.string();
v.emplace_back(oss.str());
}
ORT_CATCH(const std::exception&) {
} // ignore non-exist dir
}
}
return v;
Expand Down

0 comments on commit b6aa1d4

Please sign in to comment.