Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unfixed kv heads number #1416

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class CacheManager {
}
OPENVINO_ASSERT(m_key_cache.size() == m_value_cache.size());
m_num_allocated_kv_blocks = num_kv_blocks;
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(), num_kv_blocks);

const std::string device_name = m_device_config.get_device();

Expand All @@ -56,6 +54,8 @@ class CacheManager {

if (device_name.find("GPU") == std::string::npos) {// Allocate KV caches
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Tensor key_cache(m_device_config.get_cache_precision(), key_cache_shape);
ov::Tensor value_cache(m_device_config.get_cache_precision(), value_cache_shape);

Expand Down Expand Up @@ -104,6 +104,8 @@ class CacheManager {
} else {
auto remote_context = m_core.get_default_context(device_name);
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Tensor key_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
key_cache_shape);
ov::Tensor value_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
Expand Down Expand Up @@ -142,30 +144,27 @@ class CacheManager {
}

void copy_blocks(const std::map<size_t, std::list<size_t>>& block_copy_map) {
ov::Shape key_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(), m_num_allocated_kv_blocks);
ov::Shape value_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(), m_num_allocated_kv_blocks);

ov::Coordinate key_src_start_roi(key_shape.size(), 0);
ov::Coordinate key_src_end_roi = key_shape;
ov::Coordinate key_dst_start_roi(key_shape.size(), 0);
ov::Coordinate key_dst_end_roi = key_shape;

ov::Coordinate value_src_start_roi(value_shape.size(), 0);
ov::Coordinate value_src_end_roi = value_shape;
ov::Coordinate value_dst_start_roi(value_shape.size(), 0);
ov::Coordinate value_dst_end_roi = value_shape;

for (const auto & blocks_pair : block_copy_map) {
size_t src_block_id = blocks_pair.first;
key_src_end_roi[0] = (key_src_start_roi[0] = src_block_id) + 1;
value_src_end_roi[0] = (value_src_start_roi[0] = src_block_id) + 1;

const std::list<size_t>& dst_block_ids = blocks_pair.second;
for (size_t dst_block_id : dst_block_ids) {
key_dst_end_roi[0] = (key_dst_start_roi[0] = dst_block_id) + 1;
value_dst_end_roi[0] = (value_dst_start_roi[0] = dst_block_id) + 1;

for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape key_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
ov::Shape value_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
ov::Coordinate key_src_start_roi(key_shape.size(), 0);
ov::Coordinate key_src_end_roi = key_shape;
ov::Coordinate key_dst_start_roi(key_shape.size(), 0);
ov::Coordinate key_dst_end_roi = key_shape;

ov::Coordinate value_src_start_roi(value_shape.size(), 0);
ov::Coordinate value_src_end_roi = value_shape;
ov::Coordinate value_dst_start_roi(value_shape.size(), 0);
ov::Coordinate value_dst_end_roi = value_shape;
key_src_end_roi[0] = (key_src_start_roi[0] = src_block_id) + 1;
value_src_end_roi[0] = (value_src_start_roi[0] = src_block_id) + 1;
key_dst_end_roi[0] = (key_dst_start_roi[0] = dst_block_id) + 1;
value_dst_end_roi[0] = (value_dst_start_roi[0] = dst_block_id) + 1;

ov::Tensor key_src_cache_roi(m_key_cache[decoder_layer_id], key_src_start_roi, key_src_end_roi);
ov::Tensor key_dst_cache_roi(m_key_cache[decoder_layer_id], key_dst_start_roi, key_dst_end_roi);

Expand Down
61 changes: 40 additions & 21 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
namespace ov::genai {
class DeviceConfig {
ov::element::Type m_kv_cache_type;
ov::PartialShape m_key_cache_shape, m_value_cache_shape;
ov::Shape::value_type m_num_kv_heads, m_head_size, m_num_decoder_layers;
std::vector<ov::PartialShape> m_key_cache_shape, m_value_cache_shape;
std::vector<ov::Shape::value_type> m_num_kv_heads;
ov::Shape::value_type m_head_size, m_num_decoder_layers;
size_t m_num_kv_blocks = 0;
size_t m_block_size = 0;
size_t m_cache_size = 0;
Expand Down Expand Up @@ -88,11 +89,14 @@ class DeviceConfig {
}
}

void set_model_params(size_t num_kv_heads, size_t head_size, size_t num_decoder_layers) {
m_num_kv_heads = num_kv_heads;
void set_model_params(std::vector<size_t> num_kv_heads, size_t head_size, size_t num_decoder_layers) {
m_head_size = head_size;
m_num_decoder_layers = num_decoder_layers;

m_num_kv_heads.assign(num_kv_heads.begin(), num_kv_heads.end());
m_key_cache_shape.reserve(m_num_decoder_layers);
m_value_cache_shape.reserve(m_num_decoder_layers);

if (m_device == "CPU") {
// Scale, zero point and quantized data will be stored together.
// The layout for per token per head:
Expand All @@ -104,21 +108,32 @@ class DeviceConfig {
}

if (m_num_kv_blocks == 0 && m_cache_size > 0) {
size_t block_size = 0;
size_t size_in_bytes = m_cache_size * 1024 * 1024 * 1024;
m_num_kv_blocks = size_in_bytes / (m_num_decoder_layers * 2 * m_num_kv_heads * m_block_size * m_head_size * m_kv_cache_type.size());
for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) {
block_size += 2 * m_num_kv_heads[layer_id] * m_block_size * m_head_size * m_kv_cache_type.size();
}
m_num_kv_blocks = size_in_bytes / block_size;
}

m_key_cache_shape = m_value_cache_shape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads),
ov::Dimension(m_block_size),
ov::Dimension(m_head_size)};

if (m_device.find("GPU") != std::string::npos) {
// Update key shape, as the key's shape is different from the value's shape
m_key_cache_shape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads),
ov::Dimension(m_head_size),
ov::Dimension(m_block_size)};
for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) {
m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads[layer_id]),
ov::Dimension(m_block_size),
ov::Dimension(m_head_size)});

m_value_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads[layer_id]),
ov::Dimension(m_block_size),
ov::Dimension(m_head_size)});

if (m_device.find("GPU") != std::string::npos) {
// Update key shape, as the key's shape is different from the value's shape
m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads[layer_id]),
ov::Dimension(m_head_size),
ov::Dimension(m_block_size)});
}
}
}

Expand All @@ -134,14 +149,14 @@ class DeviceConfig {
return m_num_decoder_layers;
}

ov::PartialShape get_key_cache_shape() const {
ov::PartialShape get_key_cache_shape(size_t id) const {
OPENVINO_ASSERT(m_key_cache_shape.size());
return m_key_cache_shape;
return m_key_cache_shape[id];
}

ov::PartialShape get_value_cache_shape() const {
ov::PartialShape get_value_cache_shape(size_t id) const {
OPENVINO_ASSERT(m_value_cache_shape.size());
return m_value_cache_shape;
return m_value_cache_shape[id];
}

size_t get_num_kv_blocks() const {
Expand All @@ -153,7 +168,11 @@ class DeviceConfig {
}

size_t get_block_size_in_bytes() const {
return m_num_decoder_layers * 2 * m_num_kv_heads * m_block_size * m_head_size * get_cache_precision().size();
size_t block_size = 0;
for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) {
block_size += 2 * m_num_kv_heads[layer_id] * m_block_size * m_head_size * get_cache_precision().size();
}
return block_size;
}
};
}
20 changes: 13 additions & 7 deletions src/cpp/src/utils/paged_attention_transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ void set_kv_cache_type_and_shape(std::shared_ptr<ov::Model> model, DeviceConfig&
OPENVINO_ASSERT(key_cache_params.count(key_cache_param_name) != 0, "key_cache.0 tensor not found among model parameters");
ov::PartialShape k_shape = key_cache_params[key_cache_param_name]->get_partial_shape();
OPENVINO_ASSERT(k_shape.rank().get_length() == 3, "KV cache shape is expected to have rank 3, while shape is ", k_shape);
size_t num_kv_heads = k_shape[1].get_length(), head_size = k_shape[2].get_length();

size_t head_size = k_shape[2].get_length();
std::vector<size_t> num_kv_heads(num_layers);
for (size_t idx = 0; idx < num_layers; idx++) {
size_t num_heads = key_cache_params[std::string("key_cache.") + std::to_string(idx)]->get_partial_shape()[1].get_length();
num_kv_heads[idx] = num_heads;
}
device_config.set_model_params(num_kv_heads, head_size, num_layers);

for (auto it_k = key_cache_params.begin(), it_v = value_cache_params.begin(); it_k != key_cache_params.end();++it_k, ++it_v) {
it_k->second->set_element_type(device_config.get_cache_precision());
it_v->second->set_element_type(device_config.get_cache_precision());
it_k->second->set_partial_shape(device_config.get_key_cache_shape());
it_v->second->set_partial_shape(device_config.get_value_cache_shape());
for (size_t idx = 0; idx < num_layers; idx++) {
auto k = key_cache_params[std::string("key_cache.") + std::to_string(idx)];
auto v = value_cache_params[std::string("value_cache.") + std::to_string(idx)];
k->set_element_type(device_config.get_cache_precision());
v->set_element_type(device_config.get_cache_precision());
k->set_partial_shape(device_config.get_key_cache_shape(idx));
v->set_partial_shape(device_config.get_value_cache_shape(idx));
}

model->validate_nodes_and_infer_types();
Expand Down
13 changes: 9 additions & 4 deletions tests/cpp/cache_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ TEST(TestCacheManager, test_cache_size_param) {
const std::string device = "CPU";
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
size_t num_decoder_layers = 12;
device_config.set_model_params(12, 64, num_decoder_layers);
std::vector<size_t> num_kv_heads(12, 12);
device_config.set_model_params(num_kv_heads, 64, num_decoder_layers);

ov::InferRequest request = core.compile_model(get_dummy_model(num_decoder_layers)).create_infer_request();
auto cache_manager = std::make_shared<ov::genai::CacheManager>(device_config, request, core);
Expand All @@ -76,7 +77,8 @@ TEST(TestCacheManager, test_kv_blocks_param) {
const std::string device = "CPU";
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
size_t num_decoder_layers = 12;
device_config.set_model_params(12, 64, num_decoder_layers);
std::vector<size_t> num_kv_heads(12, 12);
device_config.set_model_params(num_kv_heads, 64, num_decoder_layers);

ov::InferRequest request = core.compile_model(get_dummy_model(num_decoder_layers)).create_infer_request();
auto cache_manager = std::make_shared<ov::genai::CacheManager>(device_config, request, core);
Expand All @@ -97,9 +99,12 @@ TEST(TestCacheManager, test_dynamic_cache_increase) {
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
size_t num_decoder_layers = 12;
size_t head_size = 64;
size_t num_kv_heads = 12;
std::vector<size_t> num_kv_heads(12, 12);
device_config.set_model_params(num_kv_heads, head_size, num_decoder_layers);
size_t block_size_in_bytes = num_decoder_layers * 2 * num_kv_heads * device_config.get_block_size() * head_size * device_config.get_cache_precision().size();
size_t block_size_in_bytes = 0;
for (size_t layer_id = 0; layer_id < num_decoder_layers; layer_id++) {
block_size_in_bytes += 2 * num_kv_heads[layer_id] * device_config.get_block_size() * head_size * device_config.get_cache_precision().size();
}


ov::InferRequest request = core.compile_model(get_dummy_model(num_decoder_layers)).create_infer_request();
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/device_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ TEST(TestDeviceConfig, kv_cache_precision_u8) {
const std::string device = "CPU";
size_t num_decoder_layers = 12;
size_t head_size = 64, head_size_u8 = head_size + 8;
size_t num_kv_heads = 12;
std::vector<size_t> num_kv_heads(12, 12);

ov::genai::DeviceConfig device_config_default(core, scheduler_config, "CPU");
device_config_default.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers);
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::shared_ptr<CacheManager> init_cache_manager(SchedulerConfig scheduler_confi
size_t num_decoder_layers = 12;
ov::InferRequest request = core.compile_model(get_model(num_decoder_layers)).create_infer_request();
size_t head_size = 64, head_size_u8 = head_size + 8;
size_t num_kv_heads = 12;
std::vector<size_t> num_kv_heads(12, 12);
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
device_config.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers);
return std::make_shared<CacheManager>(device_config, request, core);
Expand Down
Loading