Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 18, 2024
1 parent b6aa1d4 commit ed3125f
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 180 deletions.
132 changes: 68 additions & 64 deletions onnxruntime/test/onnx/TestCase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,12 @@ void LoopDataFile(int test_data_pb_fd, bool is_input, const TestModelInfo& model
} // namespace

#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<TestModelInfo> TestModelInfo::LoadOnnxModel(_In_ const PATH_CHAR_TYPE* model_url) {
std::unique_ptr<TestModelInfo> TestModelInfo::LoadOnnxModel(const std::filesystem::path& model_url) {
return std::make_unique<OnnxModelInfo>(model_url);
}
#endif

std::unique_ptr<TestModelInfo> TestModelInfo::LoadOrtModel(_In_ const PATH_CHAR_TYPE* model_url) {
std::unique_ptr<TestModelInfo> TestModelInfo::LoadOrtModel(const std::filesystem::path& model_url) {
return std::make_unique<OnnxModelInfo>(model_url, true);
}

Expand All @@ -290,7 +290,7 @@ class OnnxTestCase : public ITestCase {
mutable std::vector<std::string> debuginfo_strings_;
mutable onnxruntime::OrtMutex m_;

std::vector<std::basic_string<PATH_CHAR_TYPE>> test_data_dirs_;
std::vector<std::filesystem::path> test_data_dirs_;

std::string GetDatasetDebugInfoString(size_t dataset_id) const override {
std::lock_guard<OrtMutex> l(m_);
Expand Down Expand Up @@ -343,7 +343,7 @@ class OnnxTestCase : public ITestCase {

size_t GetDataCount() const override { return test_data_dirs_.size(); }
const std::string& GetNodeName() const override { return model_info_->GetNodeName(); }
const PATH_CHAR_TYPE* GetModelUrl() const override { return model_info_->GetModelUrl(); }
const std::filesystem::path& GetModelUrl() const override { return model_info_->GetModelUrl(); }
const std::string& GetTestCaseName() const override { return test_case_name_; }
std::string GetTestCaseVersion() const override { return model_info_->GetNominalOpsetVersion(); }

Expand Down Expand Up @@ -396,7 +396,14 @@ static std::string trim_str(const std::string& in) {
return s;
}

static bool read_config_file(const std::basic_string<PATH_CHAR_TYPE>& path, std::map<std::string, std::string>& fc) {
/**
* @brief Read a text file that each line is a key value pair separated by ':'
* @param path File path
* @param fc output key value pairs
* @return True, success. False, the file doesn't exist or could be read.
*/
static bool ReadConfigFile(const std::filesystem::path& path, std::map<std::string, std::string>& fc) {
if (!std::filesystem::exists(path)) return false;
std::ifstream infile(path);
if (!infile.good()) {
return false;
Expand Down Expand Up @@ -474,10 +481,10 @@ void OnnxTestCase::LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b,
ORT_THROW("index out of bound");
}

PATH_STRING_TYPE test_data_pb = ConcatPathComponent(
test_data_dirs_[id], (is_input ? ORT_TSTR("inputs.pb") : ORT_TSTR("outputs.pb")));
std::filesystem::path test_data_pb =
test_data_dirs_[id] / (is_input ? ORT_TSTR("inputs.pb") : ORT_TSTR("outputs.pb"));
int test_data_pb_fd;
auto st = Env::Default().FileOpenRd(test_data_pb, test_data_pb_fd);
auto st = Env::Default().FileOpenRd(test_data_pb.string(), test_data_pb_fd);
if (st.IsOK()) { // has an all-in-one input file
std::ostringstream oss;
{
Expand Down Expand Up @@ -505,20 +512,24 @@ void OnnxTestCase::LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b,
std::vector<PATH_STRING_TYPE> test_data_pb_files;

const PATH_STRING_TYPE& dir_path = test_data_dirs_[id];
LoopDir(dir_path,
[&test_data_pb_files, &dir_path, is_input](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool {
if (filename[0] == '.') return true;
if (f_type != OrtFileType::TYPE_REG) return true;
std::basic_string<PATH_CHAR_TYPE> filename_str = filename;
if (!HasExtensionOf(filename_str, ORT_TSTR("pb"))) return true;
const std::basic_string<PATH_CHAR_TYPE> file_prefix =
is_input ? ORT_TSTR("input_") : ORT_TSTR("output_");
if (!filename_str.compare(0, file_prefix.length(), file_prefix)) {
std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(dir_path, filename_str);
test_data_pb_files.push_back(p);
}
return true;
});
std::filesystem::path dir_fs_path(dir_path);
if (!std::filesystem::exists(dir_fs_path)) return;

for (auto const& dir_entry : std::filesystem::directory_iterator(dir_fs_path)) {
if (!dir_entry.is_regular_file()) continue;
const std::filesystem::path& path = dir_entry.path();
if (!path.filename().has_extension()) {
continue;
}
if (path.filename().extension().compare(ORT_TSTR(".pb")) != 0) continue;
const std::basic_string<PATH_CHAR_TYPE> file_prefix =
is_input ? ORT_TSTR("input_") : ORT_TSTR("output_");
auto filename_str = path.filename().string();
if (filename_str.compare(0, file_prefix.length(), file_prefix) == 0) {
std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(dir_path, filename_str);
test_data_pb_files.push_back(p);
}
}

SortFileNames(test_data_pb_files);

Expand Down Expand Up @@ -691,11 +702,13 @@ void OnnxTestCase::ConvertTestData(const ONNX_NAMESPACE::OptionalProto& test_dat
OnnxTestCase::OnnxTestCase(const std::string& test_case_name, _In_ std::unique_ptr<TestModelInfo> model,
double default_per_sample_tolerance, double default_relative_per_sample_tolerance)
: test_case_name_(test_case_name), model_info_(std::move(model)) {
std::basic_string<PATH_CHAR_TYPE> test_case_dir = model_info_->GetDir();

std::filesystem::path test_case_dir = model_info_->GetDir();
if (!std::filesystem::exists(test_case_dir)) {
ORT_THROW("test case dir doesn't exist");
}
// parse config
std::basic_string<PATH_CHAR_TYPE> config_path =
ConcatPathComponent(test_case_dir, ORT_TSTR("config.txt"));
std::filesystem::path config_path =
test_case_dir / ORT_TSTR("config.txt");
/* Note: protobuf-lite doesn't support reading protobuf files as text-format. Config.txt is exactly that.
That's the reason I've to parse the file in a different way to read the configs. Currently
this affects 2 tests - fp16_tiny_yolov2 and fp16_inception_v1. It's not clear why we've to use protobuf
Expand All @@ -705,7 +718,7 @@ OnnxTestCase::OnnxTestCase(const std::string& test_case_name, _In_ std::unique_p
per_sample_tolerance_ = default_per_sample_tolerance;
relative_per_sample_tolerance_ = default_relative_per_sample_tolerance;
post_processing_ = false;
if (read_config_file(config_path, fc)) {
if (ReadConfigFile(config_path, fc)) {
if (fc.count("per_sample_tolerance") > 0) {
per_sample_tolerance_ = stod(fc["per_sample_tolerance"]);
}
Expand All @@ -716,16 +729,11 @@ OnnxTestCase::OnnxTestCase(const std::string& test_case_name, _In_ std::unique_p
post_processing_ = fc["post_processing"] == "true";
}
}

LoopDir(test_case_dir, [&test_case_dir, this](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool {
if (filename[0] == '.') return true;
if (f_type == OrtFileType::TYPE_DIR) {
std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(test_case_dir, filename);
test_data_dirs_.push_back(p);
debuginfo_strings_.push_back(ToUTF8String(p));
}
return true;
});
for (auto const& dir_entry : std::filesystem::directory_iterator(test_case_dir)) {
if (!dir_entry.is_directory()) continue;
test_data_dirs_.push_back(dir_entry.path());
debuginfo_strings_.push_back(ToUTF8String(dir_entry.path().string()));
}
}

void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths,
Expand All @@ -737,20 +745,19 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
const std::function<void(std::unique_ptr<ITestCase>)>& process_function) {
std::vector<std::basic_string<PATH_CHAR_TYPE>> paths(input_paths);
while (!paths.empty()) {
std::basic_string<PATH_CHAR_TYPE> node_data_root_path = paths.back();
std::filesystem::path node_data_root_path = paths.back();
paths.pop_back();
std::basic_string<PATH_CHAR_TYPE> my_dir_name = GetLastComponent(node_data_root_path);
LoopDir(node_data_root_path, [&](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool {
if (filename[0] == '.') return true;
if (f_type == OrtFileType::TYPE_DIR) {
std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(node_data_root_path, filename);
paths.push_back(p);
return true;
if (!std::filesystem::exists(node_data_root_path)) continue;
std::filesystem::path my_dir_name = node_data_root_path.filename();
for (auto const& dir_entry : std::filesystem::directory_iterator(node_data_root_path)) {
if (dir_entry.is_directory()) {
paths.push_back(dir_entry.path());
continue;
}

std::basic_string<PATH_CHAR_TYPE> filename_str = filename;
bool is_onnx_format = HasExtensionOf(filename_str, ORT_TSTR("onnx"));
bool is_ort_format = HasExtensionOf(filename_str, ORT_TSTR("ort"));
if (!dir_entry.is_regular_file()) continue;
std::filesystem::path filename_str = dir_entry.path().filename();
bool is_onnx_format = filename_str.has_extension() && (filename_str.extension().compare(ORT_TSTR(".onnx")) == 0);
bool is_ort_format = filename_str.has_extension() && (filename_str.extension().compare(ORT_TSTR(".ort")) == 0);
bool is_valid_model = false;

#if !defined(ORT_MINIMAL_BUILD)
Expand All @@ -759,42 +766,40 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths

is_valid_model = is_valid_model || is_ort_format;
if (!is_valid_model)
return true;
continue;

std::basic_string<PATH_CHAR_TYPE> test_case_name = my_dir_name;
std::basic_string<PATH_CHAR_TYPE> test_case_name = my_dir_name.native();
if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) test_case_name = test_case_name.substr(5);

if (!whitelisted_test_cases.empty() && std::find(whitelisted_test_cases.begin(), whitelisted_test_cases.end(),
test_case_name) == whitelisted_test_cases.end()) {
return true;
continue;
}
if (disabled_tests.find(test_case_name) != disabled_tests.end()) return true;

std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(node_data_root_path, filename_str);
if (disabled_tests.find(test_case_name) != disabled_tests.end()) continue;

std::unique_ptr<TestModelInfo> model_info;

if (is_onnx_format) {
#if !defined(ORT_MINIMAL_BUILD)
model_info = TestModelInfo::LoadOnnxModel(p.c_str());
model_info = TestModelInfo::LoadOnnxModel(dir_entry.path());
#else
ORT_THROW("onnx model is not supported in this build");
#endif
} else if (is_ort_format) {
model_info = TestModelInfo::LoadOrtModel(p.c_str());
model_info = TestModelInfo::LoadOrtModel(dir_entry.path());
} else {
ORT_NOT_IMPLEMENTED(ToUTF8String(filename_str), " is not supported");
}

auto test_case_dir = model_info->GetDir();
auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir;
auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir.native();

#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN)
// to skip some models like *-int8 or *-qdq
if ((reinterpret_cast<OnnxModelInfo*>(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) ||
(reinterpret_cast<OnnxModelInfo*>(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain");
return true;
continue;
}
#endif

Expand All @@ -809,7 +814,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
});
if (!has_test_data) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to no test data");
return true;
continue;
}

if (broken_tests) {
Expand All @@ -820,7 +825,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
(opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() ||
iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests");
return true;
continue;
}
}

Expand All @@ -829,7 +834,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
std::string keyword = *iter2;
if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords");
return true;
continue;
}
}
}
Expand All @@ -841,8 +846,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
tolerances.relative(tolerance_key));
fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str());
process_function(std::move(l));
return true;
});
}
}
}

Expand Down
18 changes: 7 additions & 11 deletions onnxruntime/test/onnx/TestCase.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#include <filesystem>
#include <core/common/common.h>
#include <core/common/status.h>
#include <core/platform/path_lib.h>
Expand All @@ -31,7 +32,7 @@ class ITestCase {
virtual void LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b,
std::unordered_map<std::string, Ort::Value>& name_data_map,
bool is_input) const = 0;
virtual const PATH_CHAR_TYPE* GetModelUrl() const = 0;
virtual const std::filesystem::path& GetModelUrl() const = 0;
virtual const std::string& GetNodeName() const = 0;
virtual const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const = 0;
virtual const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t i) const = 0;
Expand All @@ -50,14 +51,9 @@ class ITestCase {

class TestModelInfo {
public:
virtual const PATH_CHAR_TYPE* GetModelUrl() const = 0;
virtual std::basic_string<PATH_CHAR_TYPE> GetDir() const {
std::basic_string<PATH_CHAR_TYPE> test_case_dir;
auto st = onnxruntime::GetDirNameFromFilePath(GetModelUrl(), test_case_dir);
if (!st.IsOK()) {
ORT_THROW("GetDirNameFromFilePath failed");
}
return test_case_dir;
virtual const std::filesystem::path& GetModelUrl() const = 0;
virtual std::filesystem::path GetDir() const {
return GetModelUrl().parent_path();
}
virtual const std::string& GetNodeName() const = 0;
virtual const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const = 0;
Expand All @@ -70,10 +66,10 @@ class TestModelInfo {
virtual ~TestModelInfo() = default;

#if !defined(ORT_MINIMAL_BUILD)
static std::unique_ptr<TestModelInfo> LoadOnnxModel(_In_ const PATH_CHAR_TYPE* model_url);
static std::unique_ptr<TestModelInfo> LoadOnnxModel(const std::filesystem::path& model_url);
#endif

static std::unique_ptr<TestModelInfo> LoadOrtModel(_In_ const PATH_CHAR_TYPE* model_url);
static std::unique_ptr<TestModelInfo> LoadOrtModel(const std::filesystem::path& model_url);

static const std::string unknown_version;
};
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/onnx/onnx_model_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

using namespace onnxruntime;

OnnxModelInfo::OnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url, bool is_ort_model)
OnnxModelInfo::OnnxModelInfo(const std::filesystem::path& model_url, bool is_ort_model)
: model_url_(model_url) {
if (is_ort_model) {
InitOrtModelInfo(model_url);
Expand All @@ -38,7 +38,7 @@ static void RepeatedPtrFieldToVector(const ::google::protobuf::RepeatedPtrField<
}
}

void OnnxModelInfo::InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url) { // parse model
void OnnxModelInfo::InitOnnxModelInfo(const std::filesystem::path& model_url) { // parse model
int model_fd;
auto st = Env::Default().FileOpenRd(model_url, model_fd);
if (!st.IsOK()) {
Expand Down Expand Up @@ -91,7 +91,7 @@ void OnnxModelInfo::InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url) { /

#endif // #if !defined(ORT_MINIMAL_BUILD)

void OnnxModelInfo::InitOrtModelInfo(_In_ const PATH_CHAR_TYPE* model_url) {
void OnnxModelInfo::InitOrtModelInfo(const std::filesystem::path& model_url) {
std::vector<uint8_t> bytes;
size_t num_bytes = 0;
const auto model_location = ToWideString(model_url);
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/test/onnx/onnx_model_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ class OnnxModelInfo : public TestModelInfo {
std::vector<ONNX_NAMESPACE::ValueInfoProto> input_value_info_;
std::vector<ONNX_NAMESPACE::ValueInfoProto> output_value_info_;
std::unordered_map<std::string, int64_t> domain_to_version_;
const std::basic_string<PATH_CHAR_TYPE> model_url_;
const std::filesystem::path model_url_;

#if !defined(ORT_MINIMAL_BUILD)
void InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url);
void InitOnnxModelInfo(const std::filesystem::path& model_url);
#endif

void InitOrtModelInfo(_In_ const PATH_CHAR_TYPE* model_url);
void InitOrtModelInfo(const std::filesystem::path& model_url);

public:
OnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url, bool is_ort_model = false);
OnnxModelInfo(const std::filesystem::path& path, bool is_ort_model = false);
bool HasDomain(const std::string& name) const {
return domain_to_version_.find(name) != domain_to_version_.end();
}
Expand All @@ -32,7 +32,7 @@ class OnnxModelInfo : public TestModelInfo {
return iter == domain_to_version_.end() ? -1 : iter->second;
}

const PATH_CHAR_TYPE* GetModelUrl() const override { return model_url_.c_str(); }
const std::filesystem::path& GetModelUrl() const override { return model_url_; }
std::string GetNominalOpsetVersion() const override { return onnx_nominal_opset_vesion_; }

const std::string& GetNodeName() const override { return node_name_; }
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/testcase_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ bool TestCaseRequestContext::SetupSession() {
ORT_TRY {
const auto* test_case_name = test_case_.GetTestCaseName().c_str();
session_opts_.SetLogId(test_case_name);
Ort::Session session{env_, test_case_.GetModelUrl(), session_opts_};
Ort::Session session{env_, test_case_.GetModelUrl().native().c_str(), session_opts_};
session_ = std::move(session);
LOGF_DEFAULT(INFO, "Testing %s\n", test_case_name);
return true;
Expand Down
Loading

0 comments on commit ed3125f

Please sign in to comment.