Skip to content

Commit

Permalink
Reverse control flow, ditch kwargs for generate_prompts().
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Sep 20, 2023
1 parent bf3a000 commit 75fa706
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 40 deletions.
24 changes: 20 additions & 4 deletions spacy_llm/tasks/builtin_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jinja2
import srsly
from spacy import Language, util, Errors
from spacy import Errors, Language, util
from spacy.tokens import Doc
from spacy.training import Example

Expand Down Expand Up @@ -45,19 +45,35 @@ def __init__(
self._template = template
self._prompt_example_type = prompt_example_type

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[Any]:
def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]:
"""Generate prompts from docs.
docs (Iterable[Doc]): Docs to generate prompts from.
RETURNS (Iterable[Any]): Iterable with one prompt per doc.
"""
environment = jinja2.Environment()
_template = environment.from_string(self._template)
for doc in docs:
for doc in self._preprocess_docs_for_prompt(docs):
prompt = _template.render(
text=doc.text, prompt_examples=self._prompt_examples, **kwargs
text=doc.text,
prompt_examples=self._prompt_examples,
**self._prompt_data,
)
yield prompt

@property
def _prompt_data(self) -> Dict[str, Any]:
"""Returns data injected into prompt template. No-op if not overridden by inheriting task class.
RETURNS (Dict[str, Any]): Data injected into prompt template.
"""
return {}

def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
"""Preprocesses docs before injection into prompt template. No-op if not overridden by inheriting task class.
docs (Iterable[Doc]): Docs to generate prompts from.
RETURNS (Iterable[Doc]): Preprocessed docs.
"""
return docs

@abc.abstractmethod
def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[Any]
Expand Down
21 changes: 11 additions & 10 deletions spacy_llm/tasks/rel/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

from spacy.language import Language
from spacy.tokens import Doc
Expand Down Expand Up @@ -52,15 +52,16 @@ def __init__(
self._verbose = verbose
self._field = "rel"

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(
docs=[
Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs
],
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
preannotate=RELTask._preannotate,
)
def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
return [Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs]

@property
def _prompt_data(self) -> Dict[str, Any]:
return {
"labels": list(self._label_dict.values()),
"label_definitions": self._label_definitions,
"preannotate": RELTask._preannotate,
}

@staticmethod
def _preannotate(doc: Union[Doc, RELExample]) -> str:
Expand Down
3 changes: 0 additions & 3 deletions spacy_llm/tasks/sentiment/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def initialize(
get_examples=get_examples, nlp=nlp, n_prompt_examples=n_prompt_examples
)

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(docs=docs)

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
) -> Iterable[Doc]:
Expand Down
21 changes: 10 additions & 11 deletions spacy_llm/tasks/span/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import typing
from typing import Callable, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

from spacy.tokens import Doc, Span

Expand Down Expand Up @@ -64,16 +64,15 @@ def __init__(
if self._prompt_examples:
self._prompt_examples = list(self._check_label_consistency(self))

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(
docs=docs,
description=self._description,
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
examples=self._prompt_examples,
allow_overlap=self._allow_overlap,
**kwargs,
)
@property
def _prompt_data(self) -> Dict[str, Any]:
return {
"description": self._description,
"labels": list(self._label_dict.values()),
"label_definitions": self._label_definitions,
"examples": self._prompt_examples,
"allow_overlap": self._allow_overlap,
}

@staticmethod
def _validate_alignment(alignment_mode: str):
Expand Down
10 changes: 7 additions & 3 deletions spacy_llm/tasks/summarization/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Iterable, List, Optional, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Type

from spacy.language import Language
from spacy.tokens import Doc
Expand Down Expand Up @@ -78,12 +78,16 @@ def _check_prompt_example_summary_len(self) -> None:
f"LLM will likely produce responses that are too long."
)

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
@property
def _prompt_data(self) -> Dict[str, Any]:
"""Returns data injected into prompt template. No-op if not overridden by inheriting task class.
RETURNS (Dict[str, Any]): Data injected into prompt template.
"""
if self._check_example_summaries:
self._check_prompt_example_summary_len()
self._check_example_summaries = False

return super().generate_prompts(docs=docs, max_n_words=self._max_n_words)
return {"max_n_words": self._max_n_words}

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
Expand Down
16 changes: 8 additions & 8 deletions spacy_llm/tasks/textcat/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def __init__(
)
self._exclusive_classes = True

def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]:
return super().generate_prompts(
docs=docs,
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
exclusive_classes=self._exclusive_classes,
allow_none=self._allow_none,
)
@property
def _prompt_data(self) -> Dict[str, Any]:
return {
"labels": list(self._label_dict.values()),
"label_definitions": self._label_definitions,
"exclusive_classes": self._exclusive_classes,
"allow_none": self._allow_none,
}

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
Expand Down
2 changes: 1 addition & 1 deletion spacy_llm/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:

@runtime_checkable
class LLMTask(Protocol):
def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[_PromptType]:
def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[_PromptType]:
"""Generate prompts from docs.
docs (Iterable[Doc]): Docs to generate prompts from.
RETURNS (Iterable[_PromptType]): Iterable with one prompt per doc.
Expand Down

0 comments on commit 75fa706

Please sign in to comment.