From 82b44fab5b538ef9e11ff47fcd245f7885c1a25f Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Fri, 27 Dec 2024 07:47:50 +0400 Subject: [PATCH] LLM tests restructuring (#1440) - Merged chat scenario tests to test_llm_pipeline.py - Created CB dedicated test_continuous_batching.py file with CB-specific tests (in addition to test_llm_pipeline.py, which cover basic LLM pipeline functionality) CVS-159921 --- .github/labeler.yml | 29 +- .github/workflows/linux.yml | 4 +- .github/workflows/mac.yml | 8 +- .github/workflows/windows.yml | 8 +- src/cpp/src/llm_pipeline.cpp | 12 +- tests/python_tests/common.py | 14 +- tests/python_tests/ov_genai_test_utils.py | 29 +- tests/python_tests/test_chat_generate_api.py | 118 -------- ...emption.py => test_continuous_batching.py} | 165 ++++++++++- ...mizations.py => test_kv_cache_eviction.py} | 4 +- ...t_generate_api.py => test_llm_pipeline.py} | 273 ++++++++++-------- .../python_tests/test_llm_pipeline_static.py | 2 +- tests/python_tests/test_sampling.py | 140 +++------ .../{test_vlm_api.py => test_vlm_pipeline.py} | 0 ...nerate_api.py => test_whisper_pipeline.py} | 0 15 files changed, 418 insertions(+), 388 deletions(-) delete mode 100644 tests/python_tests/test_chat_generate_api.py rename tests/python_tests/{test_preemption.py => test_continuous_batching.py} (62%) rename tests/python_tests/{test_cache_optimizations.py => test_kv_cache_eviction.py} (98%) rename tests/python_tests/{test_generate_api.py => test_llm_pipeline.py} (87%) rename tests/python_tests/{test_vlm_api.py => test_vlm_pipeline.py} (100%) rename tests/python_tests/{test_whisper_generate_api.py => test_whisper_pipeline.py} (100%) diff --git a/.github/labeler.yml b/.github/labeler.yml index c162f6aff4..f618bdb7fc 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -13,17 +13,20 @@ - 'src/python/py_tokenizer.cpp' - 'thirdparty/openvino_tokenizers' - 'tests/python_tests/tokenizer_configs.py' +- 'tests/python_tests/test_tokenizer.py' 'category: LLM': - 'src/cpp/include/openvino/genai/llm_pipeline.hpp' - 'src/cpp/src/llm_pipeline.cpp' +- 'src/cpp/src/lm_encoding.hpp' - 'src/cpp/src/lm_encoding.cpp' - 'src/cpp/src/llm_pipeline_base.hpp' - 'src/cpp/src/llm_pipeline_static.hpp' - 'src/cpp/src/llm_pipeline_static.cpp' +- 'src/cpp/src/text_callback_streamer.cpp' +- 'src/cpp/src/text_callback_streamer.hpp' - 'src/python/py_llm_pipeline.cpp' -- 'tests/python_tests/test_generate_api.py' -- 'tests/python_tests/test_chat_generate_api.py' +- 'tests/python_tests/test_llm_pipeline.py' 'category: sampling': - 'src/cpp/include/openvino/genai/generation_config.hpp' @@ -35,6 +38,7 @@ - 'tests/cpp/logit_filtering.cpp' - 'tests/cpp/generate_config.cpp' - 'tests/cpp/sampler.cpp' +- 'tests/python_tests/test_sampling.py' 'category: LoRA': - 'src/cpp/include/openvino/genai/lora_adapter.hpp' @@ -54,9 +58,12 @@ - 'src/cpp/include/openvino/genai/whisper_pipeline.hpp' - 'src/cpp/src/whisper/**/*' - 'src/cpp/src/whisper_generation_config.cpp' +- 'src/cpp/src/whisper_pipeline_base.hpp' - 'src/cpp/src/whisper_pipeline.cpp' +- 'src/cpp/src/whisper_pipeline_static.cpp' +- 'src/cpp/src/whisper_pipeline_static.hpp' - 'src/python/py_whisper_pipeline.cpp' -- 'tests/python_tests/test_whisper_generate_api.py' +- 'tests/python_tests/test_whisper_pipeline.py' 'category: Python API': - 'src/python/**/*' @@ -65,10 +72,14 @@ - 'src/include/openvino/genai/visual_language/**/*' - 'src/cpp/src/visual_language/**/*' - 'src/python/py_vlm_pipeline.cpp' -- 'tests/python_tests/test_vlm_api.py' +- 'tests/python_tests/test_vlm_pipeline.py' 'category: speculative decoding': - 'src/cpp/src/speculative_decoding/**/*' +- 'tests/cpp/speculative_decoding.cpp' + +'category: prompt lookup': +- 'src/cpp/src/prompt_lookup/**/*' 'category: continuous batching': - 'src/cpp/include/openvino/genai/cache_eviction.hpp' @@ -91,19 +102,19 @@ - 'src/cpp/src/generation_handle.cpp' - 'src/cpp/src/generation_stream.hpp' - 'src/cpp/src/model_runner.hpp' -- 'src/cpp/src/paged_attention_transformations.cpp' -- 'src/cpp/src/paged_attention_transformations.hpp' +- 'src/cpp/src/utils/paged_attention_transformations.cpp' +- 'src/cpp/src/utils/paged_attention_transformations.hpp' - 'src/cpp/src/scheduler.hpp' - 'src/cpp/src/sequence_group.cpp' - 'src/cpp/src/sequence_group.hpp' - 'src/cpp/src/timer.hpp' - 'src/python/py_continuous_batching_pipeline.cpp' -- 'tests/python_tests/test_cache_optimizations.py' -- 'tests/python_tests/test_preemption.py' -- 'tests/python_tests/test_sampling.py' +- 'tests/python_tests/test_continuous_batching.py' +- 'tests/python_tests/test_kv_cache_eviction.py' - 'tests/cpp/block_allocator.cpp' - 'tests/cpp/block_hash_store.cpp' - 'tests/cpp/block_manager.cpp' +- 'tests/cpp/cache_eviction.cpp' - 'tests/cpp/cache_manager.cpp' - 'tests/cpp/device_config.cpp' - 'tests/cpp/scheduler.cpp' diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 6c94a907ea..9b21491f9b 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -268,9 +268,9 @@ jobs: matrix: test: - name: 'Whisper' - cmd: 'tests/python_tests/test_whisper_generate_api.py' + cmd: 'tests/python_tests/test_whisper_pipeline.py' - name: 'LLM & VLM' - cmd: 'tests/python_tests --ignore tests/python_tests/test_whisper_generate_api.py' + cmd: 'tests/python_tests --ignore tests/python_tests/test_whisper_pipeline.py' defaults: run: shell: bash diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index a9af13bc66..4d9b7f032b 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -178,7 +178,7 @@ jobs: if: | always() && (needs.openvino_download.outputs.status == 'success' || needs.openvino_build.result == 'success') - timeout-minutes: 90 + timeout-minutes: 120 defaults: run: shell: bash @@ -235,7 +235,7 @@ jobs: python -m pip install . --verbose --find-links ${OV_INSTALL_DIR}/wheels python -c "from openvino_genai import LLMPipeline" python -m pip install ./tools/who_what_benchmark --find-links ${OV_INSTALL_DIR}/wheels - python -m pytest -v ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" + python -m pytest -v ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_pipeline.py --ignore ./tests/python_tests/test_vlm_pipeline.py -k "not test_set_chat_template" genai_python_lib_whisper: name: OpenVINO genai extension whisper tests (cmake + wheel) @@ -290,7 +290,7 @@ jobs: run: | source ${OV_INSTALL_DIR}/setupvars.sh python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${OV_INSTALL_DIR}/wheels - python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k test_smoke + python -m pytest -v ./tests/python_tests/test_whisper_pipeline.py -k test_smoke env: PYTHONPATH: "./build/:$PYTHONPATH" @@ -300,7 +300,7 @@ jobs: python -m pip install . --verbose --find-links ${OV_INSTALL_DIR}/wheels python -c "from openvino_genai import LLMPipeline" python -m pip install ./tools/who_what_benchmark --find-links ${OV_INSTALL_DIR}/wheels - python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" + python -m pytest -v ./tests/python_tests/test_whisper_pipeline.py -k "not test_smoke" genai_package: name: OpenVINO genai extension (install to OpenVINO package) diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index f88bc4c6f3..fc63129281 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -245,7 +245,7 @@ jobs: . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install . --verbose --find-links ${env:OV_INSTALL_DIR}/wheels python -m pip install ./tools/who_what_benchmark --find-links ${env:OV_INSTALL_DIR}/wheels - python -m pytest -v ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_generate_api.py --ignore ./tests/python_tests/test_vlm_api.py -k "not test_set_chat_template" + python -m pytest -v ./tests/python_tests/ --ignore ./tests/python_tests/test_whisper_pipeline.py --ignore ./tests/python_tests/test_vlm_pipeline.py -k "not test_set_chat_template" genai_python_lib_whisper: name: OpenVINO genai extension whisper tests (cmake + wheel) @@ -301,7 +301,7 @@ jobs: run: | . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/wheels - python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k test_smoke + python -m pytest -v ./tests/python_tests/test_whisper_pipeline.py -k test_smoke env: PYTHONPATH: "./build/" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. @@ -310,7 +310,7 @@ jobs: . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install . --verbose --find-links ${env:OV_INSTALL_DIR}/wheels python -m pip install ./tools/who_what_benchmark --find-links ${env:OV_INSTALL_DIR}/wheels - python -m pytest -v ./tests/python_tests/test_whisper_generate_api.py -k "not test_smoke" + python -m pytest -v ./tests/python_tests/test_whisper_pipeline.py -k "not test_smoke" genai_python_lib_vlm: name: OpenVINO genai VLM tests (cmake + wheel) @@ -366,7 +366,7 @@ jobs: run: | . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/wheels - python -m pytest -v ./tests/python_tests/test_vlm_api.py + python -m pytest -v ./tests/python_tests/test_vlm_pipeline.py env: PYTHONPATH: "./build/" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that. diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index be5ecf17fa..5e448fe88c 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -703,8 +703,7 @@ std::pair split_model_descr(const ov::An ov::genai::LLMPipeline::LLMPipeline( const ov::InferRequest& request, const ov::genai::Tokenizer& tokenizer, - OptionalGenerationConfig generation_config -) { + OptionalGenerationConfig generation_config) { auto start_time = std::chrono::steady_clock::now(); m_pimpl = std::make_unique(request, tokenizer, generation_config); auto stop_time = std::chrono::steady_clock::now(); @@ -715,8 +714,7 @@ ov::genai::LLMPipeline::LLMPipeline( const std::filesystem::path& models_path, const ov::genai::Tokenizer& tokenizer, const std::string& device, - const ov::AnyMap& properties -){ + const ov::AnyMap& properties) { auto start_time = std::chrono::steady_clock::now(); if (properties.find(ov::genai::scheduler_config.name()) != properties.end() || properties.find(utils::DRAFT_MODEL_ARG_NAME) != properties.end() || @@ -735,8 +733,7 @@ ov::genai::LLMPipeline::LLMPipeline( ov::genai::LLMPipeline::LLMPipeline( const std::filesystem::path& models_path, const std::string& device, - const ov::AnyMap& config -){ + const ov::AnyMap& config) { auto start_time = std::chrono::steady_clock::now(); if (config.find(ov::genai::scheduler_config.name()) != config.end() || @@ -759,8 +756,7 @@ ov::genai::LLMPipeline::LLMPipeline( const ov::genai::Tokenizer& tokenizer, const std::string& device, const ov::AnyMap& config, - const ov::genai::GenerationConfig& generation_config -){ + const ov::genai::GenerationConfig& generation_config) { auto [core_properties, plugin_config] = ov::genai::utils::split_core_compile_config(config); auto start_time = std::chrono::steady_clock::now(); diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 7e3c075405..f940d272ed 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -364,18 +364,6 @@ def run_continuous_batching( return output -def read_models_list(file_name: str): - models = [] - with open(file_name) as f: - for model_name in f: - model_name = model_name.strip() - # skip comment in model scope file - if model_name.startswith('#'): - continue - models.append(model_name) - return models - - def compare_results(hf_result: GenerationResult, ov_result: GenerationResult, generation_config: GenerationConfig): if generation_config.is_beam_search(): assert len(hf_result.m_scores) == len(ov_result.m_scores) @@ -447,7 +435,7 @@ def generate_and_compare_with_reference_text(models_path: Path, prompts: List[st assert ref_text == ov_text -def run_test_pipeline(tmp_path: str, model_id: str, scheduler_params: dict = None, generation_config = None): +def run_continuous_batching_pipeline_test(tmp_path: str, model_id: str, scheduler_params: dict = None, generation_config = None): prompts, generation_configs = get_test_dataset() scheduler_config = get_scheduler_config(scheduler_params) diff --git a/tests/python_tests/ov_genai_test_utils.py b/tests/python_tests/ov_genai_test_utils.py index 87b2147bcd..3fc89cb8a7 100644 --- a/tests/python_tests/ov_genai_test_utils.py +++ b/tests/python_tests/ov_genai_test_utils.py @@ -32,7 +32,7 @@ def get_models_list(): "HuggingFaceH4/zephyr-7b-beta", "ikala/redpajama-3b-chat", "mistralai/Mistral-7B-v0.1", - + # "meta-llama/Llama-2-7b-chat-hf", # Cannot be downloaded without access token # "google/gemma-2b-it", # Cannot be downloaded without access token. # "google/gemma-7b-it", # Cannot be downloaded without access token. @@ -49,7 +49,7 @@ def get_models_list(): model_ids = precommit_models else: model_ids = nightly_models - + if pytest.selected_model_ids: model_ids = [model_id for model_id in model_ids if model_id in pytest.selected_model_ids.split(' ')] # pytest.set_trace() @@ -82,30 +82,30 @@ def get_chat_models_list(): @functools.lru_cache(1) def read_model(params, **tokenizer_kwargs): model_id, path = params - + from optimum.intel.openvino import OVModelForCausalLM from transformers import AutoTokenizer hf_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) if (path / "openvino_model.xml").exists(): - opt_model = OVModelForCausalLM.from_pretrained(path, trust_remote_code=True, + opt_model = OVModelForCausalLM.from_pretrained(path, trust_remote_code=True, compile=False, device='CPU') else: - ov_tokenizer, ov_detokenizer = openvino_tokenizers.convert_tokenizer(hf_tokenizer, + ov_tokenizer, ov_detokenizer = openvino_tokenizers.convert_tokenizer(hf_tokenizer, with_detokenizer=True, **tokenizer_kwargs) openvino.save_model(ov_tokenizer, path / "openvino_tokenizer.xml") openvino.save_model(ov_detokenizer, path / "openvino_detokenizer.xml") - + # to store tokenizer config jsons with special tokens hf_tokenizer.save_pretrained(path) - - opt_model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, + + opt_model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, compile=False, device='CPU', load_in_8bit=False) opt_model.generation_config.save_pretrained(path) opt_model.config.save_pretrained(path) opt_model.save_pretrained(path) - + return ( model_id, path, @@ -116,11 +116,11 @@ def read_model(params, **tokenizer_kwargs): # in OpenVINO GenAI this parameter is called stop_criteria, -# while in HF it's called early_stopping. +# while in HF it's called early_stopping. # HF values True, False and "never" correspond to OV GenAI values "EARLY", "HEURISTIC" and "NEVER" STOP_CRITERIA_MAP = { - ov_genai.StopCriteria.NEVER: "never", - ov_genai.StopCriteria.EARLY: True, + ov_genai.StopCriteria.NEVER: "never", + ov_genai.StopCriteria.EARLY: True, ov_genai.StopCriteria.HEURISTIC: False } @@ -137,6 +137,7 @@ def model_tmp_path(tmpdir_factory): shutil.copy(src_file, temp_path / src_file.name) yield model_id, Path(temp_path) + @pytest.fixture(scope="module") def model_tokenizers_path_tmp_path(tmpdir_factory): model_id, path, _, _, _ = read_model(get_models_list()[0]) @@ -146,7 +147,7 @@ def model_tokenizers_path_tmp_path(tmpdir_factory): # There was no easy way to add tokens to IR in tests, so we remove them # and set tokens in configs and to check if they are read and validated correctly. import openvino as ov - + # copy openvino converted model and tokenizers for pattern in ['*.xml', '*.bin']: for src_file in path.glob(pattern): @@ -162,7 +163,7 @@ def model_tokenizers_path_tmp_path(tmpdir_factory): ov_model.set_rt_info("eos_token_id", "") ov_model.set_rt_info("chat_template", "") ov.save_model(ov_model, str(temp_path / src_file.name)) - + if src_file in ['openvino_tokenizer.bin', 'openvino_detokenizer.bin']: continue if src_file.is_file(): diff --git a/tests/python_tests/test_chat_generate_api.py b/tests/python_tests/test_chat_generate_api.py deleted file mode 100644 index 07b4f7c15f..0000000000 --- a/tests/python_tests/test_chat_generate_api.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (C) 2023-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import openvino_genai as ov_genai -import pytest -from typing import Dict, Tuple - -from ov_genai_test_utils import ( - get_chat_models_list, - read_model, - get_continuous_batching, -) - - -generation_configs = [ - dict(do_sample=False, max_new_tokens=20), - dict(do_sample=False, num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0) -] - - -questions = [ - '1+1=', - 'What is the previous answer?', - 'Why is the Sun yellow?', - 'What was my first question?' -] - - -@pytest.mark.parametrize("generation_config", generation_configs) -@pytest.mark.parametrize("model_descr", get_chat_models_list()) -@pytest.mark.precommit -@pytest.mark.nightly -def test_chat_compare_with_HF(model_descr, generation_config: Dict): - chat_history_hf = [] - chat_history_ov = [] - chat_prompt = '' - - # Will set add_special_tokens=False inside pipeline when start_chat() is called. - model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) - - pipe.start_chat() - for prompt in questions: - chat_history_hf.append({'role': 'user', 'content': prompt}) - chat_history_ov.append({'role': 'user', 'content': prompt}) - - chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False) - - answer = model_opt.generate(**tokenized, **generation_config) - answer_str = tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True) - chat_history_hf.append({'role': 'assistant', 'content': answer_str}) - - answer_ov = pipe.generate(prompt, **generation_config) - chat_history_ov.append({'role': 'assistant', 'content': answer_ov}) - - pipe.finish_chat() - - if chat_history_ov != chat_history_hf: - print(f'hf_output: {chat_history_hf}') - print(f'ov_output: {chat_history_ov}') - - assert chat_history_ov == chat_history_hf - - -@pytest.mark.parametrize("generation_config", generation_configs) -@pytest.mark.parametrize("model_descr", get_chat_models_list()) -@pytest.mark.precommit -@pytest.mark.nightly -def test_chat_compare_text_history_with_HF(model_descr, generation_config: Dict): - # compares with HF when history in ov_genai is save as a text - chat_history_hf = [] - chat_history_ov = [] - chat_prompt = '' - - # HF in chat scenario does not add special tokens, but openvino tokenizer by default is converted with add_special_tokens=True. - # Need to regenerate openvino_tokenizer/detokenizer. - model_id, path, hf_tokenizer, model_opt, ov_pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'), add_special_tokens=False) - ov_tokenizer = ov_pipe.get_tokenizer() - - for prompt in questions: - chat_history_hf.append({'role': 'user', 'content': prompt}) - chat_history_ov.append({'role': 'user', 'content': prompt}) - - chat_prompt = hf_tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True) - tokenized = hf_tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False) - - answer = model_opt.generate(**tokenized, **generation_config) - answer_str = hf_tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True) - chat_history_hf.append({'role': 'assistant', 'content': answer_str}) - - chat_prompt = ov_tokenizer.apply_chat_template(chat_history_ov, add_generation_prompt=True) - answer_ov = ov_pipe.generate(chat_prompt, **generation_config) - chat_history_ov.append({'role': 'assistant', 'content': answer_ov}) - - if chat_history_ov != chat_history_hf: - print(f'hf_output: {chat_history_hf}') - print(f'ov_output: {chat_history_ov}') - - assert chat_history_ov == chat_history_hf - - -@pytest.mark.parametrize("generation_config", generation_configs[1:]) -@pytest.mark.parametrize("model_descr", get_chat_models_list()) -@pytest.mark.precommit -def test_chat_continuous_batching_vs_stateful(model_descr, generation_config: Dict): - model_id, path, hf_tokenizer, opt_model, ov_stateful_pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) - cb_pipe = get_continuous_batching(path) - - ov_stateful_pipe.start_chat() - cb_pipe.start_chat() - - for question in questions: - generated = cb_pipe.generate(question, **generation_config) - reference = ov_stateful_pipe.generate(question, **generation_config) - assert generated == reference - - # Test that finish_chat() doesn't fail just in case. - cb_pipe.finish_chat() diff --git a/tests/python_tests/test_preemption.py b/tests/python_tests/test_continuous_batching.py similarity index 62% rename from tests/python_tests/test_preemption.py rename to tests/python_tests/test_continuous_batching.py index 7c648e73dc..3a1e9fa092 100644 --- a/tests/python_tests/test_preemption.py +++ b/tests/python_tests/test_continuous_batching.py @@ -1,15 +1,172 @@ # Copyright (C) 2018-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import os import pytest +import math +from typing import Dict + +from pathlib import Path +from openvino_genai import ContinuousBatchingPipeline, GenerationConfig, Tokenizer -from openvino_genai import GenerationConfig from common import get_hugging_face_model_and_tokenizer, save_ov_model_from_optimum, generate_and_compare_with_reference_text, \ - get_scheduler_config, run_test_pipeline, get_beam_search, get_greedy, \ + get_scheduler_config, get_greedy, run_continuous_batching_pipeline_test, get_beam_search, get_greedy, \ get_multinomial_all_parameters, get_multinomial_temperature_and_num_return_sequence, \ get_multinomial_temperature_and_top_k, get_multinomial_temperature, get_multinomial_temperature_and_top_p from test_sampling import RandomSamplingTestStruct, get_current_platform_ref_texts +from ov_genai_test_utils import ( + get_chat_models_list, + read_model, + get_continuous_batching, +) + +def read_models_list(file_name: str): + models = [] + with open(file_name) as f: + for model_name in f: + model_name = model_name.strip() + # skip comment in model scope file + if model_name.startswith('#'): + continue + models.append(model_name) + return models + +# +# e2e tests on random and real models +# + +@pytest.mark.precommit +@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "precommit"))) +def test_e2e_precommit(tmp_path, model_id): + run_continuous_batching_pipeline_test(tmp_path, model_id) + + +@pytest.mark.nightly +@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "nightly"))) +def test_e2e_nightly(tmp_path, model_id): + run_continuous_batching_pipeline_test(tmp_path, model_id) + + +@pytest.mark.real_models +@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "real_models"))) +def test_e2e_real_models(tmp_path, model_id): + run_continuous_batching_pipeline_test(tmp_path, model_id) + +# +# Comparison with stateful +# TODO: remove these tests once test_llm_pipeline.py are generalized and parametrized to test both Stateful and PA paths +# + +test_configs = [ + dict(max_new_tokens=20), + dict(max_new_tokens=200, ignore_eos=True), + dict(max_new_tokens=20, num_beam_groups=3, num_beams=15, diversity_penalty=1.0) +] +batched_prompts = [ + ['table is made', 'They sky is blue because', 'Difference between Jupiter and Mars is that'], + ['hello', 'Here is the longest nowel ever: '], + ['Alan Turing was a', 'return 0', '你好! 你好嗎?'], + ['table is made', 'table is made [force left pad tokens]'] +] +@pytest.mark.parametrize("generation_config", test_configs) +@pytest.mark.parametrize("prompt", batched_prompts[1:]) # num_beams=15 diverges on the first prompt. +@pytest.mark.precommit +def test_continuous_batching_vs_stateful(prompt, generation_config): + model_id, path, tokenizer, model, stateful = read_model(( + "facebook/opt-125m", + Path("opt-125m") + )) + cb = get_continuous_batching(path) + generated = cb.generate(prompt, **generation_config) + reference = stateful.generate(prompt, **generation_config) + assert generated.texts == reference.texts + if 1 != generation_config.get("num_return_sequences", 1): + # Stateful puts zeroes to generated.scores. Don't compare them. + for gen, ref in zip(generated.scores, reference.scores): + assert math.isclose(gen, ref, abs_tol=0.0003) + + +prompts = ['The Sun is yellow because', 'Difference between Jupiter and Mars is that', 'table is made of'] +@pytest.mark.parametrize("prompt", prompts) +@pytest.mark.precommit +def test_cb_streamer_vs_return_vs_stateful(prompt): + model_id, path, hf_tokenizer, opt_model, ov_pipe = read_model(( + "facebook/opt-125m", + Path("opt-125m") + )) + cb_pipe = get_continuous_batching(path) + streamed = [] + generated = cb_pipe.generate(prompt, max_new_tokens=20, streamer=lambda subword: streamed.append(subword)) + reference = ov_pipe.generate(prompt, max_new_tokens=20) + assert generated == "".join(streamed) + assert "".join(streamed) == reference + + +generation_configs = [ + dict(do_sample=False, max_new_tokens=20), + dict(do_sample=False, num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0) +] +questions = [ + '1+1=', + 'What is the previous answer?', + 'Why is the Sun yellow?', + 'What was my first question?' +] +@pytest.mark.parametrize("generation_config", generation_configs[1:]) +@pytest.mark.parametrize("model_descr", get_chat_models_list()) +@pytest.mark.precommit +def test_chat_scenario_vs_stateful(model_descr, generation_config: Dict): + model_id, path, hf_tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) + cb_pipe = get_continuous_batching(path) + + ov_pipe.start_chat() + cb_pipe.start_chat() + + for question in questions: + generated = cb_pipe.generate(question, **generation_config) + reference = ov_pipe.generate(question, **generation_config) + assert generated == reference + + # Test that finish_chat() doesn't fail just in case. + cb_pipe.finish_chat() + +# +# Stress tests to check OOM case +# + +@pytest.mark.precommit +@pytest.mark.parametrize("sampling_config", [get_greedy(), get_beam_search(), get_multinomial_all_parameters()], + ids=["greedy", "beam_search", "multinomial_all_parameters"]) +def test_post_oom_health(tmp_path, sampling_config): + generation_config = sampling_config + generation_config.ignore_eos = True + generation_config.max_new_tokens = 1000000 + + scheduler_config = get_scheduler_config() + scheduler_config.num_kv_blocks = 10 # Low cache size to trigger OOM quickly + + model_id : str = "facebook/opt-125m" + opt_model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True) + + models_path : Path = tmp_path / model_id + save_ov_model_from_optimum(opt_model, hf_tokenizer, models_path) + + cb_pipe = ContinuousBatchingPipeline(models_path, Tokenizer(models_path), scheduler_config, "CPU") + + # First run should return incomplete response + output = cb_pipe.generate(["What is OpenVINO?"], [generation_config]) + assert (len(output)) + assert (len(output[0].m_generation_ids)) + + # Same for the second run, here we want to make sure the cleanup works and we have free blocks after recent OOM + output = cb_pipe.generate(["What is OpenVINO?"], [generation_config]) + assert (len(output)) + assert (len(output[0].m_generation_ids)) + +# +# Pre-emption +# def get_greedy_seq_len_300() -> GenerationConfig: generation_config = GenerationConfig() @@ -36,7 +193,7 @@ def get_beam_search_seq_len_300() -> GenerationConfig: @pytest.mark.parametrize("params", scheduler_params_list) @pytest.mark.precommit def test_preemption(tmp_path, params): - run_test_pipeline(tmp_path, "facebook/opt-125m", params[0], params[1]) + run_continuous_batching_pipeline_test(tmp_path, "facebook/opt-125m", scheduler_params=params[0], generation_config=params[1]) multinomial_params = RandomSamplingTestStruct( @@ -175,4 +332,4 @@ def test_preemption_with_multinomial_n_seq(tmp_path, dynamic_split_fuse): # needed kv_blocks - 16 (2 blocks per sequence (30 tokens to generated text + prompt (> 2 tokens)) * (1 + 3 + 4) seq ) scheduler_config = get_scheduler_config({"num_kv_blocks": 8, "dynamic_split_fuse": dynamic_split_fuse, "max_num_batched_tokens": 256, "max_num_seqs": 256}) - generate_and_compare_with_reference_text(models_path, multinomial_params_n_seq.prompts, multinomial_params_n_seq.ref_texts, generation_configs, scheduler_config) \ No newline at end of file + generate_and_compare_with_reference_text(models_path, multinomial_params_n_seq.prompts, multinomial_params_n_seq.ref_texts, generation_configs, scheduler_config) diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_kv_cache_eviction.py similarity index 98% rename from tests/python_tests/test_cache_optimizations.py rename to tests/python_tests/test_kv_cache_eviction.py index d89697ba42..bbd0da6bb2 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_kv_cache_eviction.py @@ -15,7 +15,7 @@ from openvino import serialize from transformers import AutoTokenizer -from common import TESTS_ROOT, run_test_pipeline +from common import TESTS_ROOT, run_continuous_batching_pipeline_test def load_prompts_dataset(file_name : str) -> Dict[str, List[str]]: @@ -168,5 +168,5 @@ def get_beam_search_seq_len_300() -> GenerationConfig: @pytest.mark.parametrize("params", scheduler_params_list) @pytest.mark.precommit def test_dynamic_memory_allocation(tmp_path, params): - run_test_pipeline(tmp_path, "facebook/opt-125m", params[0], params[1]) + run_continuous_batching_pipeline_test(tmp_path, "facebook/opt-125m", params[0], params[1]) diff --git a/tests/python_tests/test_generate_api.py b/tests/python_tests/test_llm_pipeline.py similarity index 87% rename from tests/python_tests/test_generate_api.py rename to tests/python_tests/test_llm_pipeline.py index 824a3cca26..9f00996a58 100644 --- a/tests/python_tests/test_generate_api.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -12,11 +12,12 @@ import torch import math from ov_genai_test_utils import ( - get_models_list, - read_model, + get_models_list, + read_model, load_genai_pipe_with_configs, - model_tmp_path, - STOP_CRITERIA_MAP, + get_chat_models_list, + model_tmp_path, + STOP_CRITERIA_MAP, get_continuous_batching, ) @@ -26,12 +27,12 @@ def run_hf_ov_genai_comparison_batched(model_descr, generation_config: Dict, pro config = generation_config.copy() # to avoid side effects num_beams = config['num_beams'] if 'num_beams' in config else 1 config['num_return_sequences'] = num_beams - + if not isinstance(prompts, list): prompts = [prompts] if 'do_sample' not in config: - # Some HF models have default do_sample = True, and if we set beam search generation config + # Some HF models have default do_sample = True, and if we set beam search generation config # it conflicts with `diversity_penalty` and/or `num_beam_groups`. # Need to set explicitly to False, but only if test arguments omitted this arg. # Do not apply 'repetition_penalty' if sampling is not used. @@ -72,7 +73,7 @@ def run_hf_ov_genai_comparison_text_inputs(model_descr, generation_config: Dict, config = generation_config.copy() # to avoid side effects if 'do_sample' not in config: - # Some HF models have default do_sample = True, and if we set beam search generation config + # Some HF models have default do_sample = True, and if we set beam search generation config # it conflicts with `diversity_penalty` and/or `num_beam_groups`. # Need to set explicitly to False, but only if test arguments omitted this arg. # Do not apply 'repetition_penalty' if sampling is not used. @@ -101,9 +102,9 @@ def run_hf_ov_genai_comparison_text_inputs(model_descr, generation_config: Dict, def run_hf_ov_genai_comparison_encoded_inputs( - model_descr, - generation_config: Dict, - input_ids: np.ndarray, + model_descr, + generation_config: Dict, + input_ids: np.ndarray, attention_mask: Optional[np.array] = None ): device = 'CPU' @@ -112,18 +113,18 @@ def run_hf_ov_genai_comparison_encoded_inputs( config = generation_config.copy() # to avoid side effects if 'do_sample' not in config: - # Some HF models have default do_sample = True, and if we set beam search generation config + # Some HF models have default do_sample = True, and if we set beam search generation config # it conflicts with `diversity_penalty` and/or `num_beam_groups`. # Need to set explicitly to False, but only if test arguments omitted this arg. # Do not apply 'repetition_penalty' if sampling is not used. config['do_sample'] = False config['repetition_penalty'] = 1.0 # 1.0 means no penalty - + generation_config_hf = config.copy() if generation_config_hf.get('stop_criteria'): generation_config_hf['early_stopping'] = STOP_CRITERIA_MAP[generation_config_hf.pop('stop_criteria')] generation_config_hf.pop('ignore_eos', None) - + if attention_mask is not None: inputs_ov = ov_genai.TokenizedInputs(ov.Tensor(input_ids), ov.Tensor(attention_mask)) inputs_hf = dict(inputs=torch.tensor(input_ids), attention_mask=torch.tensor(attention_mask)) @@ -138,6 +139,9 @@ def run_hf_ov_genai_comparison_encoded_inputs( ov_res = np.array(ov_output.tokens, dtype=np.int64) assert np.all(ov_res == hf_res) +# +# e2e work +# test_cases = [ (dict(max_new_tokens=20), 'table is made of'), @@ -197,14 +201,13 @@ def test_batch_text_input(model_descr, generation_config, prompts): @pytest.mark.parametrize("model_descr", get_models_list()) @pytest.mark.precommit @pytest.mark.nightly -def test_beam_search_decoding(model_descr, num_beam_groups, group_size, - max_new_tokens, diversity_penalty, prompt): +def test_beam_search_decoding(model_descr, num_beam_groups, group_size, max_new_tokens, diversity_penalty, prompt): generation_config = dict( - num_beam_groups=num_beam_groups, - num_beams=num_beam_groups * group_size, - diversity_penalty=diversity_penalty, - num_return_sequences=num_beam_groups * group_size, - max_new_tokens=max_new_tokens, + num_beam_groups=num_beam_groups, + num_beams=num_beam_groups * group_size, + diversity_penalty=diversity_penalty, + num_return_sequences=num_beam_groups * group_size, + max_new_tokens=max_new_tokens, ) run_hf_ov_genai_comparison_text_inputs(read_model(model_descr), generation_config, prompt) @@ -215,17 +218,17 @@ def test_beam_search_decoding(model_descr, num_beam_groups, group_size, @pytest.mark.parametrize("model_descr", get_models_list()) @pytest.mark.precommit @pytest.mark.nightly -def test_stop_criteria(model_descr, stop_criteria, prompt, max_new_tokens): +def test_beam_search_stop_criteria(model_descr, stop_criteria, prompt, max_new_tokens): # todo: with EARLY stop_criteria looks like HF return invalid out with sentence # while genai ends sentence with if (stop_criteria == StopCriteria.EARLY): pytest.skip() generation_config = dict( - num_beam_groups=2, - num_beams=2 * 3, - diversity_penalty=1.0, - num_return_sequences=2 * 3, - max_new_tokens=max_new_tokens, + num_beam_groups=2, + num_beams=2 * 3, + diversity_penalty=1.0, + num_return_sequences=2 * 3, + max_new_tokens=max_new_tokens, stop_criteria=stop_criteria, ) run_hf_ov_genai_comparison_text_inputs(read_model(model_descr), generation_config, prompt) @@ -241,11 +244,11 @@ def test_stop_criteria(model_descr, stop_criteria, prompt, max_new_tokens): def test_beam_search_long_sentences(model_descr, num_beam_groups, group_size, max_new_tokens, prompt): generation_config = dict( - num_beam_groups=num_beam_groups, - num_beams=num_beam_groups * group_size, - diversity_penalty=1.0, - num_return_sequences=num_beam_groups * group_size, - max_new_tokens=max_new_tokens, + num_beam_groups=num_beam_groups, + num_beams=num_beam_groups * group_size, + diversity_penalty=1.0, + num_return_sequences=num_beam_groups * group_size, + max_new_tokens=max_new_tokens, ) run_hf_ov_genai_comparison_text_inputs(read_model(model_descr), generation_config, prompt) @@ -283,6 +286,72 @@ def test_greedy_repetition_penalty(model_descr, prompt): assert(len(set(ov_output.split(' '))) > len(set(ov_output_half_penalty.split(' ')))) +@pytest.mark.precommit +@pytest.mark.nightly +def test_batch_size_switch(): + ov_pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4] + ov_pipe.generate(["a"], max_new_tokens=2) + ov_pipe.generate(["1", "2"], max_new_tokens=2) + ov_pipe.generate(["a"], max_new_tokens=2) + +# +# Chat scenario +# + +generation_configs = [ + dict(do_sample=False, max_new_tokens=20), + dict(do_sample=False, num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0) +] + + +questions = [ + '1+1=', + 'What is the previous answer?', + 'Why is the Sun yellow?', + 'What was my first question?' +] + + +@pytest.mark.parametrize("generation_config", generation_configs) +@pytest.mark.parametrize("model_descr", get_chat_models_list()) +@pytest.mark.precommit +@pytest.mark.nightly +def test_chat_compare_with_HF(model_descr, generation_config: Dict): + chat_history_hf = [] + chat_history_ov = [] + chat_prompt = '' + + # Will set add_special_tokens=False inside pipeline when start_chat() is called. + model_id, path, tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) + + ov_pipe.start_chat() + for prompt in questions: + chat_history_hf.append({'role': 'user', 'content': prompt}) + chat_history_ov.append({'role': 'user', 'content': prompt}) + + chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False) + + answer = opt_model.generate(**tokenized, **generation_config) + answer_str = tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True) + chat_history_hf.append({'role': 'assistant', 'content': answer_str}) + + answer_ov = ov_pipe.generate(prompt, **generation_config) + chat_history_ov.append({'role': 'assistant', 'content': answer_ov}) + + ov_pipe.finish_chat() + + if chat_history_ov != chat_history_hf: + print(f'hf_output: {chat_history_hf}') + print(f'ov_output: {chat_history_ov}') + + assert chat_history_ov == chat_history_hf + + +# +# Streaming with callback +# + def user_defined_callback(subword): print(subword) @@ -422,11 +491,14 @@ def test_operator_with_streamer_kwargs_batch_throws(): with pytest.raises(RuntimeError): ov_pipe('', num_beams=2, streamer=printer) +# +# Tests on generation configs (invalid cases and handling within LLMPipeline) +# invalid_configs = [ dict(num_beam_groups=3, num_beams=15, do_sample=True), # TODO: CVS-158682 eos_token_id is still read from tiny-random-phi3 and we cannot modify RTInfo in tests - # dict(do_sample=True), # no eos_token_id no max_new_tokens, no max_len + # dict(do_sample=True), # no eos_token_id no max_new_tokens, no max_len dict(eos_token_id=42, ignore_eos=True), # no max_new_tokens, no max_len with ignore_eos dict(repetition_penalty=-1.0, eos_token_id=42, max_new_tokens=20), # invalid penalty dict(temperature=-1.0, do_sample=True, eos_token_id=42, max_new_tokens=20), # invalid temp @@ -446,7 +518,7 @@ def test_invalid_generation_configs_throws(model_tmp_path, generation_config): @pytest.mark.precommit @pytest.mark.nightly -def test_valid_configs(model_tmp_path): +def test_eos_token_is_inherited_from_default_generation_config(model_tmp_path): model_id, temp_path = model_tmp_path ov_pipe = load_genai_pipe_with_configs([({"eos_token_id": 37}, "config.json")], temp_path) @@ -454,6 +526,8 @@ def test_valid_configs(model_tmp_path): config.do_sample = True # no eos_token_id but it's loaded from config.json ov_pipe.set_generation_config(config) + assert 37 == ov_pipe.get_generation_config().eos_token_id + invalid_py_configs = [ dict(num_beam_groups=3, num_beams=15, do_sample=True), @@ -478,6 +552,9 @@ def test_python_generation_config_validation_throws(model_tmp_path, generation_c with pytest.raises(return_exception_type): ov_pipe.set_generation_config(ov_genai.GenerationConfig(**generation_config)) +# +# Work with Unicode in Python API +# @pytest.mark.precommit @pytest.mark.nightly @@ -512,69 +589,9 @@ def test_unicode_pybind_decoding_one_string_streamer(): ov_pipe.generate(",", max_new_tokens=4, streamer=lambda x: res_str.append(x)) assert '�' == res_str[-1] - -@pytest.mark.skip(reason="probably both models ov + hf doesn't fit to memory") -@pytest.mark.precommit -@pytest.mark.nightly -@pytest.mark.skipif(sys.platform.startswith("win"), reason="not enough space for this model on Win") -def test_left_pad(): - # test left pad tokenizer post processing implementation - prompts = [ - "The Sun is yellow because", - "The Sun is yellow because [force left pad tokens]" - ] - models = read_model(("microsoft/phi-1_5", Path("phi-1_5/"))) - - config = { - "max_new_tokens": 20, - "num_beam_groups": 2, - "num_beams": 2, - "num_return_sequences": 2, - "do_sample": False, - "diversity_penalty": 1.0, - # phi 1_5 has no eos_token_id in model configuration - # ov genai will detect eos_token_id from tokenizer config - # hf implementation doesn't fetch it from tokenizer config and defaults to None - # align ov genai and hf by setting eos_token_id explicitly - "eos_token_id": 50256, - } - - models[2].pad_token = models[2].eos_token - run_hf_ov_genai_comparison_batched(models, config, prompts) - - -@pytest.mark.parametrize("generation_config", test_configs) -@pytest.mark.parametrize("prompt", batched_prompts[1:]) # num_beams=15 diverges on the first prompt. -@pytest.mark.precommit -def test_continuous_batching_vs_stateful(prompt, generation_config): - model_id, path, tokenizer, model, stateful = read_model(( - "facebook/opt-125m", - Path("opt-125m") - )) - cb = get_continuous_batching(path) - generated = cb.generate(prompt, **generation_config) - reference = stateful.generate(prompt, **generation_config) - assert generated.texts == reference.texts - if 1 != generation_config.get("num_return_sequences", 1): - # Stateful puts zeroes to generated.scores. Don't compare them. - for gen, ref in zip(generated.scores, reference.scores): - assert math.isclose(gen, ref, abs_tol=0.0003) - - -@pytest.mark.parametrize("prompt", prompts) -@pytest.mark.precommit -def test_cb_streamer_vs_return_vs_stateful(prompt): - model_id, path, hf_tokenizer, opt_model, ov_pipe = read_model(( - "facebook/opt-125m", - Path("opt-125m") - )) - cb_pipe = get_continuous_batching(path) - streamed = [] - generated = cb_pipe.generate(prompt, max_new_tokens=20, streamer=lambda subword: streamed.append(subword)) - reference = ov_pipe.generate(prompt, max_new_tokens=20) - assert generated == "".join(streamed) - assert "".join(streamed) == reference - +# +# Perf metrics +# def run_perf_metrics_collection(model_descr, generation_config: Dict, prompt: str) -> ov_genai.PerfMetrics: model_id, path, hf_tokenizer, opt_model, ov_pipe = model_descr @@ -582,12 +599,13 @@ def run_perf_metrics_collection(model_descr, generation_config: Dict, prompt: st config = generation_config.copy() # to avoid side effects if 'do_sample' not in config: - # Some HF models have default do_sample = True, and if we set beam search generation config + # Some HF models have default do_sample = True, and if we set beam search generation config # it conflicts with `diversity_penalty` and/or `num_beam_groups`. # Need to set explicitly to False, but only if test arguments omitted this arg. # Do not apply 'repetition_penalty' if sampling is not used. config['do_sample'] = False config['repetition_penalty'] = 1.0 # 1.0 means no penalty + return ov_pipe.generate([prompt], **config).perf_metrics @@ -598,20 +616,21 @@ def run_perf_metrics_collection(model_descr, generation_config: Dict, prompt: st @pytest.mark.parametrize("model_descr", get_models_list()) @pytest.mark.precommit @pytest.mark.nightly +@pytest.mark.skip(reason="load_time + mean_gen_duration < total_time fails in https://github.com/openvinotoolkit/openvino.genai/actions/runs/12503590506/job/34884840100?pr=1440.") def test_perf_metrics(model_descr, generation_config, prompt): import time start_time = time.perf_counter() perf_metrics = run_perf_metrics_collection(read_model(model_descr), generation_config, prompt) total_time = (time.perf_counter() - start_time) * 1000 - + # Check that load time is adequate. load_time = perf_metrics.get_load_time() - assert load_time > 0 and load_time < 1000.0 - + assert load_time > 0 and load_time < 1000.0 + # Check that num input and generated tokens are adequate. num_generated_tokens = perf_metrics.get_num_generated_tokens() - assert num_generated_tokens > 0 and num_generated_tokens <= generation_config['max_new_tokens'] - + assert num_generated_tokens > 0 and num_generated_tokens <= generation_config['max_new_tokens'] + num_input_tokens = perf_metrics.get_num_input_tokens() assert num_input_tokens > 0 and num_input_tokens <= len(prompt) @@ -622,7 +641,7 @@ def test_perf_metrics(model_descr, generation_config, prompt): raw_metrics = perf_metrics.raw_metrics durations = np.array(raw_metrics.m_durations) / 1000 # Check that prefill is not included in durations for TPOT calculation. - # For the very long prompt prefill is slow and TTFT is much larger than any other token genration duration. + # For the very long prompt prefill is slow and TTFT is much larger than any other token generation duration. assert np.all(mean_ttft > durations * 2) mean_tpot, std_tpot = perf_metrics.get_tpot() @@ -632,7 +651,7 @@ def test_perf_metrics(model_descr, generation_config, prompt): mean_throughput, std_throughput = perf_metrics.get_throughput() assert (mean_throughput, std_throughput) == (perf_metrics.get_throughput().mean, perf_metrics.get_throughput().std) assert mean_throughput > 0 and mean_throughput < 20000.0 - + mean_gen_duration, std_gen_duration = perf_metrics.get_generate_duration() assert (mean_gen_duration, std_gen_duration) == (perf_metrics.get_generate_duration().mean, perf_metrics.get_generate_duration().std) assert mean_gen_duration > 0 and load_time + mean_gen_duration < total_time @@ -647,7 +666,7 @@ def test_perf_metrics(model_descr, generation_config, prompt): assert (mean_detok_duration, std_detok_duration) == (perf_metrics.get_detokenization_duration().mean, perf_metrics.get_detokenization_duration().std) assert mean_detok_duration > 0 and mean_detok_duration < mean_gen_duration assert std_detok_duration == 0 - + # assert that calculating statistics manually from the raw counters we get the same restults as from PerfMetrics assert np.allclose(mean_tpot, np.mean(durations)) assert np.allclose(std_tpot, np.std(durations)) @@ -668,15 +687,11 @@ def test_perf_metrics(model_descr, generation_config, prompt): assert len(raw_metrics.m_batch_sizes) > 0 assert len(raw_metrics.m_durations) > 0 +# +# Misc +# -@pytest.mark.precommit -@pytest.mark.nightly -def test_batch_switch(): - ov_pipe = read_model(('katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')))[4] - ov_pipe.generate(["a"], max_new_tokens=2) - ov_pipe.generate(["1", "2"], max_new_tokens=2) - - +# TODO: move to test_sampling.py @pytest.mark.precommit @pytest.mark.nightly def test_stop_token_ids(): @@ -691,6 +706,7 @@ def test_stop_token_ids(): assert 9935 in res.tokens[0] +# TODO: move to test_sampling.py @pytest.mark.precommit @pytest.mark.nightly def test_stop_strings(): @@ -701,3 +717,34 @@ def test_stop_strings(): stop_strings={"ignored", "боль"} ) assert "боль" not in res + + +# TODO: move this test to test_tokenizer.py +@pytest.mark.skip(reason="probably both models ov + hf doesn't fit to memory") +@pytest.mark.precommit +@pytest.mark.nightly +@pytest.mark.skipif(sys.platform.startswith("win"), reason="not enough space for this model on Win") +def test_left_pad(): + # test left pad tokenizer post processing implementation + prompts = [ + "The Sun is yellow because", + "The Sun is yellow because [force left pad tokens]" + ] + models = read_model(("microsoft/phi-1_5", Path("phi-1_5/"))) + + config = { + "max_new_tokens": 20, + "num_beam_groups": 2, + "num_beams": 2, + "num_return_sequences": 2, + "do_sample": False, + "diversity_penalty": 1.0, + # phi 1_5 has no eos_token_id in model configuration + # ov genai will detect eos_token_id from tokenizer config + # hf implementation doesn't fetch it from tokenizer config and defaults to None + # align ov genai and hf by setting eos_token_id explicitly + "eos_token_id": 50256, + } + + models[2].pad_token = models[2].eos_token + run_hf_ov_genai_comparison_batched(models, config, prompts) diff --git a/tests/python_tests/test_llm_pipeline_static.py b/tests/python_tests/test_llm_pipeline_static.py index cad8b0fea0..c3500d15ac 100644 --- a/tests/python_tests/test_llm_pipeline_static.py +++ b/tests/python_tests/test_llm_pipeline_static.py @@ -145,7 +145,7 @@ def test_chat_generation(model_descr): 'What was my first question?' ] - model_path = get_chat_models_lists()[0][1] + model_path = get_chat_models_list()[0][1] chat_history_stateful = generate_chat_history(model_path, "CPU", { }, questions) chat_history_static = generate_chat_history(model_path, "NPU", common_config, questions) diff --git a/tests/python_tests/test_sampling.py b/tests/python_tests/test_sampling.py index fbcce76bf7..25ae9d8afa 100644 --- a/tests/python_tests/test_sampling.py +++ b/tests/python_tests/test_sampling.py @@ -10,13 +10,13 @@ from openvino_genai import ContinuousBatchingPipeline, GenerationConfig, Tokenizer from typing import List, TypedDict -from common import run_test_pipeline, read_models_list, get_hugging_face_model_and_tokenizer, save_ov_model_from_optimum, \ - generate_and_compare_with_reference_text, get_greedy, get_beam_search, get_multinomial_temperature, \ +from common import get_hugging_face_model_and_tokenizer, save_ov_model_from_optimum, \ + get_greedy, get_beam_search, get_multinomial_temperature, \ get_greedy_with_penalties, get_multinomial_temperature, \ get_multinomial_temperature_and_top_k, get_multinomial_temperature_and_top_p, \ get_multinomial_temperature_top_p_and_top_k, DEFAULT_SCHEDULER_CONFIG, get_greedy_with_repetition_penalty, \ get_multinomial_all_parameters, get_multinomial_temperature_and_num_return_sequence, \ - generate_and_compare_with_reference_text, get_greedy, get_greedy_with_min_and_max_tokens, \ + get_greedy, get_greedy_with_min_and_max_tokens, \ get_greedy_with_single_stop_string, get_greedy_with_multiple_stop_strings, get_greedy_with_multiple_stop_strings_no_match, \ get_beam_search, get_beam_search_min_and_max_tokens, get_beam_search_with_single_stop_string, \ get_beam_search_with_multiple_stop_strings, get_beam_search_with_multiple_stop_strings_no_match, get_multinomial_max_and_min_token, \ @@ -27,25 +27,9 @@ run_continuous_batching +# TODO: currently, this test drops EOS token as both HF and OV use `skip_special_tokens=True`, which should be disabled for samlpling tests @pytest.mark.precommit -@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "precommit"))) -def test_sampling_precommit(tmp_path, model_id): - run_test_pipeline(tmp_path, model_id) - - -@pytest.mark.nightly -@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "nightly"))) -def test_sampling_nightly(tmp_path, model_id): - run_test_pipeline(tmp_path, model_id) - -@pytest.mark.real_models -@pytest.mark.parametrize("model_id", read_models_list(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "real_models"))) -def test_real_models(tmp_path, model_id): - run_test_pipeline(tmp_path, model_id) - - -@pytest.mark.precommit -def test_eos_beam_search(tmp_path): +def test_beam_search_has_eos_token_at_end(tmp_path): ''' Current test checks that in case of beam search, some generation results explicitly have EOS token at the end, which is aligned with HF @@ -61,8 +45,9 @@ def test_eos_beam_search(tmp_path): generate_and_compare_with_hf(model_id, prompts, generation_configs, scheduler_config, tmp_path) +# TODO: currently, this test drops EOS token as both HF and OV use `skip_special_tokens=True`, which should be disabled for samlpling tests @pytest.mark.precommit -def test_eos_greedy(tmp_path): +def test_greedy_has_eos_token_at_end(tmp_path): ''' Current test checks that in case of gready, some generation results explicitly have EOS token at the end, which is aligned with HF: @@ -76,55 +61,44 @@ def test_eos_greedy(tmp_path): scheduler_config = get_scheduler_config() generate_and_compare_with_hf(model_id, prompts, generation_configs, scheduler_config, tmp_path) + +# TODO: consider removing all these functions with generation configs and use Dict with properties, which can be converted to generation config @pytest.mark.precommit -@pytest.mark.parametrize("generation_config", [get_greedy(), get_greedy_with_min_and_max_tokens(), get_greedy_with_repetition_penalty(), get_greedy_with_single_stop_string(), - get_greedy_with_multiple_stop_strings(), get_greedy_with_multiple_stop_strings_no_match(), - get_beam_search(), get_beam_search_min_and_max_tokens(), get_beam_search_with_multiple_stop_strings_no_match(), - get_greedy_stop_strings_exclude_from_output(), get_greedy_stop_strings_include_to_output(), - get_greedy_n_stop_strings_exclude_from_output(), get_greedy_n_stop_strings_include_to_output() ], - ids=[ - "greedy", - "greedy_with_min_and_max_tokens", - "greedy_with_repetition_penalty", - "greedy_with_single_stop_string", - "greedy_with_multiple_stop_strings", - "greedy_with_multiple_stop_strings_no_match", - "beam", - "beam_search_min_and_max_tokens", - "beam_search_with_multiple_stop_strings_no_match", - "get_greedy_stop_strings_exclude_from_output", - "get_greedy_stop_strings_include_to_output", - "get_greedy_n_stop_strings_exclude_from_output", - "get_greedy_n_stop_strings_include_to_output" - ]) -def test_individual_generation_configs_deterministic(tmp_path, generation_config): - prompts = [ - "What is OpenVINO?", - ] +@pytest.mark.parametrize("generation_config", + [get_greedy(), get_greedy_with_min_and_max_tokens(), get_greedy_with_repetition_penalty(), get_greedy_with_single_stop_string(), + get_greedy_with_multiple_stop_strings(), get_greedy_with_multiple_stop_strings_no_match(), + get_beam_search(), get_beam_search_min_and_max_tokens(), get_beam_search_with_multiple_stop_strings_no_match(), + get_greedy_stop_strings_exclude_from_output(), get_greedy_stop_strings_include_to_output(), + get_greedy_n_stop_strings_exclude_from_output(), get_greedy_n_stop_strings_include_to_output()], + ids=["greedy", "greedy_with_min_and_max_tokens", "greedy_with_repetition_penalty", "greedy_with_single_stop_string", + "greedy_with_multiple_stop_strings", "greedy_with_multiple_stop_strings_no_match", "beam_search", "beam_search_min_and_max_tokens", + "beam_search_with_multiple_stop_strings_no_match", "greedy_stop_strings_exclude_from_output", "greedy_stop_strings_include_to_output", + "greedy_n_stop_strings_exclude_from_output", "greedy_n_stop_strings_include_to_output"]) +def test_sampling_against_optimum(tmp_path, generation_config): + prompts = [ "What is OpenVINO?" ] generation_configs = [generation_config] model_id : str = "facebook/opt-125m" generate_and_compare_with_hf(model_id, prompts, generation_configs, DEFAULT_SCHEDULER_CONFIG, tmp_path) + @pytest.mark.precommit @pytest.mark.xfail( raises=AssertionError, reason="Stop strings do not seem to work as expected with beam search in HF, so comparison will fail. If it changes, these cases shall be merged to the test above.", strict=True, ) -@pytest.mark.parametrize("generation_config", [get_beam_search_with_single_stop_string(), get_beam_search_with_multiple_stop_strings(),], - ids=[ - "beam_search_with_single_stop_string", - "beam_search_with_multiple_stop_strings", - ]) +@pytest.mark.parametrize("generation_config", [get_beam_search_with_single_stop_string(), get_beam_search_with_multiple_stop_strings()], + ids=["beam_search_with_single_stop_string", "beam_search_with_multiple_stop_strings"]) def test_beam_search_with_stop_string(tmp_path, generation_config): - prompts = [ - "What is OpenVINO?", - ] + prompts = [ "What is OpenVINO?" ] generation_configs = [generation_config] model_id : str = "facebook/opt-125m" generate_and_compare_with_hf(model_id, prompts, generation_configs, DEFAULT_SCHEDULER_CONFIG, tmp_path) +# TODO: remove platform specific reference texts once CVS-159912 is done and use comparison with HF +# and merge this tests with 'test_sampling_against_optimum' by extending a list of generation configs + class PlatformsRefTexts(TypedDict, total=False): linux: List[List[str]] win32: List[List[str]] @@ -306,7 +280,7 @@ class RandomSamplingTestStruct: "multinomial_temperature_and_frequence_penalty", "greedy_with_penalties", "multinomial_max_and_min_token"]) -def test_individual_generation_configs_random(tmp_path, test_struct: RandomSamplingTestStruct): +def test_multinomial_sampling_against_reference(tmp_path, test_struct: RandomSamplingTestStruct): generation_config = test_struct.generation_config prompts = test_struct.prompts @@ -326,9 +300,10 @@ def test_individual_generation_configs_random(tmp_path, test_struct: RandomSampl @pytest.mark.precommit -@pytest.mark.parametrize("get_generation_config", [get_greedy, get_beam_search, get_multinomial_all_parameters]) +@pytest.mark.parametrize("get_generation_config", [get_greedy, get_beam_search, get_multinomial_all_parameters], + ids=["greedy", "beam_search", "multinomial_all_parameters"]) @pytest.mark.parametrize("max_num_batched_tokens", [2, 4, 256]) -def test_echo_without_completion(tmp_path, get_generation_config, max_num_batched_tokens): +def test_echo_prompt_phase_only(tmp_path, get_generation_config, max_num_batched_tokens): generation_config = get_generation_config() generation_config.max_new_tokens = 0 generation_config.echo = True @@ -337,14 +312,14 @@ def test_echo_without_completion(tmp_path, get_generation_config, max_num_batche scheduler_config.max_num_batched_tokens = max_num_batched_tokens generation_configs = [generation_config] model_id : str = "facebook/opt-125m" - model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True) + opt_model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True) model_path : Path = tmp_path / model_id - save_ov_model_from_optimum(model, hf_tokenizer, model_path) + save_ov_model_from_optimum(opt_model, hf_tokenizer, model_path) - pipe = ContinuousBatchingPipeline(model_path, Tokenizer(model_path), scheduler_config, "CPU") + cb_pipe = ContinuousBatchingPipeline(model_path, Tokenizer(model_path), scheduler_config, "CPU") - outputs = pipe.generate(["What is OpenVINO?"], generation_configs) + outputs = cb_pipe.generate(["What is OpenVINO?"], generation_configs) assert(len(outputs)) for output in outputs: assert(len(output.m_generation_ids)) @@ -353,9 +328,10 @@ def test_echo_without_completion(tmp_path, get_generation_config, max_num_batche @pytest.mark.precommit -@pytest.mark.parametrize("get_generation_config", [get_greedy, get_beam_search, get_multinomial_all_parameters]) +@pytest.mark.parametrize("get_generation_config", [get_greedy, get_beam_search, get_multinomial_all_parameters], + ids=["greedy", "beam_search", "multinomial_all_parameters"]) @pytest.mark.parametrize("max_num_batched_tokens", [2, 4, 256]) -def test_echo_with_completion(tmp_path, get_generation_config, max_num_batched_tokens): +def test_echo_with_generation_phase(tmp_path, get_generation_config, max_num_batched_tokens): generation_config = get_generation_config() generation_config.max_new_tokens = 10 generation_config.echo = True @@ -364,45 +340,17 @@ def test_echo_with_completion(tmp_path, get_generation_config, max_num_batched_t scheduler_config.max_num_batched_tokens = max_num_batched_tokens generation_configs = [generation_config] model_id : str = "facebook/opt-125m" - model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True) + opt_model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True) model_path : Path = tmp_path / model_id - save_ov_model_from_optimum(model, hf_tokenizer, model_path) - - pipe = ContinuousBatchingPipeline(model_path, Tokenizer(model_path), scheduler_config, "CPU") + save_ov_model_from_optimum(opt_model, hf_tokenizer, model_path) - outputs = pipe.generate(["What is OpenVINO?"], generation_configs) + cb_pipe = ContinuousBatchingPipeline(model_path, Tokenizer(model_path), scheduler_config, "CPU") + outputs = cb_pipe.generate(["What is OpenVINO?"], generation_configs) assert(len(outputs)) + for output in outputs: assert(len(output.m_generation_ids)) for sequence in output.m_generation_ids: assert(sequence.startswith("What is OpenVINO?")) assert(len(sequence) > len("What is OpenVINO?")) - - -@pytest.mark.precommit -@pytest.mark.parametrize("sampling_config", [get_greedy(), get_beam_search(), get_multinomial_all_parameters()]) -def test_post_oom_health(tmp_path, sampling_config): - generation_config = sampling_config - generation_config.ignore_eos = True - generation_config.max_new_tokens = 1000000 - - scheduler_config = get_scheduler_config() - # Low cache size to trigger OOM quickly - scheduler_config.num_kv_blocks = 10 - generation_configs = [generation_config] - model_id : str = "facebook/opt-125m" - model, hf_tokenizer = get_hugging_face_model_and_tokenizer(model_id, use_optimum=True) - - models_path : Path = tmp_path / model_id - save_ov_model_from_optimum(model, hf_tokenizer, models_path) - - pipe = ContinuousBatchingPipeline(models_path, Tokenizer(models_path), scheduler_config, "CPU") - # First run should return incomplete response - output = pipe.generate(["What is OpenVINO?"], generation_configs) - assert (len(output)) - assert(len(output[0].m_generation_ids)) - # Same for the second run, here we want to make sure the cleanup works and we have free blocks after recent OOM - output = pipe.generate(["What is OpenVINO?"], generation_configs) - assert (len(output)) - assert(len(output[0].m_generation_ids)) diff --git a/tests/python_tests/test_vlm_api.py b/tests/python_tests/test_vlm_pipeline.py similarity index 100% rename from tests/python_tests/test_vlm_api.py rename to tests/python_tests/test_vlm_pipeline.py diff --git a/tests/python_tests/test_whisper_generate_api.py b/tests/python_tests/test_whisper_pipeline.py similarity index 100% rename from tests/python_tests/test_whisper_generate_api.py rename to tests/python_tests/test_whisper_pipeline.py