Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to ditch **kwargs #299

Merged
merged 3 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions spacy_llm/tasks/builtin_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jinja2
import srsly
from spacy import Language, util, Errors
from spacy import Errors, Language, util
from spacy.tokens import Doc
from spacy.training import Example

Expand Down Expand Up @@ -45,19 +45,35 @@ def __init__(
self._template = template
self._prompt_example_type = prompt_example_type

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[Any]:
def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]:
"""Generate prompts from docs.
docs (Iterable[Doc]): Docs to generate prompts from.
RETURNS (Iterable[Any]): Iterable with one prompt per doc.
"""
environment = jinja2.Environment()
_template = environment.from_string(self._template)
for doc in docs:
for doc in self._preprocess_docs_for_prompt(docs):
prompt = _template.render(
text=doc.text, prompt_examples=self._prompt_examples, **kwargs
text=doc.text,
prompt_examples=self._prompt_examples,
**self._prompt_data,
)
yield prompt

@property
def _prompt_data(self) -> Dict[str, Any]:
"""Returns data injected into prompt template. No-op if not overridden by inheriting task class.
RETURNS (Dict[str, Any]): Data injected into prompt template.
"""
return {}

def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
"""Preprocesses docs before injection into prompt template. No-op if not overridden by inheriting task class.
docs (Iterable[Doc]): Docs to generate prompts from.
RETURNS (Iterable[Doc]): Preprocessed docs.
"""
return docs

@abc.abstractmethod
def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[Any]
Expand Down
21 changes: 11 additions & 10 deletions spacy_llm/tasks/rel/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

from spacy.language import Language
from spacy.tokens import Doc
Expand Down Expand Up @@ -52,15 +52,16 @@ def __init__(
self._verbose = verbose
self._field = "rel"

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(
docs=[
Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs
],
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
preannotate=RELTask._preannotate,
)
def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
return [Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs]

@property
def _prompt_data(self) -> Dict[str, Any]:
return {
"labels": list(self._label_dict.values()),
"label_definitions": self._label_definitions,
"preannotate": RELTask._preannotate,
}

@staticmethod
def _preannotate(doc: Union[Doc, RELExample]) -> str:
Expand Down
3 changes: 0 additions & 3 deletions spacy_llm/tasks/sentiment/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def initialize(
get_examples=get_examples, nlp=nlp, n_prompt_examples=n_prompt_examples
)

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(docs=docs)

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
) -> Iterable[Doc]:
Expand Down
21 changes: 10 additions & 11 deletions spacy_llm/tasks/span/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import typing
from typing import Callable, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

from spacy.tokens import Doc, Span

Expand Down Expand Up @@ -64,16 +64,15 @@ def __init__(
if self._prompt_examples:
self._prompt_examples = list(self._check_label_consistency(self))

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(
docs=docs,
description=self._description,
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
examples=self._prompt_examples,
allow_overlap=self._allow_overlap,
**kwargs,
)
@property
def _prompt_data(self) -> Dict[str, Any]:
return {
"description": self._description,
"labels": list(self._label_dict.values()),
"label_definitions": self._label_definitions,
"examples": self._prompt_examples,
"allow_overlap": self._allow_overlap,
}

@staticmethod
def _validate_alignment(alignment_mode: str):
Expand Down
10 changes: 7 additions & 3 deletions spacy_llm/tasks/summarization/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Iterable, List, Optional, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Type

from spacy.language import Language
from spacy.tokens import Doc
Expand Down Expand Up @@ -78,12 +78,16 @@ def _check_prompt_example_summary_len(self) -> None:
f"LLM will likely produce responses that are too long."
)

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
@property
def _prompt_data(self) -> Dict[str, Any]:
"""Returns data injected into prompt template. No-op if not overridden by inheriting task class.
RETURNS (Dict[str, Any]): Data injected into prompt template.
"""
if self._check_example_summaries:
self._check_prompt_example_summary_len()
self._check_example_summaries = False

return super().generate_prompts(docs=docs, max_n_words=self._max_n_words)
return {"max_n_words": self._max_n_words}

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
Expand Down
16 changes: 8 additions & 8 deletions spacy_llm/tasks/textcat/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def __init__(
)
self._exclusive_classes = True

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(
docs=docs,
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
exclusive_classes=self._exclusive_classes,
allow_none=self._allow_none,
)
@property
def _prompt_data(self) -> Dict[str, Any]:
return {
"labels": list(self._label_dict.values()),
"label_definitions": self._label_definitions,
"exclusive_classes": self._exclusive_classes,
"allow_none": self._allow_none,
}

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
Expand Down
7 changes: 4 additions & 3 deletions spacy_llm/tasks/textcat/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from spacy.scorer import Scorer
from spacy.training import Example

from ...compat import BaseModel, Self
from ...compat import Self
from ...ty import FewshotExample


class TextCatExample(BaseModel):
class TextCatExample(FewshotExample):
text: str
answer: str

Expand All @@ -27,7 +28,7 @@ def generate(cls, example: Example, **kwargs) -> Self:
]
)

return TextCatExample(
return cls(
text=example.reference.text,
answer=answer,
)
Expand Down
8 changes: 7 additions & 1 deletion spacy_llm/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def from_disk(


class FewshotExample(abc.ABC, BaseModel):
"""Base fewshot-example.
From Python 3.7 onwards it's possible to make Pydantic models generic, which allows for a clean solution (see
https://github.com/pydantic/pydantic/issues/4171) using the controller pattern and Pydantic's GenericModel
(BaseModel in Pydantic v2). Until then passing **kwargs seems like the sanest option.
"""

@classmethod
@abc.abstractmethod
def generate(cls, example: Example, **kwargs) -> Self:
Expand Down Expand Up @@ -94,7 +100,7 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:

@runtime_checkable
class LLMTask(Protocol):
def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[_PromptType]:
def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[_PromptType]:
"""Generate prompts from docs.
docs (Iterable[Doc]): Docs to generate prompts from.
RETURNS (Iterable[_PromptType]): Iterable with one prompt per doc.
Expand Down