From 09fcad1a6ebdb737e3daece08df19f54c1dcd531 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 26 Jan 2024 15:43:33 +0100 Subject: [PATCH 1/2] Remove check for fixed endpoint for OpenAI models. (#429) --- spacy_llm/models/rest/azure/model.py | 2 +- spacy_llm/models/rest/openai/model.py | 36 +++++++++++---------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index 5a2d0fef..32adc0bb 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -35,7 +35,7 @@ def __init__( self._deployment_name = deployment_name super().__init__( name=name, - endpoint=endpoint or endpoint, + endpoint=endpoint, config=config, strict=strict, max_tries=max_tries, diff --git a/spacy_llm/models/rest/openai/model.py b/spacy_llm/models/rest/openai/model.py index b8bbdae3..7715f12c 100644 --- a/spacy_llm/models/rest/openai/model.py +++ b/spacy_llm/models/rest/openai/model.py @@ -36,12 +36,6 @@ def credentials(self) -> Dict[str, str]: if api_org: headers["OpenAI-Organization"] = api_org - # Ensure endpoint is supported. - if self._endpoint not in (Endpoints.NON_CHAT, Endpoints.CHAT): - raise ValueError( - f"Endpoint {self._endpoint} isn't supported. Please use one of: {Endpoints.CHAT}, {Endpoints.NON_CHAT}." - ) - return headers def _verify_auth(self) -> None: @@ -115,9 +109,21 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: return responses - if self._endpoint == Endpoints.CHAT: - # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual - # requests. + # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual requests. + + if self._endpoint == Endpoints.NON_CHAT: + responses = _request({"prompt": prompts_for_doc}) + if "error" in responses: + return responses["error"] + assert len(responses["choices"]) == len(prompts_for_doc) + + for response in responses["choices"]: + if "text" in response: + api_responses.append(response["text"]) + else: + api_responses.append(srsly.json_dumps(response)) + + else: for prompt in prompts_for_doc: responses = _request( {"messages": [{"role": "user", "content": prompt}]} @@ -134,18 +140,6 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: ) ) - elif self._endpoint == Endpoints.NON_CHAT: - responses = _request({"prompt": prompts_for_doc}) - if "error" in responses: - return responses["error"] - assert len(responses["choices"]) == len(prompts_for_doc) - - for response in responses["choices"]: - if "text" in response: - api_responses.append(response["text"]) - else: - api_responses.append(srsly.json_dumps(response)) - all_api_responses.append(api_responses) return all_api_responses From 0fc46336ab63d02b77e347b766ec021d07347bfb Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 29 Jan 2024 10:02:07 +0100 Subject: [PATCH 2/2] Fix legacy NER test warning in CI (#430) * Fix legacy NER test warning in CI. * Add warning filter for NER inconsistency test. --- spacy_llm/tests/tasks/legacy/test_ner.py | 1 + spacy_llm/tests/tasks/test_ner.py | 1 + 2 files changed, 2 insertions(+) diff --git a/spacy_llm/tests/tasks/legacy/test_ner.py b/spacy_llm/tests/tasks/legacy/test_ner.py index 3d9c133a..551e3dba 100644 --- a/spacy_llm/tests/tasks/legacy/test_ner.py +++ b/spacy_llm/tests/tasks/legacy/test_ner.py @@ -832,6 +832,7 @@ def test_ner_to_disk(noop_config, tmp_path: Path): assert task1._label_dict == task2._label_dict == labels +@pytest.mark.filterwarnings("ignore:Task supports sharding") def test_label_inconsistency(): """Test whether inconsistency between specified labels and labels in examples is detected.""" cfg = f""" diff --git a/spacy_llm/tests/tasks/test_ner.py b/spacy_llm/tests/tasks/test_ner.py index e8782d08..5fe4b178 100644 --- a/spacy_llm/tests/tasks/test_ner.py +++ b/spacy_llm/tests/tasks/test_ner.py @@ -820,6 +820,7 @@ def test_ner_to_disk(noop_config: str, tmp_path: Path): assert task1._label_dict == task2._label_dict == labels +@pytest.mark.filterwarnings("ignore:Task supports sharding") def test_label_inconsistency(): """Test whether inconsistency between specified labels and labels in examples is detected.""" cfg = f"""