diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index fa565a97..9af03f5c 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -4,7 +4,7 @@ import jinja2 import srsly -from spacy import Language, util, Errors +from spacy import Errors, Language, util from spacy.tokens import Doc from spacy.training import Example @@ -45,19 +45,35 @@ def __init__( self._template = template self._prompt_example_type = prompt_example_type - def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[Any]: + def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. RETURNS (Iterable[Any]): Iterable with one prompt per doc. """ environment = jinja2.Environment() _template = environment.from_string(self._template) - for doc in docs: + for doc in self._preprocess_docs_for_prompt(docs): prompt = _template.render( - text=doc.text, prompt_examples=self._prompt_examples, **kwargs + text=doc.text, + prompt_examples=self._prompt_examples, + **self._prompt_data, ) yield prompt + @property + def _prompt_data(self) -> Dict[str, Any]: + """Returns data injected into prompt template. No-op if not overridden by inheriting task class. + RETURNS (Dict[str, Any]): Data injected into prompt template. + """ + return {} + + def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: + """Preprocesses docs before injection into prompt template. No-op if not overridden by inheriting task class. + docs (Iterable[Doc]): Docs to generate prompts from. + RETURNS (Iterable[Doc]): Preprocessed docs. + """ + return docs + @abc.abstractmethod def parse_responses( self, docs: Iterable[Doc], responses: Iterable[Any] diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index fe9f6e4e..98ff30a9 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from spacy.language import Language from spacy.tokens import Doc @@ -52,15 +52,16 @@ def __init__( self._verbose = verbose self._field = "rel" - def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]: - return super().generate_prompts( - docs=[ - Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs - ], - labels=list(self._label_dict.values()), - label_definitions=self._label_definitions, - preannotate=RELTask._preannotate, - ) + def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: + return [Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs] + + @property + def _prompt_data(self) -> Dict[str, Any]: + return { + "labels": list(self._label_dict.values()), + "label_definitions": self._label_definitions, + "preannotate": RELTask._preannotate, + } @staticmethod def _preannotate(doc: Union[Doc, RELExample]) -> str: diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index 54c82572..ffde2368 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -61,9 +61,6 @@ def initialize( get_examples=get_examples, nlp=nlp, n_prompt_examples=n_prompt_examples ) - def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]: - return super().generate_prompts(docs=docs) - def parse_responses( self, docs: Iterable[Doc], responses: Iterable[str] ) -> Iterable[Doc]: diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index a7495c19..4ba69018 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -1,6 +1,6 @@ import abc import typing -from typing import Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from spacy.tokens import Doc, Span @@ -64,16 +64,15 @@ def __init__( if self._prompt_examples: self._prompt_examples = list(self._check_label_consistency(self)) - def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]: - return super().generate_prompts( - docs=docs, - description=self._description, - labels=list(self._label_dict.values()), - label_definitions=self._label_definitions, - examples=self._prompt_examples, - allow_overlap=self._allow_overlap, - **kwargs, - ) + @property + def _prompt_data(self) -> Dict[str, Any]: + return { + "description": self._description, + "labels": list(self._label_dict.values()), + "label_definitions": self._label_definitions, + "examples": self._prompt_examples, + "allow_overlap": self._allow_overlap, + } @staticmethod def _validate_alignment(alignment_mode: str): diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index d951e59e..dc1f4f1f 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Iterable, List, Optional, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy.language import Language from spacy.tokens import Doc @@ -78,12 +78,16 @@ def _check_prompt_example_summary_len(self) -> None: f"LLM will likely produce responses that are too long." ) - def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]: + @property + def _prompt_data(self) -> Dict[str, Any]: + """Returns data injected into prompt template. No-op if not overridden by inheriting task class. + RETURNS (Dict[str, Any]): Data injected into prompt template. + """ if self._check_example_summaries: self._check_prompt_example_summary_len() self._check_example_summaries = False - return super().generate_prompts(docs=docs, max_n_words=self._max_n_words) + return {"max_n_words": self._max_n_words} def parse_responses( self, docs: Iterable[Doc], responses: Iterable[str] diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index b638db17..04e3b54e 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -83,14 +83,14 @@ def __init__( ) self._exclusive_classes = True - def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]: - return super().generate_prompts( - docs=docs, - labels=list(self._label_dict.values()), - label_definitions=self._label_definitions, - exclusive_classes=self._exclusive_classes, - allow_none=self._allow_none, - ) + @property + def _prompt_data(self) -> Dict[str, Any]: + return { + "labels": list(self._label_dict.values()), + "label_definitions": self._label_definitions, + "exclusive_classes": self._exclusive_classes, + "allow_none": self._allow_none, + } def parse_responses( self, docs: Iterable[Doc], responses: Iterable[str] diff --git a/spacy_llm/tasks/textcat/util.py b/spacy_llm/tasks/textcat/util.py index 0b7d8689..bd88c8aa 100644 --- a/spacy_llm/tasks/textcat/util.py +++ b/spacy_llm/tasks/textcat/util.py @@ -3,10 +3,11 @@ from spacy.scorer import Scorer from spacy.training import Example -from ...compat import BaseModel, Self +from ...compat import Self +from ...ty import FewshotExample -class TextCatExample(BaseModel): +class TextCatExample(FewshotExample): text: str answer: str @@ -27,7 +28,7 @@ def generate(cls, example: Example, **kwargs) -> Self: ] ) - return TextCatExample( + return cls( text=example.reference.text, answer=answer, ) diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 0673b1d0..d7af596d 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -58,6 +58,12 @@ def from_disk( class FewshotExample(abc.ABC, BaseModel): + """Base fewshot-example. + From Python 3.7 onwards it's possible to make Pydantic models generic, which allows for a clean solution (see + https://github.com/pydantic/pydantic/issues/4171) using the controller pattern and Pydantic's GenericModel + (BaseModel in Pydantic v2). Until then passing **kwargs seems like the sanest option. + """ + @classmethod @abc.abstractmethod def generate(cls, example: Example, **kwargs) -> Self: @@ -94,7 +100,7 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: @runtime_checkable class LLMTask(Protocol): - def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[_PromptType]: + def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[_PromptType]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. RETURNS (Iterable[_PromptType]): Iterable with one prompt per doc.