Skip to content

Commit

Permalink
Fix TextCatExample.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Sep 21, 2023
1 parent 04675a7 commit 52e712a
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions spacy_llm/tasks/textcat/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
from spacy.scorer import Scorer
from spacy.training import Example

from ...compat import BaseModel, Self
from ...compat import Self
from ...ty import FewshotExample


class TextCatExample(BaseModel):
class TextCatExample(FewshotExample):
text: str
answer: str

@classmethod
def generate(
cls, example: Example, use_binary: bool, label_dict: Dict[str, str], **kwargs
) -> Self:
if use_binary:
def generate(cls, example: Example, **kwargs) -> Self:
if kwargs["use_binary"]:
answer = (
"POS"
if example.reference.cats[list(label_dict.values())[0]] == 1.0
if example.reference.cats[list(kwargs["label_dict"].values())[0]] == 1.0
else "NEG"
)
else:
Expand All @@ -29,7 +28,7 @@ def generate(
]
)

return TextCatExample(
return cls(
text=example.reference.text,
answer=answer,
)
Expand Down

0 comments on commit 52e712a

Please sign in to comment.