Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: backend: dnnl: encode mem address into constant cache key #2312

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ dnnl::memory::format_tag get_format_tag(const dnnl::memory::desc &md) {
return format_tag;
}

size_t generate_constant_cache_key(
size_t generate_constant_md_hash(
size_t part_id, const std::vector<dnnl::memory::desc> &const_mds) {
size_t key = 0;
key = hash_combine(key, part_id);
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ std::string get_format_tag_str(const dnnl::memory::desc &md);

dnnl::memory::format_tag get_format_tag(const dnnl::memory::desc &md);

size_t generate_constant_cache_key(
size_t generate_constant_md_hash(
size_t part_id, const std::vector<dnnl::memory::desc> &const_mds);

#ifndef NDEBUG
Expand Down
14 changes: 10 additions & 4 deletions src/graph/backend/dnnl/kernels/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ status_t batch_norm_fwd_t::compile_impl(const dnnl_partition_impl_t *part,
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand All @@ -135,9 +135,11 @@ status_t batch_norm_fwd_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -204,9 +206,11 @@ status_t batch_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -279,9 +283,11 @@ status_t batch_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/batch_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct batch_norm_fwd_t : public kernel_base_t {
memory_planner_t memory_planner_;

std::function<std::shared_ptr<execution_args_set_t>()> resource_ctor_;
constant_cache_t::key_t constant_key_ = 0;
size_t const_md_hash_ = 0;

public:
batch_norm_fwd_t() {
Expand Down
4 changes: 2 additions & 2 deletions src/graph/backend/dnnl/kernels/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ status_t conv_fwd_t<quantized>::compile_impl(const dnnl_partition_impl_t *part,
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand Down Expand Up @@ -202,7 +202,7 @@ status_t conv_bwd_data_t::compile_impl(const dnnl_partition_impl_t *part,
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand Down
12 changes: 9 additions & 3 deletions src/graph/backend/dnnl/kernels/conv_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ status_t conv_base_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -132,9 +134,11 @@ status_t conv_base_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -207,9 +211,11 @@ status_t conv_base_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/conv_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct conv_base_t : public kernel_base_t {

std::function<std::shared_ptr<execution_args_set_t>()> resource_ctor_;

constant_cache_t::key_t constant_key_ = 0;
size_t const_md_hash_ = 0;

public:
conv_base_t() {
Expand Down
4 changes: 2 additions & 2 deletions src/graph/backend/dnnl/kernels/conv_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ status_t conv_transpose_fwd_t<quantized>::compile_impl(
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand Down Expand Up @@ -183,7 +183,7 @@ status_t conv_transpose_bwd_data_t::compile_impl(
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand Down
14 changes: 10 additions & 4 deletions src/graph/backend/dnnl/kernels/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ status_t eltwise_fwd_t<quantized>::compile_impl(
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand Down Expand Up @@ -138,9 +138,11 @@ status_t eltwise_fwd_t<quantized>::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -208,9 +210,11 @@ status_t eltwise_fwd_t<quantized>::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -284,9 +288,11 @@ status_t eltwise_fwd_t<quantized>::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/eltwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct eltwise_fwd_t : public kernel_base_t {

std::function<std::shared_ptr<execution_args_set_t>()> resource_ctor_;

constant_cache_t::key_t constant_key_ = 0;
size_t const_md_hash_ = 0;

public:
eltwise_fwd_t() {
Expand Down
14 changes: 10 additions & 4 deletions src/graph/backend/dnnl/kernels/group_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ status_t group_norm_fwd_t::compile_impl(const dnnl_partition_impl_t *part,
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand Down Expand Up @@ -143,9 +143,11 @@ status_t group_norm_fwd_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -212,9 +214,11 @@ status_t group_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -287,9 +291,11 @@ status_t group_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/group_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct group_norm_fwd_t : public kernel_base_t {

std::function<std::shared_ptr<execution_args_set_t>()> resource_ctor_;

constant_cache_t::key_t constant_key_ = 0;
size_t const_md_hash_ = 0;

public:
group_norm_fwd_t() {
Expand Down
13 changes: 13 additions & 0 deletions src/graph/backend/dnnl/kernels/kernel_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ bool kernel_base_t::enabled_constant_cache() const {
return enabled;
}

size_t kernel_base_t::encode_constant_cache_key(
const std::vector<tensor_t> &inputs, size_t cache_key) const {
// Encode the constant memory address into cache key for differentiation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this, the original constant_key_ is not a cache key anymore. I would suggest to rename these variables as well as the function name of generate_constant_cache_key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Rename to const_md_hash_ and generate_constant_md_hash accordingly, please review again.

size_t encoded_cache_key = cache_key;
for (const auto &in : inputs) {
if (logical_tensor_wrapper_t(in.get_logical_tensor()).is_constant()) {
encoded_cache_key = hash_combine(encoded_cache_key,
reinterpret_cast<uintptr_t>(in.get_data_handle()));
}
}
return encoded_cache_key;
}

const std::vector<inplace_pair_t> &kernel_base_t::get_inplace_pairs() const {
return inplace_pairs_;
};
Expand Down
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/kernels/kernel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ struct kernel_base_t {

bool enabled_constant_cache() const;

size_t encode_constant_cache_key(
const std::vector<tensor_t> &inputs, size_t cache_key) const;

const std::vector<inplace_pair_t> &get_inplace_pairs() const;

protected:
Expand Down
14 changes: 10 additions & 4 deletions src/graph/backend/dnnl/kernels/large_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ status_t larger_partition_kernel_t::compile_impl(
return this->memory_planner_.get_exec_args_set().clone();
};

constant_key_ = generate_constant_cache_key(part->id(),
const_md_hash_ = generate_constant_md_hash(part->id(),
memory_planner_.get_exec_args_set().get_persistent_mem_desc_list());

return status::success;
Expand All @@ -248,9 +248,11 @@ status_t larger_partition_kernel_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -316,9 +318,11 @@ status_t larger_partition_kernel_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -389,9 +393,11 @@ status_t larger_partition_kernel_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
const size_t encoded_key
= encode_constant_cache_key(inputs, const_md_hash_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/large_partition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class larger_partition_kernel_t : public kernel_base_t {

std::function<std::shared_ptr<execution_args_set_t>()> resource_ctor_;

constant_cache_t::key_t constant_key_ = 0;
size_t const_md_hash_ = 0;

std::once_flag once_flag_;
subgraph_visualizer_t vis_;
Expand Down
Loading
Loading