From 07a4f2b85989fc44da0b155af876b12f38fc6f97 Mon Sep 17 00:00:00 2001 From: rehan Date: Thu, 20 Jul 2023 22:38:20 -0600 Subject: [PATCH 01/33] Adding span_srl task with tests, usage and documentation --- README.md | 27 ++ spacy_llm/tasks/__init__.py | 3 + spacy_llm/tasks/srl_task.py | 371 ++++++++++++++++++ spacy_llm/tasks/templates/span-srl.v1.jinja | 24 ++ spacy_llm/tests/tasks/test_span_srl.py | 83 ++++ usage_examples/span_srl_openai/README.md | 32 ++ usage_examples/span_srl_openai/__init__.py | 3 + .../span_srl_openai/run_pipeline.py | 48 +++ usage_examples/span_srl_openai/zeroshot.cfg | 19 + 9 files changed, 610 insertions(+) create mode 100644 spacy_llm/tasks/srl_task.py create mode 100644 spacy_llm/tasks/templates/span-srl.v1.jinja create mode 100644 spacy_llm/tests/tasks/test_span_srl.py create mode 100644 usage_examples/span_srl_openai/README.md create mode 100644 usage_examples/span_srl_openai/__init__.py create mode 100644 usage_examples/span_srl_openai/run_pipeline.py create mode 100644 usage_examples/span_srl_openai/zeroshot.cfg diff --git a/README.md b/README.md index dc9eab6c..a69b35e4 100644 --- a/README.md +++ b/README.md @@ -785,6 +785,33 @@ Note: the REL task relies on pre-extracted entities to make its prediction. Hence, you'll need to add a component that populates `doc.ents` with recognized spans to your spaCy pipeline and put it _before_ the REL component. +#### spacy.SRL.v1 + +The built-in Semantic Role Labeling (SRL) task supports zero-shot prompting. +The prompt contains two steps: + 1. Predicate Identification +2. Semantic Role Identification for each Predicate + +We only focus on the 4 important semantic roles: +1. ARG-0: Typical Agent +2. ARG-1: Typical Patient or Theme +3. ARG-M-TMP: Temporal Modifier +4. ARG-M-LOC: Location Modifier +```ini +[components.llm.task] +@llm_tasks = "spacy.SRL.v1" +labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC +``` + +| Argument | Type | Default | Description | +| ------------------- | --------------------------------------- | ---------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| `labels` | `Union[List[str], str]` | | List of labels or str of comma-separated list of labels. | +| `template` | `str` | [`rel.jinja`](./spacy_llm/tasks/templates/rel.jinja) | Custom prompt template to send to LLM model. Default templates for each task are located in the `spacy_llm/tasks/templates` directory. | +| `label_description` | `Optional[Dict[str, str]]` | `None` | Dictionary providing a description for each relation label. | +| `normalizer` | `Optional[Callable[[str], str]]` | `None` | Function that normalizes the labels as returned by the LLM. If `None`, falls back to `spacy.LowercaseNormalizer.v1`. | +| `verbose` | `bool` | `False` | If set to `True`, warnings will be generated when the LLM returns invalid responses. | + + #### spacy.Lemma.v1 The `Lemma.v1` task lemmatizes the provided text and updates the `lemma_` attribute in the doc's tokens accordingly. diff --git a/spacy_llm/tasks/__init__.py b/spacy_llm/tasks/__init__.py index 24f7609d..61f2de4a 100644 --- a/spacy_llm/tasks/__init__.py +++ b/spacy_llm/tasks/__init__.py @@ -6,6 +6,7 @@ from .spancat import SpanCatTask, make_spancat_task, make_spancat_task_v2 from .summarization import SummarizationTask, make_summarization_task from .textcat import TextCatTask, make_textcat_task +from .srl_task import SRLTask, make_srl_task __all__ = [ "make_lemma_task", @@ -18,6 +19,7 @@ "make_spancat_task_v2", "make_summarization_task", "make_textcat_task", + "make_srl_task", "LemmaTask", "NERTask", "NoopTask", @@ -26,4 +28,5 @@ "SpanCatTask", "SummarizationTask", "TextCatTask", + "SRLTask" ] diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py new file mode 100644 index 00000000..613da88a --- /dev/null +++ b/spacy_llm/tasks/srl_task.py @@ -0,0 +1,371 @@ +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, Any, Literal + +import re +import jinja2 +from wasabi import msg +from spacy.tokens import Doc +from spacy.training import Example +from .util.serialization import ExampleType +from ..registry import lowercase_normalizer, registry +from ..util import split_labels +from .util import SerializableTask +from .util.parsing import find_substrings +from collections import defaultdict +from pydantic import BaseModel +from pathlib import Path +from .templates import read_template +from spacy import Language + +_DEFAULT_SPAN_SRL_TEMPLATE_V1 = read_template("span-srl.v1") + + +class SpanItem(BaseModel): + text: str + start_char: int + end_char: int + + def __hash__(self): + return hash((self.text, self.start_char, self.end_char)) + + +class PredicateItem(SpanItem): + roleset_id: str = '' + + def __hash__(self): + return hash((self.text, self.start_char, self.end_char, self.roleset_id)) + + +class ArgRELItem(BaseModel): + predicate: PredicateItem + role: SpanItem + label: str + + def __hash__(self): + return hash((self.predicate, self.role, self.label)) + + +class SRLExample(BaseModel): + text: str + predicates: List[PredicateItem] + relations: List[ArgRELItem] + + +def _preannotate(doc: Union[Doc, SRLExample]) -> str: + """Creates a text version of the document with list of provided predicates.""" + + text = doc.text + preds = ', '.join([s.text for s in doc.predicates]) + + formatted_text = f"{text}\nPredicates: {preds}" + + return formatted_text + + +def score_srl_spans( + examples: Iterable[Example], +) -> Dict[str, Any]: + pred_predicates_spans = set() + gold_predicates_spans = set() + + pred_relation_tuples = set() + gold_relation_tuples = set() + + for i, eg in enumerate(examples): + pred_doc = eg.predicted + gold_doc = eg.reference + + pred_predicates_spans.update([(i, PredicateItem(**dict(p))) for p in pred_doc._.predicates]) + gold_predicates_spans.update([(i, PredicateItem(**dict(p))) for p in gold_doc._.predicates]) + + pred_relation_tuples.update([(i, ArgRELItem(**dict(r))) for r in pred_doc._.relations]) + gold_relation_tuples.update([(i, ArgRELItem(**dict(r))) for r in gold_doc._.relations]) + + def _overlap_prf(gold: set, pred: set): + overlap = gold.intersection(pred) + p = len(overlap)/len(pred) + r = len(overlap)/len(gold) + f = 2*p*r/(p+r) + return p, r, f + + predicates_prf = _overlap_prf(gold_predicates_spans, pred_predicates_spans) + micro_rel_prf = _overlap_prf(gold_relation_tuples, pred_relation_tuples) + + def _get_label2rels(rel_tuples: Iterable[Tuple[int, ArgRELItem]]): + label2rels = defaultdict(set) + for tup in rel_tuples: + label_ = tup[1].label + label2rels[label_].add(tup) + return label2rels + + pred_label2relations = _get_label2rels(pred_relation_tuples) + gold_label2relations = _get_label2rels(gold_relation_tuples) + + all_labels = set.union(set(pred_label2relations.keys()), set(gold_label2relations.keys())) + label2prf = {} + for label in all_labels: + pred_label_rels = pred_label2relations[label] + gold_label_rels = gold_label2relations[label] + label2prf[label] = _overlap_prf(gold_label_rels, pred_label_rels) + + return { + 'Predicates': predicates_prf, + 'ARGs': { + 'Overall': micro_rel_prf, + 'PerLabel': label2prf + } + } + + +@registry.llm_tasks("spacy.SRL.v1") +def make_srl_task( + labels: str, + template: str = _DEFAULT_SPAN_SRL_TEMPLATE_V1, + label_definitions: Optional[Dict[str, str]] = None, + examples: Optional[Callable[[], Iterable[Any]]] = None, + normalizer: Optional[Callable[[str], str]] = None, + alignment_mode: Literal["strict", "contract", "expand"] = "contract", + case_sensitive_matching: bool = False, + single_match: bool = False, + verbose: bool = False, + predicate_key: str = 'Predicate' +): + """SRL.v1 task factory. + + labels (str): Comma-separated list of labels to pass to the template. + Leave empty to populate it at initialization time (only if examples are provided). + template (str): Prompt template passed to the model. + label_definitions (Optional[Dict[str, str]]): Map of label -> description + of the label to help the language model output the entities wanted. + It is usually easier to provide these definitions rather than + full examples, although both can be provided. + examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that + reads a file containing task examples for few-shot learning. If None is + passed, then zero-shot learning will be used. + normalizer (Optional[Callable[[str], str]]): optional normalizer function. + alignment_mode (str): "strict", "contract" or "expand". + case_sensitive_matching: Whether to search without case sensitivity. + single_match (bool): If False, allow one substring to match multiple times in + the text. If True, returns the first hit. + verbose (boole): Verbose ot not + predicate_key: The str of Predicate in the template + """ + labels_list = split_labels(labels) + raw_examples = examples() if callable(examples) else examples + rel_examples = [SRLExample(**eg) for eg in raw_examples] if raw_examples else None + return SRLTask( + labels=labels_list, + template=template, + label_definitions=label_definitions, + examples=rel_examples, + normalizer=normalizer, + verbose=verbose, + alignment_mode=alignment_mode, + case_sensitive_matching=case_sensitive_matching, + single_match=single_match, + predicate_key=predicate_key, + ) + + +class SRLTask(SerializableTask[SRLExample]): + @property + def _Example(self) -> Type[SRLExample]: + return SRLExample + + @property + def _cfg_keys(self) -> List[str]: + return [ + "_label_dict", + "_template", + "_label_definitions", + "_verbose", + "_predicate_key", + "_alignment_mode", + "_case_sensitive_matching", + "_single_match", + ] + + def __init__( + self, + labels: List[str] = [], + template: str = _DEFAULT_SPAN_SRL_TEMPLATE_V1, + label_definitions: Optional[Dict[str, str]] = None, + examples: Optional[List[SRLExample]] = None, + normalizer: Optional[Callable[[str], str]] = None, + verbose: bool = False, + predicate_key: str = 'Predicate', + alignment_mode: Literal[ + "strict", "contract", "expand" # noqa: F821 + ] = "contract", + case_sensitive_matching: bool = False, + single_match: bool = False, + ): + self._normalizer = normalizer if normalizer else lowercase_normalizer() + self._label_dict = {self._normalizer(label): label for label in labels} + self._template = template + self._label_definitions = label_definitions + self._examples = examples + self._verbose = verbose + self._validate_alignment(alignment_mode) + self._alignment_mode = alignment_mode + self._case_sensitive_matching = case_sensitive_matching + self._single_match = single_match + self._predicate_key = predicate_key + self._check_extensions() + + @classmethod + def _check_extensions(cls): + """Add `predicates` extension if need be. + Add `relations` extension if need be.""" + + if not Doc.has_extension("predicates"): + Doc.set_extension("predicates", default=[]) + + if not Doc.has_extension("relations"): + Doc.set_extension("relations", default=[]) + + @staticmethod + def _validate_alignment(alignment_mode: str): + """Raises error if specified alignment_mode is not supported. + alignment_mode (str): Alignment mode to check. + """ + # ideally, this list should be taken from spaCy, but it's not currently exposed from doc.pyx. + alignment_modes = ("strict", "contract", "expand") + if alignment_mode not in alignment_modes: + raise ValueError( + f"Unsupported alignment mode '{alignment_mode}'. Supported modes: {', '.join(alignment_modes)}" + ) + + def initialize( + self, + get_examples: Callable[[], Iterable["Example"]], + nlp: Language, + labels: List[str] = [], + ) -> None: + """Initialize the task, by auto-discovering labels. + + Labels can be set through, by order of precedence: + + - the `[initialize]` section of the pipeline configuration + - the `labels` argument supplied to the task factory + - the labels found in the examples + + get_examples (Callable[[], Iterable["Example"]]): Callable that provides examples + for initialization. + nlp (Language): Language instance. + labels (List[str]): Optional list of labels. + """ + self._check_extensions() + + examples = get_examples() + + if not labels: + labels = list(self._label_dict.values()) + + if not labels: + label_set = set() + + for eg in examples: + rels: List[ArgRELItem] = eg.reference._.relations + for rel in rels: + label_set.add(rel.label) + labels = list(label_set) + + self._label_dict = {self._normalizer(label): label for label in labels} + + @property + def labels(self) -> Tuple[str, ...]: + return tuple(self._label_dict.values()) + + @property + def prompt_template(self) -> str: + return self._template + + def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: + environment = jinja2.Environment() + _template = environment.from_string(self._template) + for doc in docs: + prompt = _template.render( + text=doc.text, + labels=list(self._label_dict.values()), + label_definitions=self._label_definitions, + ) + yield prompt + + def _format_response(self, arg_lines): + """Parse raw string response into a structured format""" + output = [] + # this ensures unique arguments in the sentence for a predicate + found_labels = set() + for line in arg_lines: + if line.strip() and ':' in line: + label, phrase = line.strip().split(':', 1) + + # label is of the form "ARG-n (def)" + label = label.split('(')[0].strip() + + # strip any surrounding quotes + phrase = phrase.strip('\'" ') + + norm_label = self._normalizer(label) + if norm_label in self._label_dict and norm_label not in found_labels: + if phrase.strip(): + _phrase = phrase.strip() + found_labels.add(norm_label) + output.append((self._label_dict[norm_label], _phrase)) + return output + + def parse_responses( + self, docs: Iterable[Doc], responses: Iterable[str] + ) -> Iterable[Doc]: + for doc, prompt_response in zip(docs, responses): + predicates = [] + relations = [] + lines = prompt_response.split('\n') + + # match lines that start with {Predicate:, Predicate 1:, Predicate1:} + # assuming self.predicate_key = 'Predicate' + pred_patt = r'^' + re.escape(self._predicate_key) + r'\b\s*\d*[:\-\s]' + pred_indices, pred_lines = zip(*[(i, line) for i, line in enumerate(lines) if re.search(pred_patt, line)]) + + pred_indices = list(pred_indices) + + # extract the predicate strings + pred_strings = [line.split(":", 1)[1].strip('\'" ') for line in pred_lines] + + # extract the line ranges (s, e) of predicate's content. + # then extract the pred content lines using the ranges + pred_indices.append(len(lines)) + pred_ranges = zip(pred_indices[:-1], pred_indices[1:]) + pred_contents = [lines[s:e] for s, e in pred_ranges] + + # assign the spans of the predicates and args + # then create ArgRELItem from the identified predicates and arguments + for pred_str, pred_content_lines in zip(pred_strings, pred_contents): + pred_offsets = list(find_substrings( + doc.text, + [pred_str], + case_sensitive=True, + single_match=True + )) + + # ignore the args if the predicate is not found + if len(pred_offsets): + p_start_char, p_end_char = pred_offsets[0] + pred_item = PredicateItem(text=pred_str, start_char=p_start_char, end_char=p_end_char) + predicates.append(pred_item.dict()) + + for label, phrase in self._format_response(pred_content_lines): + arg_offsets = find_substrings( + doc.text, + [phrase], + case_sensitive=self._case_sensitive_matching, + single_match=self._single_match + ) + for start, end in arg_offsets: + arg_item = SpanItem(text=phrase, start_char=start, end_char=end).dict() + arg_rel_item = ArgRELItem(predicate=pred_item, role=arg_item, label=label).dict() + relations.append(arg_rel_item) + + doc._.predicates = predicates + doc._.relations = relations + yield doc diff --git a/spacy_llm/tasks/templates/span-srl.v1.jinja b/spacy_llm/tasks/templates/span-srl.v1.jinja new file mode 100644 index 00000000..613d37f2 --- /dev/null +++ b/spacy_llm/tasks/templates/span-srl.v1.jinja @@ -0,0 +1,24 @@ +You are an expert Semantic Role Labeling (SRL) system. Your task is to accept Text as input and extract \ +the Predicates and the Semantic Roles for each Predicate's ARGs in a step-by-step manner. +{# whitespace #} +Here is the text that needs labeling: +{# whitespace #} +Text: +''' +{{text}} +''' +{# whitespace #} +{%- if predicates -%} +Step 1: Use the following Predicates for the Text: +Predicates: +{%- else -%} +Step 1: Extract the Predicates for the Text in the following format : +Predicates: +{%- endif -%} +{# whitespace #} +Step 2: For each Predicate, extract only the following Sematic Roles in '''Text''' in this format : +Predicate: +ARG-0 (Agent): +ARG-1 (Patient or Theme): +ARG-M-LOC (Location Modifier): +ARG-M-TMP (Temporal Modifier): diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py new file mode 100644 index 00000000..62d4a5b3 --- /dev/null +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -0,0 +1,83 @@ +from pathlib import Path + +import pytest +from confection import Config +from pytest import FixtureRequest +from spacy_llm.pipeline import LLMWrapper +from spacy_llm.tasks.srl_task import _DEFAULT_SPAN_SRL_TEMPLATE_V1, ArgRELItem, PredicateItem, SpanItem +from spacy_llm.ty import Labeled, LLMTask +from spacy_llm.util import assemble_from_config, split_labels + +from spacy_llm.tests.compat import has_openai_key + +EXAMPLES_DIR = Path(__file__).parent / "examples" + + +@pytest.fixture +def zeroshot_cfg_string(): + return """ + [paths] + examples = null + + [nlp] + lang = "en" + pipeline = ["llm"] + + [components] + + [components.llm] + factory = "llm" + + [components.llm.task] + @llm_tasks = "spacy.SRL.v1" + labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + + [components.llm.model] + @llm_models = "spacy.GPT-3-5.v1" + """ + + +@pytest.fixture +def task(): + text = "We love this sentence in Berlin right now ." + gold_relations = [] + return text, gold_relations + + +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string"]) +def test_rel_config(cfg_string, request: FixtureRequest): + """Simple test to check if the config loads properly given different settings""" + cfg_string = request.getfixturevalue(cfg_string) + orig_config = Config().from_str(cfg_string) + nlp = assemble_from_config(orig_config) + assert nlp.pipe_names == ["llm"] + + pipe = nlp.get_pipe("llm") + assert isinstance(pipe, LLMWrapper) + assert isinstance(pipe.task, LLMTask) + + task = pipe.task + labels = orig_config["components"]["llm"]["task"]["labels"] + labels = split_labels(labels) + assert isinstance(task, Labeled) + assert task.labels == tuple(labels) + assert set(pipe.labels) == set(task.labels) + assert nlp.pipe_labels["llm"] == list(task.labels) + + +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string"]) # "zeroshot_cfg_string", +def test_rel_predict(task, cfg_string, request): + """Use OpenAI to get REL results. + Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable + """ + cfg_string = request.getfixturevalue(cfg_string) + orig_config = Config().from_str(cfg_string) + nlp = assemble_from_config(orig_config) + + text, _ = task + doc = nlp(text) + + assert doc._.predicates + assert doc._.relations \ No newline at end of file diff --git a/usage_examples/span_srl_openai/README.md b/usage_examples/span_srl_openai/README.md new file mode 100644 index 00000000..ac65ae3e --- /dev/null +++ b/usage_examples/span_srl_openai/README.md @@ -0,0 +1,32 @@ +# Semantic Role Labeling (SRL) using LLMs + +This example shows how you can use a model from OpenAI for SRL in +zero- and few-shot settings. + + +We leverage the OpenAI API to detect the predicates and argument roles in a sentence. +In the example below, we focus on the predicate "bought" and ARG-0, ARG-1, and ARG-M-LOC. + +First, create a new API key from +[openai.com](https://platform.openai.com/account/api-keys) or fetch an existing +one. Record the secret key and make sure this is available as an environmental +variable: + +```sh +export OPENAI_API_KEY="sk-..." +export OPENAI_API_ORG="org-..." +``` + +Then, you can run the pipeline on a sample text via: + +```sh +python run_pipeline.py [TEXT] [PATH TO CONFIG] +``` + +For example: + +```sh +python run_pipeline.py \ + "Laura just bought an apartment in Boston." \ + ./zeroshot.cfg +``` \ No newline at end of file diff --git a/usage_examples/span_srl_openai/__init__.py b/usage_examples/span_srl_openai/__init__.py new file mode 100644 index 00000000..39687454 --- /dev/null +++ b/usage_examples/span_srl_openai/__init__.py @@ -0,0 +1,3 @@ +from .run_pipeline import run_pipeline + +__all__ = ["run_pipeline"] \ No newline at end of file diff --git a/usage_examples/span_srl_openai/run_pipeline.py b/usage_examples/span_srl_openai/run_pipeline.py new file mode 100644 index 00000000..e358d729 --- /dev/null +++ b/usage_examples/span_srl_openai/run_pipeline.py @@ -0,0 +1,48 @@ +import os +from pathlib import Path +from typing import Optional + +import typer +from wasabi import msg + +from spacy_llm.util import assemble + +Arg = typer.Argument +Opt = typer.Option + + +def run_pipeline( + # fmt: off + text: str = Arg("", help="Text to perform text categorization on."), + config_path: Path = Arg(..., help="Path to the configuration file to use."), + examples_path: Optional[Path] = Arg(None, help="Path to the examples file to use (few-shot only)."), + verbose: bool = Opt(False, "--verbose", "-v", help="Show extra information."), + # fmt: on +): + if not os.getenv("OPENAI_API_KEY", None): + msg.fail( + "OPENAI_API_KEY env variable was not found. " + "Set it by running 'export OPENAI_API_KEY=...' and try again.", + exits=1, + ) + + msg.text(f"Loading config from {config_path}", show=verbose) + nlp = assemble( + config_path, + overrides={} + if examples_path is None + else {"paths.examples": str(examples_path)}, + ) + + doc = nlp(text) + + msg.text(f"Text: {doc.text}") + msg.text(f"Predicates: {[p['text'] for p in doc._.predicates]}") + + msg.text("Relations:") + for r in doc._.relations: + msg.text(f" - {r['predicate']['text']} [{r['label']}] {r['role']['text']}") + + +if __name__ == "__main__": + typer.run(run_pipeline) \ No newline at end of file diff --git a/usage_examples/span_srl_openai/zeroshot.cfg b/usage_examples/span_srl_openai/zeroshot.cfg new file mode 100644 index 00000000..570ac1bb --- /dev/null +++ b/usage_examples/span_srl_openai/zeroshot.cfg @@ -0,0 +1,19 @@ +[paths] +examples = null + +[nlp] +lang = "en" +pipeline = ["llm"] + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.SRL.v1" +labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + +[components.llm.model] +@llm_models = "spacy.GPT-3-5.v1" + From 22cba55cfc8093a54fa8b4ff5e7aa3883bc384ab Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 11:30:30 -0600 Subject: [PATCH 02/33] Fixing minor issues --- spacy_llm/tasks/srl_task.py | 66 +++++++++++++-------- spacy_llm/tasks/templates/span-srl.v1.jinja | 20 ++++--- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 613da88a..f5fff4ac 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -1,20 +1,21 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, Any, Literal import re + +from collections import defaultdict import jinja2 -from wasabi import msg +from pydantic import BaseModel, Field, ValidationError, validator +from spacy.language import Language from spacy.tokens import Doc from spacy.training import Example -from .util.serialization import ExampleType +from wasabi import msg + from ..registry import lowercase_normalizer, registry +from ..ty import ExamplesConfigType from ..util import split_labels +from .templates import read_template from .util import SerializableTask from .util.parsing import find_substrings -from collections import defaultdict -from pydantic import BaseModel -from pathlib import Path -from .templates import read_template -from spacy import Language _DEFAULT_SPAN_SRL_TEMPLATE_V1 = read_template("span-srl.v1") @@ -82,9 +83,9 @@ def score_srl_spans( def _overlap_prf(gold: set, pred: set): overlap = gold.intersection(pred) - p = len(overlap)/len(pred) - r = len(overlap)/len(gold) - f = 2*p*r/(p+r) + p = 0. if not len(pred) else len(overlap)/len(pred) + r = 0. if not len(gold) else len(overlap)/len(gold) + f = 0. if not p or not r else 2*p*r/(p+r) return p, r, f predicates_prf = _overlap_prf(gold_predicates_spans, pred_predicates_spans) @@ -121,7 +122,7 @@ def make_srl_task( labels: str, template: str = _DEFAULT_SPAN_SRL_TEMPLATE_V1, label_definitions: Optional[Dict[str, str]] = None, - examples: Optional[Callable[[], Iterable[Any]]] = None, + examples: ExamplesConfigType = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", case_sensitive_matching: bool = False, @@ -284,11 +285,17 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: environment = jinja2.Environment() _template = environment.from_string(self._template) for doc in docs: + predicates = None + if len(doc._.predicates): + predicates = ','.join([p['text'] for p in doc._.predicates]) + prompt = _template.render( text=doc.text, labels=list(self._label_dict.values()), label_definitions=self._label_definitions, + predicates=predicates ) + yield prompt def _format_response(self, arg_lines): @@ -297,21 +304,28 @@ def _format_response(self, arg_lines): # this ensures unique arguments in the sentence for a predicate found_labels = set() for line in arg_lines: - if line.strip() and ':' in line: - label, phrase = line.strip().split(':', 1) - - # label is of the form "ARG-n (def)" - label = label.split('(')[0].strip() - - # strip any surrounding quotes - phrase = phrase.strip('\'" ') - - norm_label = self._normalizer(label) - if norm_label in self._label_dict and norm_label not in found_labels: - if phrase.strip(): - _phrase = phrase.strip() - found_labels.add(norm_label) - output.append((self._label_dict[norm_label], _phrase)) + try: + if line.strip() and ':' in line: + label, phrase = line.strip().split(':', 1) + + # label is of the form "ARG-n (def)" + label = label.split('(')[0].strip() + + # strip any surrounding quotes + phrase = phrase.strip('\'" -') + + norm_label = self._normalizer(label) + if norm_label in self._label_dict and norm_label not in found_labels: + if phrase.strip(): + _phrase = phrase.strip() + found_labels.add(norm_label) + output.append((self._label_dict[norm_label], _phrase)) + except ValidationError: + msg.warn( + "Validation issue", + line, + show=self._verbose, + ) return output def parse_responses( diff --git a/spacy_llm/tasks/templates/span-srl.v1.jinja b/spacy_llm/tasks/templates/span-srl.v1.jinja index 613d37f2..451b3609 100644 --- a/spacy_llm/tasks/templates/span-srl.v1.jinja +++ b/spacy_llm/tasks/templates/span-srl.v1.jinja @@ -1,24 +1,26 @@ You are an expert Semantic Role Labeling (SRL) system. Your task is to accept Text as input and extract \ the Predicates and the Semantic Roles for each Predicate's ARGs in a step-by-step manner. {# whitespace #} -Here is the text that needs labeling: -{# whitespace #} -Text: -''' -{{text}} -''' -{# whitespace #} {%- if predicates -%} Step 1: Use the following Predicates for the Text: -Predicates: +Predicates: {{predicates}} {%- else -%} Step 1: Extract the Predicates for the Text in the following format : Predicates: {%- endif -%} {# whitespace #} -Step 2: For each Predicate, extract only the following Sematic Roles in '''Text''' in this format : +Step 2: For each Predicate, extract only the following Sematic Roles in '''Text''' in this format: +Text: Predicate: ARG-0 (Agent): ARG-1 (Patient or Theme): +ARG-2: ARG-M-LOC (Location Modifier): ARG-M-TMP (Temporal Modifier): +{# whitespace #} +Here is the text that needs labeling: +{# whitespace #} +Text: +''' +{{text}} +''' \ No newline at end of file From 917868b8d169ac31c8581e0826eb2e0cef94c8f3 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 11:49:42 -0600 Subject: [PATCH 03/33] adding example usage of SRL --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index a69b35e4..a8b45633 100644 --- a/README.md +++ b/README.md @@ -811,6 +811,16 @@ labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC | `normalizer` | `Optional[Callable[[str], str]]` | `None` | Function that normalizes the labels as returned by the LLM. If `None`, falls back to `spacy.LowercaseNormalizer.v1`. | | `verbose` | `bool` | `False` | If set to `True`, warnings will be generated when the LLM returns invalid responses. | +Example usage: +```python +from spacy_llm.util import assemble + + +nlp = assemble("config.cfg") +doc = nlp("I love this sentence.") +print(doc._.predicates) +print(doc._.relations) +``` #### spacy.Lemma.v1 From b6f4f526ebecfbd929f8e1b30308cf0e0bc53c54 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:04:16 -0600 Subject: [PATCH 04/33] Fixing format warnings --- spacy_llm/tasks/srl_task.py | 14 +++++++------- spacy_llm/tests/tasks/test_span_srl.py | 3 +-- usage_examples/span_srl_openai/run_pipeline.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index f5fff4ac..7fa87e8e 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -4,7 +4,7 @@ from collections import defaultdict import jinja2 -from pydantic import BaseModel, Field, ValidationError, validator +from pydantic import BaseModel, ValidationError from spacy.language import Language from spacy.tokens import Doc from spacy.training import Example @@ -125,8 +125,8 @@ def make_srl_task( examples: ExamplesConfigType = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", - case_sensitive_matching: bool = False, - single_match: bool = False, + case_sensitive_matching: bool = True, + single_match: bool = True, verbose: bool = False, predicate_key: str = 'Predicate' ): @@ -187,7 +187,7 @@ def _cfg_keys(self) -> List[str]: def __init__( self, - labels: List[str] = [], + labels: List[str] = None, template: str = _DEFAULT_SPAN_SRL_TEMPLATE_V1, label_definitions: Optional[Dict[str, str]] = None, examples: Optional[List[SRLExample]] = None, @@ -197,8 +197,8 @@ def __init__( alignment_mode: Literal[ "strict", "contract", "expand" # noqa: F821 ] = "contract", - case_sensitive_matching: bool = False, - single_match: bool = False, + case_sensitive_matching: bool = True, + single_match: bool = True, ): self._normalizer = normalizer if normalizer else lowercase_normalizer() self._label_dict = {self._normalizer(label): label for label in labels} @@ -240,7 +240,7 @@ def initialize( self, get_examples: Callable[[], Iterable["Example"]], nlp: Language, - labels: List[str] = [], + labels: List[str] = None, ) -> None: """Initialize the task, by auto-discovering labels. diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 62d4a5b3..affb50e1 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -4,7 +4,6 @@ from confection import Config from pytest import FixtureRequest from spacy_llm.pipeline import LLMWrapper -from spacy_llm.tasks.srl_task import _DEFAULT_SPAN_SRL_TEMPLATE_V1, ArgRELItem, PredicateItem, SpanItem from spacy_llm.ty import Labeled, LLMTask from spacy_llm.util import assemble_from_config, split_labels @@ -80,4 +79,4 @@ def test_rel_predict(task, cfg_string, request): doc = nlp(text) assert doc._.predicates - assert doc._.relations \ No newline at end of file + assert doc._.relations diff --git a/usage_examples/span_srl_openai/run_pipeline.py b/usage_examples/span_srl_openai/run_pipeline.py index e358d729..966498dc 100644 --- a/usage_examples/span_srl_openai/run_pipeline.py +++ b/usage_examples/span_srl_openai/run_pipeline.py @@ -45,4 +45,4 @@ def run_pipeline( if __name__ == "__main__": - typer.run(run_pipeline) \ No newline at end of file + typer.run(run_pipeline) From d803b233587598e37bc99f5f0b6b2150c31e8da4 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:10:52 -0600 Subject: [PATCH 05/33] Fixing format warnings --- spacy_llm/tasks/srl_task.py | 30 +++++++++++----------- usage_examples/span_srl_openai/__init__.py | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 7fa87e8e..19e0d9ab 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -30,7 +30,7 @@ def __hash__(self): class PredicateItem(SpanItem): - roleset_id: str = '' + roleset_id: str = "" def __hash__(self): return hash((self.text, self.start_char, self.end_char, self.roleset_id)) @@ -55,7 +55,7 @@ def _preannotate(doc: Union[Doc, SRLExample]) -> str: """Creates a text version of the document with list of provided predicates.""" text = doc.text - preds = ', '.join([s.text for s in doc.predicates]) + preds = ", ".join([s.text for s in doc.predicates]) formatted_text = f"{text}\nPredicates: {preds}" @@ -109,10 +109,10 @@ def _get_label2rels(rel_tuples: Iterable[Tuple[int, ArgRELItem]]): label2prf[label] = _overlap_prf(gold_label_rels, pred_label_rels) return { - 'Predicates': predicates_prf, - 'ARGs': { - 'Overall': micro_rel_prf, - 'PerLabel': label2prf + "Predicates": predicates_prf, + "ARGs": { + "Overall": micro_rel_prf, + "PerLabel": label2prf } } @@ -128,7 +128,7 @@ def make_srl_task( case_sensitive_matching: bool = True, single_match: bool = True, verbose: bool = False, - predicate_key: str = 'Predicate' + predicate_key: str = "Predicate" ): """SRL.v1 task factory. @@ -193,7 +193,7 @@ def __init__( examples: Optional[List[SRLExample]] = None, normalizer: Optional[Callable[[str], str]] = None, verbose: bool = False, - predicate_key: str = 'Predicate', + predicate_key: str = "Predicate", alignment_mode: Literal[ "strict", "contract", "expand" # noqa: F821 ] = "contract", @@ -287,7 +287,7 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: for doc in docs: predicates = None if len(doc._.predicates): - predicates = ','.join([p['text'] for p in doc._.predicates]) + predicates = ", ".join([p["text"] for p in doc._.predicates]) prompt = _template.render( text=doc.text, @@ -305,11 +305,11 @@ def _format_response(self, arg_lines): found_labels = set() for line in arg_lines: try: - if line.strip() and ':' in line: - label, phrase = line.strip().split(':', 1) + if line.strip() and ":" in line: + label, phrase = line.strip().split(":", 1) # label is of the form "ARG-n (def)" - label = label.split('(')[0].strip() + label = label.split("(")[0].strip() # strip any surrounding quotes phrase = phrase.strip('\'" -') @@ -334,11 +334,11 @@ def parse_responses( for doc, prompt_response in zip(docs, responses): predicates = [] relations = [] - lines = prompt_response.split('\n') + lines = prompt_response.split("\n") # match lines that start with {Predicate:, Predicate 1:, Predicate1:} - # assuming self.predicate_key = 'Predicate' - pred_patt = r'^' + re.escape(self._predicate_key) + r'\b\s*\d*[:\-\s]' + # assuming self.predicate_key = "Predicate" + pred_patt = r"^" + re.escape(self._predicate_key) + r"\b\s*\d*[:\-\s]" pred_indices, pred_lines = zip(*[(i, line) for i, line in enumerate(lines) if re.search(pred_patt, line)]) pred_indices = list(pred_indices) diff --git a/usage_examples/span_srl_openai/__init__.py b/usage_examples/span_srl_openai/__init__.py index 39687454..06fab2f6 100644 --- a/usage_examples/span_srl_openai/__init__.py +++ b/usage_examples/span_srl_openai/__init__.py @@ -1,3 +1,3 @@ from .run_pipeline import run_pipeline -__all__ = ["run_pipeline"] \ No newline at end of file +__all__ = ["run_pipeline"] From 53a494c9d8aba14f1569414fc4592911a57c28f5 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:14:49 -0600 Subject: [PATCH 06/33] Fixing format warnings --- spacy_llm/tests/tasks/test_span_srl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index affb50e1..01a455f0 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -17,20 +17,20 @@ def zeroshot_cfg_string(): return """ [paths] examples = null - + [nlp] lang = "en" pipeline = ["llm"] - + [components] - + [components.llm] factory = "llm" - + [components.llm.task] @llm_tasks = "spacy.SRL.v1" labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC - + [components.llm.model] @llm_models = "spacy.GPT-3-5.v1" """ From dd7d9fb7a5caa0db046610138aee268260abbe44 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:18:56 -0600 Subject: [PATCH 07/33] Fixing format warnings --- spacy_llm/tasks/srl_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 19e0d9ab..db29ac71 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -119,7 +119,7 @@ def _get_label2rels(rel_tuples: Iterable[Tuple[int, ArgRELItem]]): @registry.llm_tasks("spacy.SRL.v1") def make_srl_task( - labels: str, + labels: Union[List[str], str] = [], template: str = _DEFAULT_SPAN_SRL_TEMPLATE_V1, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, @@ -187,7 +187,7 @@ def _cfg_keys(self) -> List[str]: def __init__( self, - labels: List[str] = None, + labels: List[str] = [], template: str = _DEFAULT_SPAN_SRL_TEMPLATE_V1, label_definitions: Optional[Dict[str, str]] = None, examples: Optional[List[SRLExample]] = None, From 0ad60635cdd313b3cd1e034e5e8c5c3682200d12 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:23:40 -0600 Subject: [PATCH 08/33] Fix Literal ImportError --- spacy_llm/tasks/srl_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index db29ac71..082e8652 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, Any, Literal +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, Any import re @@ -16,6 +16,7 @@ from .templates import read_template from .util import SerializableTask from .util.parsing import find_substrings +from ..compat import Literal _DEFAULT_SPAN_SRL_TEMPLATE_V1 = read_template("span-srl.v1") From 15412b55098080b6fc1cb5ac30a2634b6aeb5c4f Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:26:36 -0600 Subject: [PATCH 09/33] Fix Label assignment --- spacy_llm/tasks/srl_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 082e8652..dd26601e 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -241,7 +241,7 @@ def initialize( self, get_examples: Callable[[], Iterable["Example"]], nlp: Language, - labels: List[str] = None, + labels: List[str] = [], ) -> None: """Initialize the task, by auto-discovering labels. From fd194411032abbd55e7b836bcc9e17d26bf2f897 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:36:58 -0600 Subject: [PATCH 10/33] Fix the template's preamble --- spacy_llm/tasks/templates/span-srl.v1.jinja | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spacy_llm/tasks/templates/span-srl.v1.jinja b/spacy_llm/tasks/templates/span-srl.v1.jinja index 451b3609..d8ba0b78 100644 --- a/spacy_llm/tasks/templates/span-srl.v1.jinja +++ b/spacy_llm/tasks/templates/span-srl.v1.jinja @@ -1,5 +1,4 @@ -You are an expert Semantic Role Labeling (SRL) system. Your task is to accept Text as input and extract \ -the Predicates and the Semantic Roles for each Predicate's ARGs in a step-by-step manner. +You are an expert Semantic Role Labeling (SRL) system. Your task is to accept Text as input and extract the Predicates and the Semantic Roles for each Predicate's ARGs in a step-by-step manner. {# whitespace #} {%- if predicates -%} Step 1: Use the following Predicates for the Text: From b56c1d06f23c25ddfcb982d03850c3c7028cd97e Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 25 Jul 2023 12:47:51 -0600 Subject: [PATCH 11/33] Black formatting --- spacy_llm/tasks/__init__.py | 2 +- spacy_llm/tasks/srl_task.py | 83 ++++++++++++++++---------- spacy_llm/tests/tasks/test_span_srl.py | 4 +- 3 files changed, 56 insertions(+), 33 deletions(-) diff --git a/spacy_llm/tasks/__init__.py b/spacy_llm/tasks/__init__.py index 61f2de4a..8e375412 100644 --- a/spacy_llm/tasks/__init__.py +++ b/spacy_llm/tasks/__init__.py @@ -28,5 +28,5 @@ "SpanCatTask", "SummarizationTask", "TextCatTask", - "SRLTask" + "SRLTask", ] diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index dd26601e..bb372186 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -76,17 +76,25 @@ def score_srl_spans( pred_doc = eg.predicted gold_doc = eg.reference - pred_predicates_spans.update([(i, PredicateItem(**dict(p))) for p in pred_doc._.predicates]) - gold_predicates_spans.update([(i, PredicateItem(**dict(p))) for p in gold_doc._.predicates]) - - pred_relation_tuples.update([(i, ArgRELItem(**dict(r))) for r in pred_doc._.relations]) - gold_relation_tuples.update([(i, ArgRELItem(**dict(r))) for r in gold_doc._.relations]) + pred_predicates_spans.update( + [(i, PredicateItem(**dict(p))) for p in pred_doc._.predicates] + ) + gold_predicates_spans.update( + [(i, PredicateItem(**dict(p))) for p in gold_doc._.predicates] + ) + + pred_relation_tuples.update( + [(i, ArgRELItem(**dict(r))) for r in pred_doc._.relations] + ) + gold_relation_tuples.update( + [(i, ArgRELItem(**dict(r))) for r in gold_doc._.relations] + ) def _overlap_prf(gold: set, pred: set): overlap = gold.intersection(pred) - p = 0. if not len(pred) else len(overlap)/len(pred) - r = 0. if not len(gold) else len(overlap)/len(gold) - f = 0. if not p or not r else 2*p*r/(p+r) + p = 0.0 if not len(pred) else len(overlap) / len(pred) + r = 0.0 if not len(gold) else len(overlap) / len(gold) + f = 0.0 if not p or not r else 2 * p * r / (p + r) return p, r, f predicates_prf = _overlap_prf(gold_predicates_spans, pred_predicates_spans) @@ -102,7 +110,9 @@ def _get_label2rels(rel_tuples: Iterable[Tuple[int, ArgRELItem]]): pred_label2relations = _get_label2rels(pred_relation_tuples) gold_label2relations = _get_label2rels(gold_relation_tuples) - all_labels = set.union(set(pred_label2relations.keys()), set(gold_label2relations.keys())) + all_labels = set.union( + set(pred_label2relations.keys()), set(gold_label2relations.keys()) + ) label2prf = {} for label in all_labels: pred_label_rels = pred_label2relations[label] @@ -111,10 +121,7 @@ def _get_label2rels(rel_tuples: Iterable[Tuple[int, ArgRELItem]]): return { "Predicates": predicates_prf, - "ARGs": { - "Overall": micro_rel_prf, - "PerLabel": label2prf - } + "ARGs": {"Overall": micro_rel_prf, "PerLabel": label2prf}, } @@ -129,7 +136,7 @@ def make_srl_task( case_sensitive_matching: bool = True, single_match: bool = True, verbose: bool = False, - predicate_key: str = "Predicate" + predicate_key: str = "Predicate", ): """SRL.v1 task factory. @@ -217,7 +224,7 @@ def __init__( @classmethod def _check_extensions(cls): """Add `predicates` extension if need be. - Add `relations` extension if need be.""" + Add `relations` extension if need be.""" if not Doc.has_extension("predicates"): Doc.set_extension("predicates", default=[]) @@ -294,7 +301,7 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: text=doc.text, labels=list(self._label_dict.values()), label_definitions=self._label_definitions, - predicates=predicates + predicates=predicates, ) yield prompt @@ -313,10 +320,13 @@ def _format_response(self, arg_lines): label = label.split("(")[0].strip() # strip any surrounding quotes - phrase = phrase.strip('\'" -') + phrase = phrase.strip("'\" -") norm_label = self._normalizer(label) - if norm_label in self._label_dict and norm_label not in found_labels: + if ( + norm_label in self._label_dict + and norm_label not in found_labels + ): if phrase.strip(): _phrase = phrase.strip() found_labels.add(norm_label) @@ -330,7 +340,7 @@ def _format_response(self, arg_lines): return output def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, docs: Iterable[Doc], responses: Iterable[str] ) -> Iterable[Doc]: for doc, prompt_response in zip(docs, responses): predicates = [] @@ -340,12 +350,18 @@ def parse_responses( # match lines that start with {Predicate:, Predicate 1:, Predicate1:} # assuming self.predicate_key = "Predicate" pred_patt = r"^" + re.escape(self._predicate_key) + r"\b\s*\d*[:\-\s]" - pred_indices, pred_lines = zip(*[(i, line) for i, line in enumerate(lines) if re.search(pred_patt, line)]) + pred_indices, pred_lines = zip( + *[ + (i, line) + for i, line in enumerate(lines) + if re.search(pred_patt, line) + ] + ) pred_indices = list(pred_indices) # extract the predicate strings - pred_strings = [line.split(":", 1)[1].strip('\'" ') for line in pred_lines] + pred_strings = [line.split(":", 1)[1].strip("'\" ") for line in pred_lines] # extract the line ranges (s, e) of predicate's content. # then extract the pred content lines using the ranges @@ -356,17 +372,18 @@ def parse_responses( # assign the spans of the predicates and args # then create ArgRELItem from the identified predicates and arguments for pred_str, pred_content_lines in zip(pred_strings, pred_contents): - pred_offsets = list(find_substrings( - doc.text, - [pred_str], - case_sensitive=True, - single_match=True - )) + pred_offsets = list( + find_substrings( + doc.text, [pred_str], case_sensitive=True, single_match=True + ) + ) # ignore the args if the predicate is not found if len(pred_offsets): p_start_char, p_end_char = pred_offsets[0] - pred_item = PredicateItem(text=pred_str, start_char=p_start_char, end_char=p_end_char) + pred_item = PredicateItem( + text=pred_str, start_char=p_start_char, end_char=p_end_char + ) predicates.append(pred_item.dict()) for label, phrase in self._format_response(pred_content_lines): @@ -374,11 +391,15 @@ def parse_responses( doc.text, [phrase], case_sensitive=self._case_sensitive_matching, - single_match=self._single_match + single_match=self._single_match, ) for start, end in arg_offsets: - arg_item = SpanItem(text=phrase, start_char=start, end_char=end).dict() - arg_rel_item = ArgRELItem(predicate=pred_item, role=arg_item, label=label).dict() + arg_item = SpanItem( + text=phrase, start_char=start, end_char=end + ).dict() + arg_rel_item = ArgRELItem( + predicate=pred_item, role=arg_item, label=label + ).dict() relations.append(arg_rel_item) doc._.predicates = predicates diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 01a455f0..231c682a 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -66,7 +66,9 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string"]) # "zeroshot_cfg_string", +@pytest.mark.parametrize( + "cfg_string", ["zeroshot_cfg_string"] +) # "zeroshot_cfg_string", def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable From d6564f77f12c91b6a93c21977e7561051d57bdbc Mon Sep 17 00:00:00 2001 From: rehan Date: Thu, 27 Jul 2023 10:47:32 -0600 Subject: [PATCH 12/33] imports in alphabetical order --- spacy_llm/tasks/__init__.py | 6 +++--- spacy_llm/tasks/srl_task.py | 2 +- spacy_llm/tests/tasks/test_span_srl.py | 6 +++--- usage_examples/span_srl_openai/run_pipeline.py | 7 +++---- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/spacy_llm/tasks/__init__.py b/spacy_llm/tasks/__init__.py index 8e375412..ec6ae8f6 100644 --- a/spacy_llm/tasks/__init__.py +++ b/spacy_llm/tasks/__init__.py @@ -4,9 +4,9 @@ from .rel import RELTask, make_rel_task from .sentiment import SentimentTask, make_sentiment_task from .spancat import SpanCatTask, make_spancat_task, make_spancat_task_v2 +from .srl_task import SRLTask, make_srl_task from .summarization import SummarizationTask, make_summarization_task from .textcat import TextCatTask, make_textcat_task -from .srl_task import SRLTask, make_srl_task __all__ = [ "make_lemma_task", @@ -17,16 +17,16 @@ "make_sentiment_task", "make_spancat_task", "make_spancat_task_v2", + "make_srl_task", "make_summarization_task", "make_textcat_task", - "make_srl_task", "LemmaTask", "NERTask", "NoopTask", "RELTask", "SentimentTask", "SpanCatTask", + "SRLTask", "SummarizationTask", "TextCatTask", - "SRLTask", ] diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index bb372186..ce741ac5 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -10,13 +10,13 @@ from spacy.training import Example from wasabi import msg +from ..compat import Literal from ..registry import lowercase_normalizer, registry from ..ty import ExamplesConfigType from ..util import split_labels from .templates import read_template from .util import SerializableTask from .util.parsing import find_substrings -from ..compat import Literal _DEFAULT_SPAN_SRL_TEMPLATE_V1 = read_template("span-srl.v1") diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 231c682a..829815a6 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -1,13 +1,13 @@ -from pathlib import Path - import pytest + from confection import Config +from pathlib import Path from pytest import FixtureRequest from spacy_llm.pipeline import LLMWrapper +from spacy_llm.tests.compat import has_openai_key from spacy_llm.ty import Labeled, LLMTask from spacy_llm.util import assemble_from_config, split_labels -from spacy_llm.tests.compat import has_openai_key EXAMPLES_DIR = Path(__file__).parent / "examples" diff --git a/usage_examples/span_srl_openai/run_pipeline.py b/usage_examples/span_srl_openai/run_pipeline.py index 966498dc..a434cf9f 100644 --- a/usage_examples/span_srl_openai/run_pipeline.py +++ b/usage_examples/span_srl_openai/run_pipeline.py @@ -1,11 +1,10 @@ import os -from pathlib import Path -from typing import Optional - import typer -from wasabi import msg +from pathlib import Path from spacy_llm.util import assemble +from typing import Optional +from wasabi import msg Arg = typer.Argument Opt = typer.Option From de68696b81c39a01ec4bbfdabe6db8e86b279697 Mon Sep 17 00:00:00 2001 From: rehan Date: Thu, 27 Jul 2023 10:49:41 -0600 Subject: [PATCH 13/33] alignment_mode should be a Literal. --- spacy_llm/tasks/srl_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index ce741ac5..9c3dde76 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -151,7 +151,7 @@ def make_srl_task( reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. normalizer (Optional[Callable[[str], str]]): optional normalizer function. - alignment_mode (str): "strict", "contract" or "expand". + alignment_mode (Literal): "strict", "contract" or "expand". case_sensitive_matching: Whether to search without case sensitivity. single_match (bool): If False, allow one substring to match multiple times in the text. If True, returns the first hit. @@ -235,7 +235,7 @@ def _check_extensions(cls): @staticmethod def _validate_alignment(alignment_mode: str): """Raises error if specified alignment_mode is not supported. - alignment_mode (str): Alignment mode to check. + alignment_mode (Literal): Alignment mode to check. """ # ideally, this list should be taken from spaCy, but it's not currently exposed from doc.pyx. alignment_modes = ("strict", "contract", "expand") From ed07c83e2cbdf303d8cd2a99c6ba0325972c8972 Mon Sep 17 00:00:00 2001 From: Rehan Ahmed Date: Thu, 27 Jul 2023 10:52:45 -0600 Subject: [PATCH 14/33] Update spacy_llm/tasks/srl_task.py Co-authored-by: Raphael Mitsch --- spacy_llm/tasks/srl_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 9c3dde76..e12e28b9 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -155,7 +155,7 @@ def make_srl_task( case_sensitive_matching: Whether to search without case sensitivity. single_match (bool): If False, allow one substring to match multiple times in the text. If True, returns the first hit. - verbose (boole): Verbose ot not + verbose (bool): Verbose or not predicate_key: The str of Predicate in the template """ labels_list = split_labels(labels) From 472d5c7e74334511102780cff363da858b560643 Mon Sep 17 00:00:00 2001 From: Rehan Ahmed Date: Thu, 27 Jul 2023 10:53:38 -0600 Subject: [PATCH 15/33] Update spacy_llm/tasks/templates/span-srl.v1.jinja Co-authored-by: Raphael Mitsch --- spacy_llm/tasks/templates/span-srl.v1.jinja | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_llm/tasks/templates/span-srl.v1.jinja b/spacy_llm/tasks/templates/span-srl.v1.jinja index d8ba0b78..733fed5e 100644 --- a/spacy_llm/tasks/templates/span-srl.v1.jinja +++ b/spacy_llm/tasks/templates/span-srl.v1.jinja @@ -8,7 +8,7 @@ Step 1: Extract the Predicates for the Text in the following format : Predicates: {%- endif -%} {# whitespace #} -Step 2: For each Predicate, extract only the following Sematic Roles in '''Text''' in this format: +Step 2: For each Predicate, extract only the following Semantic Roles in '''Text''' in this format: Text: Predicate: ARG-0 (Agent): From 55a80185bda38dd78630e6c6c0d09e321abd163d Mon Sep 17 00:00:00 2001 From: Rehan Ahmed Date: Thu, 27 Jul 2023 10:54:00 -0600 Subject: [PATCH 16/33] Update spacy_llm/tests/tasks/test_span_srl.py Co-authored-by: Raphael Mitsch --- spacy_llm/tests/tasks/test_span_srl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 829815a6..a24811fa 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -68,7 +68,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") @pytest.mark.parametrize( "cfg_string", ["zeroshot_cfg_string"] -) # "zeroshot_cfg_string", +) def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable From 355241a19606e42e6a09718b98765e2f4802fbbf Mon Sep 17 00:00:00 2001 From: rehan Date: Thu, 27 Jul 2023 11:46:58 -0600 Subject: [PATCH 17/33] reformatting --- spacy_llm/tests/tasks/test_span_srl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 829815a6..2c437bd7 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -8,7 +8,6 @@ from spacy_llm.ty import Labeled, LLMTask from spacy_llm.util import assemble_from_config, split_labels - EXAMPLES_DIR = Path(__file__).parent / "examples" From 84d17dfe9fe1596919bdf0636db810f7ab7e4e77 Mon Sep 17 00:00:00 2001 From: rehan Date: Thu, 27 Jul 2023 11:52:50 -0600 Subject: [PATCH 18/33] reformatting --- spacy_llm/tests/tasks/test_span_srl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index bce486d7..a808c47f 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -65,9 +65,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize( - "cfg_string", ["zeroshot_cfg_string"] -) +@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string"]) def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable From 666c3eec6273b0cf46443b469218f6d450001f5c Mon Sep 17 00:00:00 2001 From: rehan Date: Thu, 27 Jul 2023 18:51:52 -0600 Subject: [PATCH 19/33] adding test on srl roles --- spacy_llm/tests/tasks/test_span_srl.py | 48 ++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index a808c47f..080404bc 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -4,6 +4,7 @@ from pathlib import Path from pytest import FixtureRequest from spacy_llm.pipeline import LLMWrapper +from spacy_llm.tasks.srl_task import SRLExample from spacy_llm.tests.compat import has_openai_key from spacy_llm.ty import Labeled, LLMTask from spacy_llm.util import assemble_from_config, split_labels @@ -28,7 +29,7 @@ def zeroshot_cfg_string(): [components.llm.task] @llm_tasks = "spacy.SRL.v1" - labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + labels = ARG-0,ARG-1,ARG-M-LOC,ARG-M-TMP [components.llm.model] @llm_models = "spacy.GPT-3-5.v1" @@ -38,8 +39,36 @@ def zeroshot_cfg_string(): @pytest.fixture def task(): text = "We love this sentence in Berlin right now ." - gold_relations = [] - return text, gold_relations + predicate = {"text": "love", "start_char": 3, "end_char": 7} + srl_example = SRLExample( + **{ + "text": text, + "predicates": [predicate], + "relations": [ + { + "label": "ARG-0", + "predicate": predicate, + "role": {"text": "We", "start_char": 0, "end_char": 2}, + }, + { + "label": "ARG-1", + "predicate": predicate, + "role": {"text": "this sentence", "start_char": 8, "end_char": 21}, + }, + { + "label": "ARG-M-LOC", + "predicate": predicate, + "role": {"text": "in Berlin", "start_char": 22, "end_char": 31}, + }, + { + "label": "ARG-M-TMP", + "predicate": predicate, + "role": {"text": "right now", "start_char": 32, "end_char": 41}, + }, + ], + } + ) + return text, srl_example @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") @@ -74,8 +103,15 @@ def test_rel_predict(task, cfg_string, request): orig_config = Config().from_str(cfg_string) nlp = assemble_from_config(orig_config) - text, _ = task + text, gold_example = task doc = nlp(text) - assert doc._.predicates - assert doc._.relations + assert len(doc._.predicates) + assert len(doc._.relations) + + assert doc._.predicates[0]["text"] == gold_example.predicates[0].text + + predicated_roles = tuple(sorted([r["role"]["text"] for r in doc._.relations])) + gold_roles = tuple(sorted([r.role.text for r in gold_example.relations])) + + assert predicated_roles == gold_roles From cb81bdf53f4e020e256af98c2a4be83d8bb2dbaa Mon Sep 17 00:00:00 2001 From: rehan Date: Thu, 27 Jul 2023 18:52:22 -0600 Subject: [PATCH 20/33] SRLTask inherits SpanTask --- spacy_llm/tasks/span.py | 6 +-- spacy_llm/tasks/srl_task.py | 89 +++++++++++++++---------------------- 2 files changed, 38 insertions(+), 57 deletions(-) diff --git a/spacy_llm/tasks/span.py b/spacy_llm/tasks/span.py index 708e1565..a59d73f0 100644 --- a/spacy_llm/tasks/span.py +++ b/spacy_llm/tasks/span.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type +from typing import Callable, Dict, Generic, Iterable, List, Optional, Tuple, Type import jinja2 from pydantic import BaseModel @@ -8,7 +8,7 @@ from ..compat import Literal from ..registry import lowercase_normalizer from .util.parsing import find_substrings -from .util.serialization import SerializableTask +from .util.serialization import ExampleType, SerializableTask class SpanExample(BaseModel): @@ -16,7 +16,7 @@ class SpanExample(BaseModel): entities: Dict[str, List[str]] -class SpanTask(SerializableTask[SpanExample]): +class SpanTask(SerializableTask[SpanExample], Generic[ExampleType]): """Base class for Span-related tasks, eg NER and SpanCat.""" def __init__( diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index e12e28b9..230a89f2 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -11,11 +11,11 @@ from wasabi import msg from ..compat import Literal -from ..registry import lowercase_normalizer, registry +from ..registry import registry from ..ty import ExamplesConfigType from ..util import split_labels +from .span import SpanTask from .templates import read_template -from .util import SerializableTask from .util.parsing import find_substrings _DEFAULT_SPAN_SRL_TEMPLATE_V1 = read_template("span-srl.v1") @@ -151,7 +151,7 @@ def make_srl_task( reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. normalizer (Optional[Callable[[str], str]]): optional normalizer function. - alignment_mode (Literal): "strict", "contract" or "expand". + alignment_mode (Literal["strict", "contract", "expand"]): "strict", "contract" or "expand". case_sensitive_matching: Whether to search without case sensitivity. single_match (bool): If False, allow one substring to match multiple times in the text. If True, returns the first hit. @@ -165,7 +165,7 @@ def make_srl_task( labels=labels_list, template=template, label_definitions=label_definitions, - examples=rel_examples, + prompt_examples=rel_examples, normalizer=normalizer, verbose=verbose, alignment_mode=alignment_mode, @@ -175,49 +175,33 @@ def make_srl_task( ) -class SRLTask(SerializableTask[SRLExample]): - @property - def _Example(self) -> Type[SRLExample]: - return SRLExample - - @property - def _cfg_keys(self) -> List[str]: - return [ - "_label_dict", - "_template", - "_label_definitions", - "_verbose", - "_predicate_key", - "_alignment_mode", - "_case_sensitive_matching", - "_single_match", - ] - +class SRLTask(SpanTask[SRLExample]): def __init__( self, labels: List[str] = [], template: str = _DEFAULT_SPAN_SRL_TEMPLATE_V1, label_definitions: Optional[Dict[str, str]] = None, - examples: Optional[List[SRLExample]] = None, + prompt_examples: Optional[List[SRLExample]] = None, normalizer: Optional[Callable[[str], str]] = None, - verbose: bool = False, - predicate_key: str = "Predicate", alignment_mode: Literal[ "strict", "contract", "expand" # noqa: F821 ] = "contract", case_sensitive_matching: bool = True, single_match: bool = True, + verbose: bool = False, + predicate_key: str = "Predicate", ): - self._normalizer = normalizer if normalizer else lowercase_normalizer() - self._label_dict = {self._normalizer(label): label for label in labels} - self._template = template - self._label_definitions = label_definitions - self._examples = examples + super().__init__( + labels, + template, + label_definitions, + prompt_examples, + normalizer, + alignment_mode, + case_sensitive_matching, + single_match, + ) self._verbose = verbose - self._validate_alignment(alignment_mode) - self._alignment_mode = alignment_mode - self._case_sensitive_matching = case_sensitive_matching - self._single_match = single_match self._predicate_key = predicate_key self._check_extensions() @@ -232,18 +216,6 @@ def _check_extensions(cls): if not Doc.has_extension("relations"): Doc.set_extension("relations", default=[]) - @staticmethod - def _validate_alignment(alignment_mode: str): - """Raises error if specified alignment_mode is not supported. - alignment_mode (Literal): Alignment mode to check. - """ - # ideally, this list should be taken from spaCy, but it's not currently exposed from doc.pyx. - alignment_modes = ("strict", "contract", "expand") - if alignment_mode not in alignment_modes: - raise ValueError( - f"Unsupported alignment mode '{alignment_mode}'. Supported modes: {', '.join(alignment_modes)}" - ) - def initialize( self, get_examples: Callable[[], Iterable["Example"]], @@ -281,14 +253,6 @@ def initialize( self._label_dict = {self._normalizer(label): label for label in labels} - @property - def labels(self) -> Tuple[str, ...]: - return tuple(self._label_dict.values()) - - @property - def prompt_template(self) -> str: - return self._template - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: environment = jinja2.Environment() _template = environment.from_string(self._template) @@ -405,3 +369,20 @@ def parse_responses( doc._.predicates = predicates doc._.relations = relations yield doc + + @property + def _cfg_keys(self) -> List[str]: + return [ + "_label_dict", + "_template", + "_label_definitions", + "_verbose", + "_predicate_key", + "_alignment_mode", + "_case_sensitive_matching", + "_single_match", + ] + + @property + def _Example(self) -> Type[SRLExample]: + return SRLExample From 6ab47237b694b5721e7290e7f01ee5ed7b77e6a9 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 1 Aug 2023 11:43:00 -0600 Subject: [PATCH 21/33] Added label definitions rendering in prompt --- spacy_llm/tasks/templates/span-srl.v1.jinja | 12 +++++++----- spacy_llm/tests/tasks/test_span_srl.py | 14 +++++++++++++- usage_examples/span_srl_openai/zeroshot.cfg | 11 +++++++++-- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/spacy_llm/tasks/templates/span-srl.v1.jinja b/spacy_llm/tasks/templates/span-srl.v1.jinja index 733fed5e..4fa142df 100644 --- a/spacy_llm/tasks/templates/span-srl.v1.jinja +++ b/spacy_llm/tasks/templates/span-srl.v1.jinja @@ -1,5 +1,6 @@ You are an expert Semantic Role Labeling (SRL) system. Your task is to accept Text as input and extract the Predicates and the Semantic Roles for each Predicate's ARGs in a step-by-step manner. {# whitespace #} +{# whitespace #} {%- if predicates -%} Step 1: Use the following Predicates for the Text: Predicates: {{predicates}} @@ -8,14 +9,15 @@ Step 1: Extract the Predicates for the Text in the following format : Predicates: {%- endif -%} {# whitespace #} +{# whitespace #} Step 2: For each Predicate, extract only the following Semantic Roles in '''Text''' in this format: Text: Predicate: -ARG-0 (Agent): -ARG-1 (Patient or Theme): -ARG-2: -ARG-M-LOC (Location Modifier): -ARG-M-TMP (Temporal Modifier): +{# whitespace #} +{%- for label, definition in label_definitions.items() -%} +{{ label }}: +{# whitespace #} +{%- endfor -%} {# whitespace #} Here is the text that needs labeling: {# whitespace #} diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 080404bc..9adfc34d 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -29,7 +29,14 @@ def zeroshot_cfg_string(): [components.llm.task] @llm_tasks = "spacy.SRL.v1" - labels = ARG-0,ARG-1,ARG-M-LOC,ARG-M-TMP + labels = ARG-0,ARG-1,ARG-2,ARG-M-LOC,ARG-M-TMP + + [components.llm.task.label_definitions] + ARG-0 = "Agent" + ARG-1 = "Patient or Theme" + ARG-2 = "ARG-2" + ARG-M-TMP = "Temporal Modifier" + ARG-M-LOC = "Location Modifier" [components.llm.model] @llm_models = "spacy.GPT-3-5.v1" @@ -99,6 +106,11 @@ def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable """ + import logging + import spacy_llm + spacy_llm.logger.addHandler(logging.StreamHandler()) + spacy_llm.logger.setLevel(logging.DEBUG) + cfg_string = request.getfixturevalue(cfg_string) orig_config = Config().from_str(cfg_string) nlp = assemble_from_config(orig_config) diff --git a/usage_examples/span_srl_openai/zeroshot.cfg b/usage_examples/span_srl_openai/zeroshot.cfg index 570ac1bb..bcfc03ef 100644 --- a/usage_examples/span_srl_openai/zeroshot.cfg +++ b/usage_examples/span_srl_openai/zeroshot.cfg @@ -12,8 +12,15 @@ factory = "llm" [components.llm.task] @llm_tasks = "spacy.SRL.v1" -labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC +labels = ARG-0,ARG-1,ARG-2,ARG-M-TMP,ARG-M-LOC + +[components.llm.task.label_definitions] +ARG-0 = "Agent" +ARG-1 = "Patient or Theme" +ARG-2 = "ARG-2" +ARG-M-TMP = "Temporal Modifier" +ARG-M-LOC = "Location Modifier" [components.llm.model] @llm_models = "spacy.GPT-3-5.v1" - +config = {"temperature": 1} From 037f36ff360280060339d1371aca0156712c0fd9 Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 1 Aug 2023 11:45:06 -0600 Subject: [PATCH 22/33] Reformatting --- spacy_llm/tests/tasks/test_span_srl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 9adfc34d..62ddfc1e 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -106,11 +106,6 @@ def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable """ - import logging - import spacy_llm - spacy_llm.logger.addHandler(logging.StreamHandler()) - spacy_llm.logger.setLevel(logging.DEBUG) - cfg_string = request.getfixturevalue(cfg_string) orig_config = Config().from_str(cfg_string) nlp = assemble_from_config(orig_config) From d6faecd220cd8448d3b3e56e96e3e09af2a344aa Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 1 Aug 2023 18:03:20 -0600 Subject: [PATCH 23/33] Restructuring SRLExample and ARGRelItem --- spacy_llm/tasks/srl_task.py | 46 +++++++++----- spacy_llm/tests/tasks/test_span_srl.py | 61 ++++++++++++------- usage_examples/span_srl_openai/README.md | 9 ++- .../span_srl_openai/run_pipeline.py | 11 ++-- 4 files changed, 84 insertions(+), 43 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 230a89f2..966d14a0 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -37,19 +37,22 @@ def __hash__(self): return hash((self.text, self.start_char, self.end_char, self.roleset_id)) -class ArgRELItem(BaseModel): - predicate: PredicateItem +class RoleItem(BaseModel): role: SpanItem label: str def __hash__(self): - return hash((self.predicate, self.role, self.label)) + return hash((self.role, self.label)) class SRLExample(BaseModel): text: str predicates: List[PredicateItem] - relations: List[ArgRELItem] + relations: List[Tuple[PredicateItem, List[RoleItem]]] + + def __str__(self): + return f"""Predicates: {', '.join([p.text for p in self.predicates])} +Relations: {str([(p.text, [(r.label, r.role.text) for r in rs]) for p, rs in self.relations])}""" def _preannotate(doc: Union[Doc, SRLExample]) -> str: @@ -84,10 +87,18 @@ def score_srl_spans( ) pred_relation_tuples.update( - [(i, ArgRELItem(**dict(r))) for r in pred_doc._.relations] + [ + (i, PredicateItem(**dict(p)), RoleItem(**dict(r))) + for p, rs in pred_doc._.relations + for r in rs + ] ) gold_relation_tuples.update( - [(i, ArgRELItem(**dict(r))) for r in gold_doc._.relations] + [ + (i, PredicateItem(**dict(p)), RoleItem(**dict(r))) + for p, rs in gold_doc._.relations + for r in rs + ] ) def _overlap_prf(gold: set, pred: set): @@ -100,10 +111,10 @@ def _overlap_prf(gold: set, pred: set): predicates_prf = _overlap_prf(gold_predicates_spans, pred_predicates_spans) micro_rel_prf = _overlap_prf(gold_relation_tuples, pred_relation_tuples) - def _get_label2rels(rel_tuples: Iterable[Tuple[int, ArgRELItem]]): + def _get_label2rels(rel_tuples: Iterable[Tuple[int, PredicateItem, RoleItem]]): label2rels = defaultdict(set) for tup in rel_tuples: - label_ = tup[1].label + label_ = tup[-1].label label2rels[label_].add(tup) return label2rels @@ -201,8 +212,8 @@ def __init__( case_sensitive_matching, single_match, ) - self._verbose = verbose self._predicate_key = predicate_key + self._verbose = verbose self._check_extensions() @classmethod @@ -246,9 +257,10 @@ def initialize( label_set = set() for eg in examples: - rels: List[ArgRELItem] = eg.reference._.relations - for rel in rels: - label_set.add(rel.label) + rels: List[Tuple[PredicateItem, List[RoleItem]]] = eg.relations + for p, rs in rels: + for r in rs: + label_set.add(r.label) labels = list(label_set) self._label_dict = {self._normalizer(label): label for label in labels} @@ -270,7 +282,7 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: yield prompt - def _format_response(self, arg_lines): + def _format_response(self, arg_lines) -> List[Tuple[str, str]]: """Parse raw string response into a structured format""" output = [] # this ensures unique arguments in the sentence for a predicate @@ -350,6 +362,8 @@ def parse_responses( ) predicates.append(pred_item.dict()) + roles = [] + for label, phrase in self._format_response(pred_content_lines): arg_offsets = find_substrings( doc.text, @@ -361,10 +375,12 @@ def parse_responses( arg_item = SpanItem( text=phrase, start_char=start, end_char=end ).dict() - arg_rel_item = ArgRELItem( + arg_rel_item = RoleItem( predicate=pred_item, role=arg_item, label=label ).dict() - relations.append(arg_rel_item) + roles.append(arg_rel_item) + + relations.append((pred_item, roles)) doc._.predicates = predicates doc._.relations = relations diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 62ddfc1e..c1b432e7 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -52,26 +52,39 @@ def task(): "text": text, "predicates": [predicate], "relations": [ - { - "label": "ARG-0", - "predicate": predicate, - "role": {"text": "We", "start_char": 0, "end_char": 2}, - }, - { - "label": "ARG-1", - "predicate": predicate, - "role": {"text": "this sentence", "start_char": 8, "end_char": 21}, - }, - { - "label": "ARG-M-LOC", - "predicate": predicate, - "role": {"text": "in Berlin", "start_char": 22, "end_char": 31}, - }, - { - "label": "ARG-M-TMP", - "predicate": predicate, - "role": {"text": "right now", "start_char": 32, "end_char": 41}, - }, + ( + predicate, + [ + { + "label": "ARG-0", + "role": {"text": "We", "start_char": 0, "end_char": 2}, + }, + { + "label": "ARG-1", + "role": { + "text": "this sentence", + "start_char": 8, + "end_char": 21, + }, + }, + { + "label": "ARG-M-LOC", + "role": { + "text": "in Berlin", + "start_char": 22, + "end_char": 31, + }, + }, + { + "label": "ARG-M-TMP", + "role": { + "text": "right now", + "start_char": 32, + "end_char": 41, + }, + }, + ], + ) ], } ) @@ -118,7 +131,11 @@ def test_rel_predict(task, cfg_string, request): assert doc._.predicates[0]["text"] == gold_example.predicates[0].text - predicated_roles = tuple(sorted([r["role"]["text"] for r in doc._.relations])) - gold_roles = tuple(sorted([r.role.text for r in gold_example.relations])) + predicated_roles = tuple( + sorted([r["role"]["text"] for p, rs in doc._.relations for r in rs]) + ) + gold_roles = tuple( + sorted([r.role.text for p, rs in gold_example.relations for r in rs]) + ) assert predicated_roles == gold_roles diff --git a/usage_examples/span_srl_openai/README.md b/usage_examples/span_srl_openai/README.md index ac65ae3e..2dd5f9d8 100644 --- a/usage_examples/span_srl_openai/README.md +++ b/usage_examples/span_srl_openai/README.md @@ -27,6 +27,13 @@ For example: ```sh python run_pipeline.py \ - "Laura just bought an apartment in Boston." \ + "Laura bought an apartment in Boston last month." \ ./zeroshot.cfg +``` +Output: +```shell +Text: Laura bought an apartment last month in Boston. +SRL Output: +Predicates: ['bought'] +Relations: [('bought', [('ARG-0', 'Laura'), ('ARG-1', 'an apartment'), ('ARG-M-TMP', 'last month'), ('ARG-M-LOC', 'in Boston')])] ``` \ No newline at end of file diff --git a/usage_examples/span_srl_openai/run_pipeline.py b/usage_examples/span_srl_openai/run_pipeline.py index a434cf9f..21898345 100644 --- a/usage_examples/span_srl_openai/run_pipeline.py +++ b/usage_examples/span_srl_openai/run_pipeline.py @@ -3,6 +3,7 @@ from pathlib import Path from spacy_llm.util import assemble +from spacy_llm.tasks.srl_task import SRLExample from typing import Optional from wasabi import msg @@ -35,12 +36,12 @@ def run_pipeline( doc = nlp(text) - msg.text(f"Text: {doc.text}") - msg.text(f"Predicates: {[p['text'] for p in doc._.predicates]}") + doc_srl = SRLExample( + text=doc.text, predicates=doc._.predicates, relations=doc._.relations + ) - msg.text("Relations:") - for r in doc._.relations: - msg.text(f" - {r['predicate']['text']} [{r['label']}] {r['role']['text']}") + msg.text(f"Text: {doc_srl.text}") + print(f"SRL Output:\n{str(doc_srl)}\n") if __name__ == "__main__": From 2a4e8621f6e827f73d56b203afa5db8f2cc87f1b Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 1 Aug 2023 22:02:51 -0600 Subject: [PATCH 24/33] added expected response --- usage_examples/span_srl_openai/README.md | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/usage_examples/span_srl_openai/README.md b/usage_examples/span_srl_openai/README.md index 2dd5f9d8..9822a6e7 100644 --- a/usage_examples/span_srl_openai/README.md +++ b/usage_examples/span_srl_openai/README.md @@ -27,10 +27,26 @@ For example: ```sh python run_pipeline.py \ - "Laura bought an apartment in Boston last month." \ + "Laura bought an apartment last month in Berlin." \ ./zeroshot.cfg ``` -Output: +LLM-response: +```shell +LLM response for doc: Laura bought an apartment last month in Boston. + +Step 1: Extract the Predicates for the Text +Predicates: bought + +Step 2: For each Predicate, extract the Semantic Roles in 'Text' +Text: Laura bought an apartment last month in Boston. +Predicate: bought +ARG-0: Laura +ARG-1: an apartment +ARG-2: +ARG-M-TMP: last month +ARG-M-LOC: in Boston +``` +std output: ```shell Text: Laura bought an apartment last month in Boston. SRL Output: From b38047869842e43a802c171bead41d17c6fd9467 Mon Sep 17 00:00:00 2001 From: rehan Date: Wed, 2 Aug 2023 10:36:27 -0600 Subject: [PATCH 25/33] Removing print statement --- usage_examples/span_srl_openai/run_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/usage_examples/span_srl_openai/run_pipeline.py b/usage_examples/span_srl_openai/run_pipeline.py index 21898345..7e97666c 100644 --- a/usage_examples/span_srl_openai/run_pipeline.py +++ b/usage_examples/span_srl_openai/run_pipeline.py @@ -41,7 +41,7 @@ def run_pipeline( ) msg.text(f"Text: {doc_srl.text}") - print(f"SRL Output:\n{str(doc_srl)}\n") + msg.text(f"SRL Output:\n{str(doc_srl)}\n") if __name__ == "__main__": From 8fc6b8d44c106a794ec0b2743b6b787e1683b37f Mon Sep 17 00:00:00 2001 From: rehan Date: Mon, 7 Aug 2023 11:43:03 -0600 Subject: [PATCH 26/33] Added few-shot span-srl --- spacy_llm/tasks/srl_task.py | 54 ++++++++++++++++++- spacy_llm/tasks/templates/span-srl.v1.jinja | 32 +++++++++++ spacy_llm/tests/tasks/examples/span_srl.jsonl | 1 + spacy_llm/tests/tasks/test_span_srl.py | 51 ++++++++++++++---- usage_examples/span_srl_openai/README.md | 12 ++++- usage_examples/span_srl_openai/examples.jsonl | 1 + usage_examples/span_srl_openai/fewshot.cfg | 28 ++++++++++ usage_examples/span_srl_openai/zeroshot.cfg | 3 -- 8 files changed, 166 insertions(+), 16 deletions(-) create mode 100644 spacy_llm/tests/tasks/examples/span_srl.jsonl create mode 100644 usage_examples/span_srl_openai/examples.jsonl create mode 100644 usage_examples/span_srl_openai/fewshot.cfg diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 966d14a0..822c6389 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -1,9 +1,10 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, Any +import jinja2 import re +import warnings from collections import defaultdict -import jinja2 from pydantic import BaseModel, ValidationError from spacy.language import Language from spacy.tokens import Doc @@ -216,6 +217,54 @@ def __init__( self._verbose = verbose self._check_extensions() + def _check_label_consistency(self) -> List[SRLExample]: + """Checks consistency of labels between examples and defined labels. Emits warning on inconsistency. + RETURNS (List[SRLExample]): List of SpanExamples with valid labels. + """ + assert self._prompt_examples + srl_examples = [SRLExample(**eg.dict()) for eg in self._prompt_examples] + example_labels = { + self._normalizer(r.label): r.label + for example in srl_examples + for p, rs in example.relations + for r in rs + } + unspecified_labels = { + example_labels[key] + for key in (set(example_labels.keys()) - set(self._label_dict.keys())) + } + if not set(example_labels.keys()) <= set(self._label_dict.keys()): + warnings.warn( + f"Examples contain labels that are not specified in the task configuration. The latter contains the " + f"following labels: {sorted(list(set(self._label_dict.values())))}. Labels in examples missing from " + f"the task configuration: {sorted(list(unspecified_labels))}. Please ensure your label specification " + f"and example labels are consistent." + ) + + # Return examples without non-declared roles. the roles within a predicate that have undeclared role labels + # are discarded. + return [ + example + for example in [ + SRLExample( + text=example.text, + predicates=example.predicates, + relations=[ + ( + p, + [ + r + for r in rs + if self._normalizer(r.label) in self._label_dict + ], + ) + for p, rs in example.relations + ], + ) + for example in srl_examples + ] + ] + @classmethod def _check_extensions(cls): """Add `predicates` extension if need be. @@ -229,7 +278,7 @@ def _check_extensions(cls): def initialize( self, - get_examples: Callable[[], Iterable["Example"]], + get_examples: Callable[[], Iterable["SRLExample"]], nlp: Language, labels: List[str] = [], ) -> None: @@ -278,6 +327,7 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: labels=list(self._label_dict.values()), label_definitions=self._label_definitions, predicates=predicates, + examples=self._prompt_examples, ) yield prompt diff --git a/spacy_llm/tasks/templates/span-srl.v1.jinja b/spacy_llm/tasks/templates/span-srl.v1.jinja index 4fa142df..f8a5ec79 100644 --- a/spacy_llm/tasks/templates/span-srl.v1.jinja +++ b/spacy_llm/tasks/templates/span-srl.v1.jinja @@ -19,6 +19,38 @@ Predicate: {# whitespace #} {%- endfor -%} {# whitespace #} +{%- if examples -%} +{# whitespace #} +Below are a few similar examples (only use these as a guide): +{# whitespace #} +{# whitespace #} +{%- for example in examples -%} +Example Text: +''' +{{ example.text }} +''' +{# whitespace #} +Step 1: Extract the Predicates in '''Example Text''': +{# whitespace #} +Predicates: {{ example.predicates|map(attribute='text')|join(', ') }} +{# whitespace #} +Step 2: For each Predicate, extract the Sematic Roles in '''Example Text''': +{# whitespace #} +{%- for predicate, relations in example.relations -%} +Predicate: {{predicate.text}} +{# whitespace #} +{%- for relation in relations -%} +{{relation.label}}: {{relation.role.text}} +{# whitespace #} +{%- endfor -%} +{# whitespace #} +{# whitespace #} +{%- endfor -%} +{# whitespace #} +{%- endfor -%} +{# whitespace #} +{%- endif -%} +{# whitespace #} Here is the text that needs labeling: {# whitespace #} Text: diff --git a/spacy_llm/tests/tasks/examples/span_srl.jsonl b/spacy_llm/tests/tasks/examples/span_srl.jsonl new file mode 100644 index 00000000..98f73555 --- /dev/null +++ b/spacy_llm/tests/tasks/examples/span_srl.jsonl @@ -0,0 +1 @@ +{"text": "Ben bought a house last year in Berlin .", "predicates": [{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}], "relations": [[{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}, [{"role": {"text": "Ben", "start_char": 0, "end_char": 3}, "label": "ARG-0"}, {"role": {"text": "a house", "start_char": 11, "end_char": 18}, "label": "ARG-1"}, {"role": {"text": "last year", "start_char": 19, "end_char": 28}, "label": "ARG-M-TMP"}, {"role": {"text": "in Berlin", "start_char": 29, "end_char": 38}, "label": "ARG-M-LOC"}]]]} diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index c1b432e7..cd260507 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -29,12 +29,11 @@ def zeroshot_cfg_string(): [components.llm.task] @llm_tasks = "spacy.SRL.v1" - labels = ARG-0,ARG-1,ARG-2,ARG-M-LOC,ARG-M-TMP + labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC [components.llm.task.label_definitions] ARG-0 = "Agent" ARG-1 = "Patient or Theme" - ARG-2 = "ARG-2" ARG-M-TMP = "Temporal Modifier" ARG-M-LOC = "Location Modifier" @@ -43,9 +42,43 @@ def zeroshot_cfg_string(): """ +@pytest.fixture +def fewshot_cfg_string(): + return f""" + [paths] + examples = null + + [nlp] + lang = "en" + pipeline = ["llm"] + + [components] + + [components.llm] + factory = "llm" + + [components.llm.task] + @llm_tasks = "spacy.SRL.v1" + labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + + [components.llm.task.label_definitions] + ARG-0 = "Agent" + ARG-1 = "Patient or Theme" + ARG-M-TMP = "Temporal Modifier" + ARG-M-LOC = "Location Modifier" + + [components.llm.task.examples] + @misc = "spacy.FewShotReader.v1" + path = {str((Path(__file__).parent / "examples" / "span_srl.jsonl"))} + + [components.llm.model] + @llm_models = "spacy.GPT-3-5.v1" + """ + + @pytest.fixture def task(): - text = "We love this sentence in Berlin right now ." + text = "We love this sentence right now in Berlin" predicate = {"text": "love", "start_char": 3, "end_char": 7} srl_example = SRLExample( **{ @@ -68,17 +101,17 @@ def task(): }, }, { - "label": "ARG-M-LOC", + "label": "ARG-M-TMP", "role": { - "text": "in Berlin", + "text": "right now", "start_char": 22, "end_char": 31, }, }, { - "label": "ARG-M-TMP", + "label": "ARG-M-LOC", "role": { - "text": "right now", + "text": "in Berlin", "start_char": 32, "end_char": 41, }, @@ -106,7 +139,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): task = pipe.task labels = orig_config["components"]["llm"]["task"]["labels"] - labels = split_labels(labels) + labels = sorted(split_labels(labels)) assert isinstance(task, Labeled) assert task.labels == tuple(labels) assert set(pipe.labels) == set(task.labels) @@ -114,7 +147,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string"]) +@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"]) def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable diff --git a/usage_examples/span_srl_openai/README.md b/usage_examples/span_srl_openai/README.md index 9822a6e7..61e22fe3 100644 --- a/usage_examples/span_srl_openai/README.md +++ b/usage_examples/span_srl_openai/README.md @@ -30,8 +30,16 @@ python run_pipeline.py \ "Laura bought an apartment last month in Berlin." \ ./zeroshot.cfg ``` +or, for few-shot: +```sh +python run_pipeline.py \ + "Laura bought an apartment last month in Berlin." \ + ./fewshot.cfg \ + ./examples.jsonl +``` + LLM-response: -```shell +```sh LLM response for doc: Laura bought an apartment last month in Boston. Step 1: Extract the Predicates for the Text @@ -47,7 +55,7 @@ ARG-M-TMP: last month ARG-M-LOC: in Boston ``` std output: -```shell +```sh Text: Laura bought an apartment last month in Boston. SRL Output: Predicates: ['bought'] diff --git a/usage_examples/span_srl_openai/examples.jsonl b/usage_examples/span_srl_openai/examples.jsonl new file mode 100644 index 00000000..98f73555 --- /dev/null +++ b/usage_examples/span_srl_openai/examples.jsonl @@ -0,0 +1 @@ +{"text": "Ben bought a house last year in Berlin .", "predicates": [{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}], "relations": [[{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}, [{"role": {"text": "Ben", "start_char": 0, "end_char": 3}, "label": "ARG-0"}, {"role": {"text": "a house", "start_char": 11, "end_char": 18}, "label": "ARG-1"}, {"role": {"text": "last year", "start_char": 19, "end_char": 28}, "label": "ARG-M-TMP"}, {"role": {"text": "in Berlin", "start_char": 29, "end_char": 38}, "label": "ARG-M-LOC"}]]]} diff --git a/usage_examples/span_srl_openai/fewshot.cfg b/usage_examples/span_srl_openai/fewshot.cfg new file mode 100644 index 00000000..ffd1304a --- /dev/null +++ b/usage_examples/span_srl_openai/fewshot.cfg @@ -0,0 +1,28 @@ +[paths] +examples = null + +[nlp] +lang = "en" +pipeline = ["llm"] + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.SRL.v1" +labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + +[components.llm.task.label_definitions] +ARG-0 = "Agent" +ARG-1 = "Patient or Theme" +ARG-M-TMP = "Temporal Modifier" +ARG-M-LOC = "Location Modifier" + +[components.llm.task.examples] +@misc = "spacy.FewShotReader.v1" +path = ${paths.examples} + +[components.llm.model] +@llm_models = "spacy.GPT-3-5.v1" \ No newline at end of file diff --git a/usage_examples/span_srl_openai/zeroshot.cfg b/usage_examples/span_srl_openai/zeroshot.cfg index bcfc03ef..cf826712 100644 --- a/usage_examples/span_srl_openai/zeroshot.cfg +++ b/usage_examples/span_srl_openai/zeroshot.cfg @@ -1,6 +1,3 @@ -[paths] -examples = null - [nlp] lang = "en" pipeline = ["llm"] From 73bf0f6eb60fe1312b86a5fc6c9542d679ba0a89 Mon Sep 17 00:00:00 2001 From: rehan Date: Mon, 7 Aug 2023 14:33:07 -0600 Subject: [PATCH 27/33] Add examples path in srl docs --- usage_examples/span_srl_openai/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/usage_examples/span_srl_openai/README.md b/usage_examples/span_srl_openai/README.md index 61e22fe3..52643b22 100644 --- a/usage_examples/span_srl_openai/README.md +++ b/usage_examples/span_srl_openai/README.md @@ -20,7 +20,7 @@ export OPENAI_API_ORG="org-..." Then, you can run the pipeline on a sample text via: ```sh -python run_pipeline.py [TEXT] [PATH TO CONFIG] +python run_pipeline.py [TEXT] [PATH TO CONFIG] [PATH TO FILE WITH EXAMPLES] ``` For example: From 824aa828e92bccdd328e9c2c636f33926ef035eb Mon Sep 17 00:00:00 2001 From: rehan Date: Tue, 8 Aug 2023 12:24:05 -0600 Subject: [PATCH 28/33] removing whitespaces causing commit check failures --- spacy_llm/tests/tasks/test_span_srl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index cd260507..3b901b38 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -30,7 +30,7 @@ def zeroshot_cfg_string(): [components.llm.task] @llm_tasks = "spacy.SRL.v1" labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC - + [components.llm.task.label_definitions] ARG-0 = "Agent" ARG-1 = "Patient or Theme" @@ -66,7 +66,7 @@ def fewshot_cfg_string(): ARG-1 = "Patient or Theme" ARG-M-TMP = "Temporal Modifier" ARG-M-LOC = "Location Modifier" - + [components.llm.task.examples] @misc = "spacy.FewShotReader.v1" path = {str((Path(__file__).parent / "examples" / "span_srl.jsonl"))} From 6d5efc9bd525a0fb48aa3528b2de1298ec4e364e Mon Sep 17 00:00:00 2001 From: rehan Date: Wed, 9 Aug 2023 22:24:55 -0600 Subject: [PATCH 29/33] Make SRLExample hashable to remove duplicate examples --- spacy_llm/tasks/srl_task.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 822c6389..5565fe37 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -51,6 +51,9 @@ class SRLExample(BaseModel): predicates: List[PredicateItem] relations: List[Tuple[PredicateItem, List[RoleItem]]] + def __hash__(self): + return hash((self.text,) + tuple(self.predicates)) + def __str__(self): return f"""Predicates: {', '.join([p.text for p in self.predicates])} Relations: {str([(p.text, [(r.label, r.role.text) for r in rs]) for p, rs in self.relations])}""" From 2a9ede58f565ec927ee66ad5d2a4a1c2d501df05 Mon Sep 17 00:00:00 2001 From: rehan Date: Fri, 11 Aug 2023 10:45:18 -0600 Subject: [PATCH 30/33] Add doc-tailored examples in generate_prompts --- spacy_llm/tasks/srl_task.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 5565fe37..dab2c403 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -222,7 +222,7 @@ def __init__( def _check_label_consistency(self) -> List[SRLExample]: """Checks consistency of labels between examples and defined labels. Emits warning on inconsistency. - RETURNS (List[SRLExample]): List of SpanExamples with valid labels. + RETURNS (List[SRLExample]): List of SRLExamples with valid labels. """ assert self._prompt_examples srl_examples = [SRLExample(**eg.dict()) for eg in self._prompt_examples] @@ -325,12 +325,18 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: if len(doc._.predicates): predicates = ", ".join([p["text"] for p in doc._.predicates]) + doc_examples = self._prompt_examples + + # check if there are doc-tailored examples + if doc.has_extension("egs") and doc._.egs is not None and len(doc._.egs): + doc_examples = doc._.egs + prompt = _template.render( text=doc.text, labels=list(self._label_dict.values()), label_definitions=self._label_definitions, predicates=predicates, - examples=self._prompt_examples, + examples=doc_examples, ) yield prompt From be5065570bd235e5b297f85228770cd100c49ff8 Mon Sep 17 00:00:00 2001 From: rehan Date: Wed, 16 Aug 2023 10:40:06 -0600 Subject: [PATCH 31/33] Added defs for alignment modes Added docs for srl response parsing --- spacy_llm/tasks/srl_task.py | 47 ++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index dab2c403..19ea3249 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -166,7 +166,10 @@ def make_srl_task( reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. normalizer (Optional[Callable[[str], str]]): optional normalizer function. - alignment_mode (Literal["strict", "contract", "expand"]): "strict", "contract" or "expand". + alignment_mode (Literal["strict", "contract", "expand"]): How character indices snap to token boundaries. + Options: "strict" (no snapping), "contract" (span of all tokens completely within the character span), + "expand" (span of all tokens at least partially covered by the character span). + Defaults to "strict". case_sensitive_matching: Whether to search without case sensitivity. single_match (bool): If False, allow one substring to match multiple times in the text. If True, returns the first hit. @@ -377,6 +380,48 @@ def _format_response(self, arg_lines) -> List[Tuple[str, str]]: def parse_responses( self, docs: Iterable[Doc], responses: Iterable[str] ) -> Iterable[Doc]: + """ + Parse LLM response by extracting predicate-arguments blocks from the generate response. + For example, + LLM response for doc: "A sentence with multiple predicates (p1, p2)" + + Step 1: Extract the Predicates for the Text + Predicates: p1, p2 + + Step 2: For each Predicate, extract the Semantic Roles in 'Text' + Text: A sentence with multiple predicates (p1, p2) + Predicate: p1 + ARG-0: a0_1 + ARG-1: a1_1 + ARG-M-TMP: a_t_1 + ARG-M-LOC: a_l_1 + + Predicate: p2 + ARG-0: a0_2 + ARG-1: a1_2 + ARG-M-TMP: a_t_2 + + So the steps in the parsing are to first find the text boundaries for the information + of each predicate. This is done by identifying the lines "Predicate: p1" and "Predicate: p2", + which gives us the text for each predicate as follows: + + Predicate: p1 + ARG-0: a0_1 + ARG-1: a1_1 + ARG-M-TMP: a_t_1 + ARG-M-LOC: a_l_1 + + and, + + Predicate: p2 + ARG-0: a0_2 + ARG-1: a1_2 + ARG-M-TMP: a_t_2 + + Once we separate these out, then it is a matter of parsing line by line to extract the predicate + and its args for each predicate block + + """ for doc, prompt_response in zip(docs, responses): predicates = [] relations = [] From 0970e64fc181f023ddc7f5fd5eac111ff72d08d0 Mon Sep 17 00:00:00 2001 From: Rehan Ahmed Date: Wed, 23 Aug 2023 00:11:46 -0400 Subject: [PATCH 32/33] fix serialization issue of pred_item --- spacy_llm/tasks/srl_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy_llm/tasks/srl_task.py b/spacy_llm/tasks/srl_task.py index 19ea3249..abde93cb 100644 --- a/spacy_llm/tasks/srl_task.py +++ b/spacy_llm/tasks/srl_task.py @@ -463,8 +463,8 @@ def parse_responses( p_start_char, p_end_char = pred_offsets[0] pred_item = PredicateItem( text=pred_str, start_char=p_start_char, end_char=p_end_char - ) - predicates.append(pred_item.dict()) + ).dict() + predicates.append(pred_item) roles = [] From 3e0a50efa0f74bb4179e60a61d252dc15a2c3234 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 18 Sep 2023 14:41:01 +0200 Subject: [PATCH 33/33] Update spacy_llm/tests/tasks/test_span_srl.py --- spacy_llm/tests/tasks/test_span_srl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy_llm/tests/tasks/test_span_srl.py b/spacy_llm/tests/tasks/test_span_srl.py index 3b901b38..921e9e09 100644 --- a/spacy_llm/tests/tasks/test_span_srl.py +++ b/spacy_llm/tests/tasks/test_span_srl.py @@ -150,7 +150,6 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"]) def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. - Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable """ cfg_string = request.getfixturevalue(cfg_string) orig_config = Config().from_str(cfg_string)