Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AUTO] Refine the logic of creating HW plugins in AUTO #27691

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
18 changes: 9 additions & 9 deletions src/inference/src/dev/core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,8 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::shared_ptr<
// will consume ov::cache_dir if plugin not support it
auto cacheManager = parsed._core_config.get_cache_config_for_device(plugin, parsed._config)._cacheManager;
// Skip caching for proxy plugin. HW plugin will load network from the cache
if (cacheManager && device_supports_model_caching(plugin) && !is_proxy_device(plugin)) {
CacheContent cacheContent{cacheManager, parsed._core_config.get_enable_mmap()};
if (cacheManager && device_supports_model_caching(plugin, parsed._config) && !is_proxy_device(plugin)) {
CacheContent cacheContent{cacheManager};
cacheContent.blobId = ov::ModelCache::compute_hash(model, create_compile_config(plugin, parsed._config));
std::unique_ptr<CacheGuardEntry> lock = cacheGuard.get_hash_lock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, parsed._config, ov::SoPtr<ov::IRemoteContext>{}, [&]() {
Expand Down Expand Up @@ -817,8 +817,8 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::shared_ptr<
// will consume ov::cache_dir if plugin not support it
auto cacheManager = parsed._core_config.get_cache_config_for_device(plugin, parsed._config)._cacheManager;
// Skip caching for proxy plugin. HW plugin will load network from the cache
if (cacheManager && device_supports_model_caching(plugin) && !is_proxy_device(plugin)) {
CacheContent cacheContent{cacheManager, parsed._core_config.get_enable_mmap()};
if (cacheManager && device_supports_model_caching(plugin, parsed._config) && !is_proxy_device(plugin)) {
CacheContent cacheContent{cacheManager};
cacheContent.blobId = ov::ModelCache::compute_hash(model, create_compile_config(plugin, parsed._config));
std::unique_ptr<CacheGuardEntry> lock = cacheGuard.get_hash_lock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, parsed._config, context, [&]() {
Expand All @@ -841,7 +841,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::string& mod
// will consume ov::cache_dir if plugin not support it
auto cacheManager = parsed._core_config.get_cache_config_for_device(plugin, parsed._config)._cacheManager;

if (cacheManager && device_supports_model_caching(plugin) && !is_proxy_device(plugin)) {
if (cacheManager && device_supports_model_caching(plugin, parsed._config) && !is_proxy_device(plugin)) {
// Skip caching for proxy plugin. HW plugin will load network from the cache
CoreConfig::remove_core_skip_cache_dir(parsed._config);
CacheContent cacheContent{cacheManager, parsed._core_config.get_enable_mmap(), model_path};
Expand Down Expand Up @@ -869,8 +869,8 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::string& mod
// will consume ov::cache_dir if plugin not support it
auto cacheManager = parsed._core_config.get_cache_config_for_device(plugin, parsed._config)._cacheManager;
// Skip caching for proxy plugin. HW plugin will load network from the cache
if (cacheManager && device_supports_model_caching(plugin) && !is_proxy_device(plugin)) {
CacheContent cacheContent{cacheManager, parsed._core_config.get_enable_mmap()};
if (cacheManager && device_supports_model_caching(plugin, parsed._config) && !is_proxy_device(plugin)) {
CacheContent cacheContent{cacheManager};
cacheContent.blobId =
ov::ModelCache::compute_hash(model_str, weights, create_compile_config(plugin, parsed._config));
std::unique_ptr<CacheGuardEntry> lock = cacheGuard.get_hash_lock(cacheContent.blobId);
Expand Down Expand Up @@ -1387,8 +1387,8 @@ bool ov::CoreImpl::device_supports_internal_property(const ov::Plugin& plugin, c
return util::contains(plugin.get_property(ov::internal::supported_properties), key);
}

bool ov::CoreImpl::device_supports_model_caching(const ov::Plugin& plugin) const {
return plugin.supports_model_caching();
bool ov::CoreImpl::device_supports_model_caching(const ov::Plugin& plugin, const ov::AnyMap& arguments) const {
return plugin.supports_model_caching(arguments);
}

bool ov::CoreImpl::device_supports_cache_dir(const ov::Plugin& plugin) const {
Expand Down
2 changes: 1 addition & 1 deletion src/inference/src/dev/core_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class CoreImpl : public ov::ICore, public std::enable_shared_from_this<ov::ICore
const ov::SoPtr<ov::IRemoteContext>& context,
std::function<ov::SoPtr<ov::ICompiledModel>()> compile_model_lambda) const;

bool device_supports_model_caching(const ov::Plugin& plugin) const;
bool device_supports_model_caching(const ov::Plugin& plugin, const ov::AnyMap& origConfig = {}) const;

bool device_supports_property(const ov::Plugin& plugin, const ov::PropertyName& key) const;
bool device_supports_internal_property(const ov::Plugin& plugin, const ov::PropertyName& key) const;
Expand Down
9 changes: 5 additions & 4 deletions src/inference/src/dev/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,11 @@ ov::Any ov::Plugin::get_property(const std::string& name, const AnyMap& argument
return {m_ptr->get_property(name, arguments), {m_so}};
}

bool ov::Plugin::supports_model_caching() const {
bool ov::Plugin::supports_model_caching(const ov::AnyMap& arguments) const {
bool supported(false);
supported = util::contains(get_property(ov::supported_properties), ov::device::capabilities) &&
util::contains(get_property(ov::device::capabilities), ov::device::capability::EXPORT_IMPORT) &&
util::contains(get_property(ov::internal::supported_properties), ov::internal::caching_properties);
supported =
util::contains(get_property(ov::supported_properties), ov::device::capabilities) &&
util::contains(get_property(ov::device::capabilities, arguments), ov::device::capability::EXPORT_IMPORT) &&
util::contains(get_property(ov::internal::supported_properties), ov::internal::caching_properties);
return supported;
}
2 changes: 1 addition & 1 deletion src/inference/src/dev/plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Plugin {
T get_property(const ov::Property<T, M>& property, const AnyMap& arguments) const {
return get_property(property.name(), arguments).template as<T>();
}
bool supports_model_caching() const;
bool supports_model_caching(const AnyMap& arguments = {}) const;
};

} // namespace ov
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/auto/src/auto_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ void AutoSchedule::init() {
auto load_device_task = [&](AutoCompileContext* context_ptr, const std::shared_ptr<ov::Model>& model) {
try_to_compile_model(*context_ptr, model);
if (context_ptr->m_is_load_success) {
// release cloned model here
const_cast<std::shared_ptr<ov::Model>&>(model).reset();
if (context_ptr->m_worker_name.empty()) {
context_ptr->m_worker_name = context_ptr->m_device_info.device_name;
}
Expand Down
117 changes: 96 additions & 21 deletions src/plugins/auto/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ std::vector<DeviceInformation> Plugin::parse_meta_devices(const std::string& pri
auto device_id = get_core()->get_property(device_name, ov::device::id);
return device_id;
} catch (ov::Exception&) {
LOG_DEBUG_TAG("get default device id failed for ", device_name.c_str());
LOG_DEBUG_TAG("get default device id failed for %s", device_name.c_str());
return "";
}
};
Expand All @@ -188,7 +188,6 @@ std::vector<DeviceInformation> Plugin::parse_meta_devices(const std::string& pri
bool enable_device_priority = (prioritiesIter != properties.end()) &&
check_priority_config(prioritiesIter->second.as<std::string>());

auto device_list = get_core()->get_available_devices();
for (auto && d : devices_with_requests) {
auto opening_bracket = d.find_first_of('(');
auto closing_bracket = d.find_first_of(')', opening_bracket);
Expand All @@ -206,9 +205,27 @@ std::vector<DeviceInformation> Plugin::parse_meta_devices(const std::string& pri
ov::DeviceIDParser parsed{device_name};
std::string deviceid = parsed.get_device_id();
std::vector<std::string> same_type_devices;
// if AUTO:GPU case, replace GPU with GPU.0 and GPU.1

if (deviceid.empty()) {
for (auto&& device : device_list) {
// if AUTO:GPU case, replace GPU with GPU.0 and GPU.1
std::vector<std::string> device_list_with_id = {};
try {
auto device_id_list = get_core()
->get_property(parsed.get_device_name(), ov::available_devices.name(), {})
.as<std::vector<std::string>>();
for (auto&& device_id : device_id_list) {
if (device_id.empty())
continue;
device_list_with_id.push_back(parsed.get_device_name() + "." + device_id);
}
if (device_id_list.empty()) {
device_id_list.push_back(parsed.get_device_name());
}
} catch (const ov::Exception&) {
device_list_with_id.push_back(parsed.get_device_name());
LOG_DEBUG_TAG("Failed to get available devices for ", parsed.get_device_name().c_str());
}
for (auto&& device : device_list_with_id) {
if (device.find(device_name) != std::string::npos) {
same_type_devices.push_back(std::move(device));
}
Expand Down Expand Up @@ -281,11 +298,29 @@ ov::Any Plugin::get_property(const std::string& name, const ov::AnyMap& argument
} else if (name == ov::device::full_name) {
return decltype(ov::device::full_name)::value_type {get_device_name()};
} else if (name == ov::device::capabilities.name()) {
auto device_list = get_core()->get_available_devices();
std::vector<std::string> device_list = arguments.count(ov::device::priorities.name())
? m_plugin_config.parse_priorities_devices(
arguments.at(ov::device::priorities.name()).as<std::string>())
: get_core()->get_available_devices();
bool enable_startup_cpu = arguments.count(ov::intel_auto::enable_startup_fallback.name())
? arguments.at(ov::intel_auto::enable_startup_fallback.name()).as<bool>()
: true;
bool enable_runtime_cpu = arguments.count(ov::intel_auto::enable_runtime_fallback.name())
? arguments.at(ov::intel_auto::enable_runtime_fallback.name()).as<bool>()
: true;
bool enable_cpu = enable_startup_cpu || enable_runtime_cpu;
std::vector<std::string> capabilities;
for (auto const & device : device_list) {
auto devCapabilities = get_core()->get_property(device, ov::device::capabilities);
capabilities.insert(capabilities.end(), devCapabilities.begin(), devCapabilities.end());
for (auto const& device : device_list) {
auto real_device = device[0] == '-' ? device.substr(1) : device;
if (real_device.find("CPU") != std::string::npos && !enable_cpu) {
continue;
}
try {
auto devCapabilities = get_core()->get_property(real_device, ov::device::capabilities);
capabilities.insert(capabilities.end(), devCapabilities.begin(), devCapabilities.end());
} catch (const ov::Exception&) {
LOG_DEBUG_TAG("Failed to get capabilities for device: ", device.c_str());
}
}
std::sort(capabilities.begin(), capabilities.end());
capabilities.resize(std::distance(capabilities.begin(), std::unique(capabilities.begin(), capabilities.end())));
Expand Down Expand Up @@ -460,7 +495,14 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model_impl(const std::string
if (is_cumulative) {
impl = std::make_shared<AutoCumuCompiledModel>(cloned_model, shared_from_this(), device_context, auto_s_context, scheduler);
} else {
impl = std::make_shared<AutoCompiledModel>(cloned_model, shared_from_this(), device_context, auto_s_context, scheduler);
auto model = auto_s_context->m_model;
if (std::static_pointer_cast<AutoSchedule>(scheduler)->m_compile_context[ACTUALDEVICE].m_is_already) {
// release cloned model here if actual device finish compiling model.
model.reset();
auto_s_context->m_model.reset();
}
impl =
std::make_shared<AutoCompiledModel>(model, shared_from_this(), device_context, auto_s_context, scheduler);
}
return impl;
}
Expand Down Expand Up @@ -656,7 +698,6 @@ void Plugin::register_priority(const unsigned int& priority, const std::string&
std::string Plugin::get_device_list(const ov::AnyMap& properties) const {
std::string all_devices;
std::string device_architecture;
auto device_list = get_core()->get_available_devices();
auto device_list_config = properties.find(ov::device::priorities.name());
auto get_gpu_architecture = [&](const std::string& name) -> std::string {
try {
Expand All @@ -667,17 +708,8 @@ std::string Plugin::get_device_list(const ov::AnyMap& properties) const {
}
return "";
};
for (auto&& device : device_list) {
// filter out the supported devices
if (device.find("GPU") != std::string::npos) {
device_architecture = get_gpu_architecture(device);
}
if (!m_plugin_config.is_supported_device(device, device_architecture))
continue;
all_devices += device + ",";
}
std::vector<std::string> devices_merged;
if (device_list_config != properties.end() && !device_list_config->second.empty()) {
if (device_list_config != properties.end() && !(device_list_config->second.as<std::string>().empty())) {
auto priorities = device_list_config->second;
// parsing the string and splitting the comma-separated tokens
std::vector<std::string> devices_to_be_merged = m_plugin_config.parse_priorities_devices(priorities.as<std::string>());
Expand Down Expand Up @@ -718,7 +750,9 @@ std::string Plugin::get_device_list(const ov::AnyMap& properties) const {
return device.find(".") == std::string::npos ? device + ".0" : device;
};
if (devices_to_be_merged.empty()) {
auto device_list = get_core()->get_available_devices();
for (auto&& device : device_list) {
all_devices += device + ",";
if (device.find("GPU") != std::string::npos) {
device_architecture = get_gpu_architecture(device);
}
Expand All @@ -728,8 +762,38 @@ std::string Plugin::get_device_list(const ov::AnyMap& properties) const {
}
} else {
for (auto&& device : devices_to_be_merged) {
ov::DeviceIDParser parsed{device};
std::vector<std::string> device_list = {};
try {
if (parsed.get_device_name().find("CPU") != std::string::npos) {
bool enable_startup_cpu =
properties.count(ov::intel_auto::enable_startup_fallback.name())
? properties.at(ov::intel_auto::enable_startup_fallback.name()).as<bool>()
: true;
bool enable_runtime_cpu =
properties.count(ov::intel_auto::enable_runtime_fallback.name())
? properties.at(ov::intel_auto::enable_runtime_fallback.name()).as<bool>()
: true;
// Skip to load CPU device if both startup and runtime fallback are disabled
if (!enable_startup_cpu && !enable_runtime_cpu)
continue;
}
auto device_id_list = get_core()
->get_property(parsed.get_device_name(), ov::available_devices.name(), {})
.as<std::vector<std::string>>();
for (auto&& device_id : device_id_list) {
if (device_id.empty())
continue;
device_list.push_back(parsed.get_device_name() + "." + device_id);
}
if (device_id_list.empty()) {
device_id_list.push_back(parsed.get_device_name());
}
} catch (const ov::Exception&) {
device_list.push_back(parsed.get_device_name());
LOG_DEBUG_TAG("no available devices found for %s", device.c_str());
}
if (!is_any_dev(device, device_list)) {
ov::DeviceIDParser parsed{device};
auto iter = std::find(devices_merged.begin(), devices_merged.end(), parsed.get_device_name());
if (iter != devices_merged.end() && parsed.get_device_name() != device && parsed.get_device_id() == "0")
// The device is the device with default device ID (eg. GPU.0) and
Expand Down Expand Up @@ -761,7 +825,18 @@ std::string Plugin::get_device_list(const ov::AnyMap& properties) const {
std::for_each(devices_merged.begin(), devices_merged.end(), [&all_devices](const std::string& device) {
all_devices += device + ",";
});
} else {
auto device_list = get_core()->get_available_devices();
for (auto&& device : device_list) {
if (device.find("GPU") != std::string::npos) {
device_architecture = get_gpu_architecture(device);
}
if (!m_plugin_config.is_supported_device(device, device_architecture))
continue;
all_devices += device + ",";
}
}

if (all_devices.empty()) {
OPENVINO_THROW("Please, check environment due to no supported devices can be used");
}
Expand Down
Loading
Loading