Skip to content

Commit

Permalink
Extend input tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 18, 2024
1 parent eff8595 commit 7932290
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 303 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ void regmodule_offline_transformations(py::module m) {

m_offline_transformations.def(
"paged_attention_transformation",
[](std::shared_ptr<ov::Model> model, bool use_block_indices_inputs, bool use_score_outputs) {
[](std::shared_ptr<ov::Model> model, bool use_block_indices_inputs, bool use_score_outputs, bool allow_cache_rotation) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs, use_score_outputs);
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs, use_score_outputs, allow_cache_rotation);
manager.run_passes(model);
},
py::arg("model"),
py::arg("use_block_indices_inputs") = false,
py::arg("use_score_outputs") = false);
py::arg("use_score_outputs") = false,
py::arg("allow_cache_rotation") = false);

m_offline_transformations.def(
"stateful_to_stateless_transformation",
Expand Down
62 changes: 62 additions & 0 deletions src/core/tests/type_prop/paged_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"
#include "openvino/op/paged_attention.hpp"
#include "openvino/openvino.hpp"
#include "openvino/opsets/opset13.hpp"

using namespace ov;
using namespace testing;

TEST(type_prop, paged_attention_static_13_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{3, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{3, 4});
const auto key_cache = std::make_shared<opset13::Parameter>(element::f32, Shape{6, 2, 5, 4});
const auto value_cache = std::make_shared<opset13::Parameter>(element::f32, Shape{6, 2, 5, 4});
const auto past_lens = std::make_shared<opset13::Parameter>(element::i32, Shape{5});
const auto subsequence_begins = std::make_shared<opset13::Parameter>(element::i32, Shape{5});
const auto block_indices = std::make_shared<opset13::Parameter>(element::i32, Shape{15});
const auto block_indices_begins = std::make_shared<opset13::Parameter>(element::i32, Shape{8});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
const auto sliding_window = std::make_shared<opset13::Parameter>(element::i32, Shape{});
const auto alibi_slopes = std::make_shared<opset13::Parameter>(element::f32, Shape{9});
const auto max_context_len = std::make_shared<opset13::Parameter>(element::i32, Shape{});


ov::OutputVector args = {query, key, value, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, scale, sliding_window, alibi_slopes, max_context_len};
const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (Shape{3, 4}));
}

TEST(type_prop, paged_attention_static_15_inputs) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, Shape{3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, Shape{3, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, Shape{3, 4});
const auto key_cache = std::make_shared<opset13::Parameter>(element::f32, Shape{6, 2, 5, 4});
const auto value_cache = std::make_shared<opset13::Parameter>(element::f32, Shape{6, 2, 5, 4});
const auto past_lens = std::make_shared<opset13::Parameter>(element::i32, Shape{5});
const auto subsequence_begins = std::make_shared<opset13::Parameter>(element::i32, Shape{5});
const auto block_indices = std::make_shared<opset13::Parameter>(element::i32, Shape{15});
const auto block_indices_begins = std::make_shared<opset13::Parameter>(element::i32, Shape{8});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, Shape{});
const auto sliding_window = std::make_shared<opset13::Parameter>(element::i32, Shape{});
const auto alibi_slopes = std::make_shared<opset13::Parameter>(element::f32, Shape{9});
const auto max_context_len = std::make_shared<opset13::Parameter>(element::i32, Shape{});

const auto rotation_coefficients = std::make_shared<opset13::Parameter>(element::f32, Shape{12});
const auto rotated_block_indices = std::make_shared<opset13::Parameter>(element::i32, Shape{3});

ov::OutputVector args = {query, key, value, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, scale, sliding_window, alibi_slopes, max_context_len, rotation_coefficients, rotated_block_indices};

const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4}));
}

Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main():

# wrapping in try/catch block to continue printing models even if one has failed
try:
paged_attention_transformation(model.model, use_cache_eviction, use_cache_eviction)
paged_attention_transformation(model.model, use_cache_eviction, use_cache_eviction, use_cache_eviction)
except:
continue

Expand All @@ -85,10 +85,12 @@ def main():
after_map[op.get_type_name()] = after_map.get(op.get_type_name(), 0) + 1

print(f'\t"{model_id}" : {{', file=file)
for op in set(after_map.keys()) | set(before_map.keys()):
for op in sorted(set(after_map.keys()) | set(before_map.keys())):
print(f'\t\t"{op}" : {after_map.get(op, 0) - before_map.get(op, 0)},', file=file)
print('\t},', file=file)
print('}', file=file)

print(f"output written to {OUTPUT_FILE}")

if __name__ == "__main__":
main()
main()
Loading

0 comments on commit 7932290

Please sign in to comment.