diff --git a/spacy_llm/tasks/span/examples.py b/spacy_llm/tasks/span/examples.py index 2e826046..c61f4a19 100644 --- a/spacy_llm/tasks/span/examples.py +++ b/spacy_llm/tasks/span/examples.py @@ -5,11 +5,14 @@ from spacy.tokens import Span from ...compat import Self -from ...ty import FewshotExample -from .task import SpanTaskContraT +from ...ty import FewshotExample, TaskContraT -class SpanExample(FewshotExample[SpanTaskContraT], abc.ABC, Generic[SpanTaskContraT]): +class SpanExample(FewshotExample[TaskContraT], abc.ABC, Generic[TaskContraT]): + """Example for span tasks not using CoT. + Note: this should be SpanTaskContraT instead of TaskContraT, but this would entail a circular import. + """ + text: str entities: Dict[str, List[str]] @@ -68,9 +71,11 @@ def __str__(self) -> str: return self.to_str() -class SpanCoTExample( - FewshotExample[SpanTaskContraT], abc.ABC, Generic[SpanTaskContraT] -): +class SpanCoTExample(FewshotExample[TaskContraT], abc.ABC, Generic[TaskContraT]): + """Example for span tasks using CoT. + Note: this should be SpanTaskContraT instead of TaskContraT, but this would entail a circular import. + """ + text: str spans: List[SpanReason] diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index a8556ca1..1df0355c 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -1,17 +1,18 @@ import abc from typing import Callable, Dict, Iterable, List, Optional, Protocol, Type, TypeVar -from typing import Union +from typing import Union, cast from spacy.tokens import Doc, Span from ...compat import Literal, Self from ...ty import FewshotExample, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels +from . import SpanExample +from .examples import SpanCoTExample SpanTaskContraT = TypeVar("SpanTaskContraT", bound="SpanTask", contravariant=True) # todo type with spanexample/spancotexample instead of fewshot example -# -> circular import? class SpanTaskLabelCheck(Protocol[SpanTaskContraT]): @@ -27,18 +28,20 @@ class SpanTask(BuiltinTaskWithLabels, abc.ABC): def __init__( self, parse_responses: TaskResponseParser[Self], - prompt_example_type: Type[FewshotExample[Self]], + prompt_example_type: Type[Union[SpanExample[Self], SpanCoTExample[Self]]], labels: List[str], template: str, label_definitions: Optional[Dict[str, str]], - prompt_examples: Optional[List[FewshotExample[Self]]], + prompt_examples: Optional[ + Union[List[SpanExample[Self]], List[SpanCoTExample[Self]]] + ], description: Optional[str], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], # noqa: F821 case_sensitive_matching: bool, allow_overlap: bool, single_match: bool, - check_label_consistency: SpanTaskLabelCheck, + check_label_consistency: SpanTaskLabelCheck[Self], ): super().__init__( parse_responses=parse_responses, @@ -50,6 +53,10 @@ def __init__( normalizer=normalizer, ) + self._prompt_example_type = cast( + Type[Union[SpanExample[Self], SpanCoTExample[Self]]], + self._prompt_example_type, + ) self._validate_alignment(alignment_mode) self._alignment_mode = alignment_mode self._case_sensitive_matching = case_sensitive_matching @@ -126,7 +133,9 @@ def prompt_examples(self) -> Optional[Iterable[FewshotExample]]: return self._prompt_examples @property - def prompt_example_type(self) -> Union[Type[FewshotExample[Self]]]: + def prompt_example_type( + self, + ) -> Type[Union[SpanExample[Self], SpanCoTExample[Self]]]: return self._prompt_example_type @property