Skip to content

Commit

Permalink
Fix/index error when unhighlighting (#434)
Browse files Browse the repository at this point in the history
* make sure nonhighlighted ents don't cause IndexError when unhighlighting

* linting
  • Loading branch information
magdaaniol authored Jan 29, 2024
1 parent 0fc4633 commit 2e88594
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
7 changes: 5 additions & 2 deletions spacy_llm/tasks/entity_linker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._n_shards = None

return [
EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i])
for i, doc in enumerate(docs)
Expand Down Expand Up @@ -335,7 +334,11 @@ def unhighlight_ents_in_doc(doc: Doc) -> Doc:
for ent in doc.ents
if ent.start - 1 > 0 and doc[ent.start - 1].text == "*"
}
highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"}
highlight_end_idx = {
ent.end
for ent in doc.ents
if ent.end < len(doc) and doc[ent.end].text == "*"
}
highlight_idx = highlight_start_idx | highlight_end_idx

# Compute entity indices with removed highlights.
Expand Down
31 changes: 30 additions & 1 deletion spacy_llm/tests/tasks/test_entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,38 @@ def test_ent_highlighting():
EntityLinkerTask.highlight_ents_in_doc(doc).text
== "Alice goes to *Boston* to see the *Boston Celtics* game."
)


@pytest.mark.parametrize(
"text,ents,include_ents",
[
(
"Alice goes to Boston to see the Boston Celtics game.",
[
{"start": 3, "end": 4, "label": "LOC"},
{"start": 7, "end": 9, "label": "ORG"},
],
[True, True],
),
(
"I went to see Boston in concert yesterday",
[
{"start": 4, "end": 5, "label": "GPE"},
{"start": 7, "end": 8, "label": "DATE"},
],
[True, False],
),
],
)
def test_ent_unhighlighting(text, ents, include_ents):
"""Tests unhighlighting of entities in text."""
nlp = spacy.blank("en")
doc = nlp.make_doc(text)
doc.ents = [Span(doc=doc, **ents[0]), Span(doc=doc, **ents[1])]

assert (
EntityLinkerTask.unhighlight_ents_in_doc(
EntityLinkerTask.highlight_ents_in_doc(doc)
EntityLinkerTask.highlight_ents_in_doc(doc, include_ents)
).text
== doc.text
== text
Expand Down

0 comments on commit 2e88594

Please sign in to comment.