Skip to content

Commit

Permalink
Use span-specific example types in SpanTask.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Sep 26, 2023
1 parent 35a7952 commit d8e0ff1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
17 changes: 11 additions & 6 deletions spacy_llm/tasks/span/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from spacy.tokens import Span

from ...compat import Self
from ...ty import FewshotExample
from .task import SpanTaskContraT
from ...ty import FewshotExample, TaskContraT


class SpanExample(FewshotExample[SpanTaskContraT], abc.ABC, Generic[SpanTaskContraT]):
class SpanExample(FewshotExample[TaskContraT], abc.ABC, Generic[TaskContraT]):
"""Example for span tasks not using CoT.
Note: this should be SpanTaskContraT instead of TaskContraT, but this would entail a circular import.
"""

text: str
entities: Dict[str, List[str]]

Expand Down Expand Up @@ -68,9 +71,11 @@ def __str__(self) -> str:
return self.to_str()


class SpanCoTExample(
FewshotExample[SpanTaskContraT], abc.ABC, Generic[SpanTaskContraT]
):
class SpanCoTExample(FewshotExample[TaskContraT], abc.ABC, Generic[TaskContraT]):
"""Example for span tasks using CoT.
Note: this should be SpanTaskContraT instead of TaskContraT, but this would entail a circular import.
"""

text: str
spans: List[SpanReason]

Expand Down
21 changes: 15 additions & 6 deletions spacy_llm/tasks/span/task.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import abc
from typing import Callable, Dict, Iterable, List, Optional, Protocol, Type, TypeVar
from typing import Union
from typing import Union, cast

from spacy.tokens import Doc, Span

from ...compat import Literal, Self
from ...ty import FewshotExample, TaskResponseParser
from ..builtin_task import BuiltinTaskWithLabels
from . import SpanExample
from .examples import SpanCoTExample

SpanTaskContraT = TypeVar("SpanTaskContraT", bound="SpanTask", contravariant=True)

# todo type with spanexample/spancotexample instead of fewshot example
# -> circular import?


class SpanTaskLabelCheck(Protocol[SpanTaskContraT]):
Expand All @@ -27,18 +28,20 @@ class SpanTask(BuiltinTaskWithLabels, abc.ABC):
def __init__(
self,
parse_responses: TaskResponseParser[Self],
prompt_example_type: Type[FewshotExample[Self]],
prompt_example_type: Type[Union[SpanExample[Self], SpanCoTExample[Self]]],
labels: List[str],
template: str,
label_definitions: Optional[Dict[str, str]],
prompt_examples: Optional[List[FewshotExample[Self]]],
prompt_examples: Optional[
Union[List[SpanExample[Self]], List[SpanCoTExample[Self]]]
],
description: Optional[str],
normalizer: Optional[Callable[[str], str]],
alignment_mode: Literal["strict", "contract", "expand"], # noqa: F821
case_sensitive_matching: bool,
allow_overlap: bool,
single_match: bool,
check_label_consistency: SpanTaskLabelCheck,
check_label_consistency: SpanTaskLabelCheck[Self],
):
super().__init__(
parse_responses=parse_responses,
Expand All @@ -50,6 +53,10 @@ def __init__(
normalizer=normalizer,
)

self._prompt_example_type = cast(
Type[Union[SpanExample[Self], SpanCoTExample[Self]]],
self._prompt_example_type,
)
self._validate_alignment(alignment_mode)
self._alignment_mode = alignment_mode
self._case_sensitive_matching = case_sensitive_matching
Expand Down Expand Up @@ -126,7 +133,9 @@ def prompt_examples(self) -> Optional[Iterable[FewshotExample]]:
return self._prompt_examples

@property
def prompt_example_type(self) -> Union[Type[FewshotExample[Self]]]:
def prompt_example_type(
self,
) -> Type[Union[SpanExample[Self], SpanCoTExample[Self]]]:
return self._prompt_example_type

@property
Expand Down

0 comments on commit d8e0ff1

Please sign in to comment.