Skip to content

Commit

Permalink
Fixed greedy_with_penalties
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Dec 26, 2024
1 parent 4191b31 commit 6b89a12
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tests/python_tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def test_greedy_has_eos_token_at_end(tmp_path):
# TODO: consider removing all these functions with generation configs and use Dict with properties, which can be converted to generation config
@pytest.mark.precommit
@pytest.mark.parametrize("generation_config",
[get_greedy(), get_greedy_with_min_and_max_tokens(), get_greedy_with_repetition_penalty(), get_greedy_with_penalties(), get_greedy_with_single_stop_string(),
[get_greedy(), get_greedy_with_min_and_max_tokens(), get_greedy_with_repetition_penalty(), get_greedy_with_single_stop_string(),
get_greedy_with_multiple_stop_strings(), get_greedy_with_multiple_stop_strings_no_match(),
get_beam_search(), get_beam_search_min_and_max_tokens(), get_beam_search_with_multiple_stop_strings_no_match(),
get_greedy_stop_strings_exclude_from_output(), get_greedy_stop_strings_include_to_output(),
get_greedy_n_stop_strings_exclude_from_output(), get_greedy_n_stop_strings_include_to_output()],
ids=["greedy", "greedy_with_min_and_max_tokens", "greedy_with_repetition_penalty", "greedy_with_penalties", "greedy_with_single_stop_string",
ids=["greedy", "greedy_with_min_and_max_tokens", "greedy_with_repetition_penalty", "greedy_with_single_stop_string",
"greedy_with_multiple_stop_strings", "greedy_with_multiple_stop_strings_no_match", "beam_search", "beam_search_min_and_max_tokens",
"beam_search_with_multiple_stop_strings_no_match", "greedy_stop_strings_exclude_from_output", "greedy_stop_strings_include_to_output",
"greedy_n_stop_strings_exclude_from_output", "greedy_n_stop_strings_include_to_output"])
Expand Down Expand Up @@ -235,6 +235,15 @@ class RandomSamplingTestStruct:
]
],
),
RandomSamplingTestStruct(
generation_config=get_greedy_with_penalties(),
prompts=["What is OpenVINO?"],
ref_texts=[
[
"\nOpenVINO is a software that allows users to create and manage their own virtual machines. It's designed for use with Windows, Mac OS X"
]
],
),
RandomSamplingTestStruct(
generation_config=get_multinomial_max_and_min_token(),
prompts=["What is OpenVINO?"],
Expand Down Expand Up @@ -269,6 +278,7 @@ class RandomSamplingTestStruct:
"multinomial_all_parameters",
"multinomial_temperature_and_presence_penalty",
"multinomial_temperature_and_frequence_penalty",
"greedy_with_penalties",
"multinomial_max_and_min_token"])
def test_multinomial_sampling_against_reference(tmp_path, test_struct: RandomSamplingTestStruct):
generation_config = test_struct.generation_config
Expand Down

0 comments on commit 6b89a12

Please sign in to comment.