Skip to content

Commit

Permalink
Merge pull request #431 from explosion/main
Browse files Browse the repository at this point in the history
Sync `develop` with `main`
  • Loading branch information
rmitsch authored Jan 29, 2024
2 parents ad30add + 0fc4633 commit 89c7971
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 22 deletions.
2 changes: 1 addition & 1 deletion spacy_llm/models/rest/azure/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self._deployment_name = deployment_name
super().__init__(
name=name,
endpoint=endpoint or endpoint,
endpoint=endpoint,
config=config,
strict=strict,
max_tries=max_tries,
Expand Down
36 changes: 15 additions & 21 deletions spacy_llm/models/rest/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ def credentials(self) -> Dict[str, str]:
if api_org:
headers["OpenAI-Organization"] = api_org

# Ensure endpoint is supported.
if self._endpoint not in (Endpoints.NON_CHAT, Endpoints.CHAT):
raise ValueError(
f"Endpoint {self._endpoint} isn't supported. Please use one of: {Endpoints.CHAT}, {Endpoints.NON_CHAT}."
)

return headers

def _verify_auth(self) -> None:
Expand Down Expand Up @@ -115,9 +109,21 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:

return responses

if self._endpoint == Endpoints.CHAT:
# The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual
# requests.
# The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual requests.

if self._endpoint == Endpoints.NON_CHAT:
responses = _request({"prompt": prompts_for_doc})
if "error" in responses:
return responses["error"]
assert len(responses["choices"]) == len(prompts_for_doc)

for response in responses["choices"]:
if "text" in response:
api_responses.append(response["text"])
else:
api_responses.append(srsly.json_dumps(response))

else:
for prompt in prompts_for_doc:
responses = _request(
{"messages": [{"role": "user", "content": prompt}]}
Expand All @@ -134,18 +140,6 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
)
)

elif self._endpoint == Endpoints.NON_CHAT:
responses = _request({"prompt": prompts_for_doc})
if "error" in responses:
return responses["error"]
assert len(responses["choices"]) == len(prompts_for_doc)

for response in responses["choices"]:
if "text" in response:
api_responses.append(response["text"])
else:
api_responses.append(srsly.json_dumps(response))

all_api_responses.append(api_responses)

return all_api_responses
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/tests/tasks/legacy/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ def test_ner_to_disk(noop_config, tmp_path: Path):
assert task1._label_dict == task2._label_dict == labels


@pytest.mark.filterwarnings("ignore:Task supports sharding")
def test_label_inconsistency():
"""Test whether inconsistency between specified labels and labels in examples is detected."""
cfg = f"""
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/tests/tasks/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def test_ner_to_disk(noop_config: str, tmp_path: Path):
assert task1._label_dict == task2._label_dict == labels


@pytest.mark.filterwarnings("ignore:Task supports sharding")
def test_label_inconsistency():
"""Test whether inconsistency between specified labels and labels in examples is detected."""
cfg = f"""
Expand Down

0 comments on commit 89c7971

Please sign in to comment.