diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index 86426ed0..fd44506d 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -105,7 +105,6 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] self._n_shards = None - return [ EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i]) for i, doc in enumerate(docs) @@ -335,7 +334,11 @@ def unhighlight_ents_in_doc(doc: Doc) -> Doc: for ent in doc.ents if ent.start - 1 > 0 and doc[ent.start - 1].text == "*" } - highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"} + highlight_end_idx = { + ent.end + for ent in doc.ents + if ent.end < len(doc) and doc[ent.end].text == "*" + } highlight_idx = highlight_start_idx | highlight_end_idx # Compute entity indices with removed highlights. diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index 6101236b..45f18e7e 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -682,9 +682,38 @@ def test_ent_highlighting(): EntityLinkerTask.highlight_ents_in_doc(doc).text == "Alice goes to *Boston* to see the *Boston Celtics* game." ) + + +@pytest.mark.parametrize( + "text,ents,include_ents", + [ + ( + "Alice goes to Boston to see the Boston Celtics game.", + [ + {"start": 3, "end": 4, "label": "LOC"}, + {"start": 7, "end": 9, "label": "ORG"}, + ], + [True, True], + ), + ( + "I went to see Boston in concert yesterday", + [ + {"start": 4, "end": 5, "label": "GPE"}, + {"start": 7, "end": 8, "label": "DATE"}, + ], + [True, False], + ), + ], +) +def test_ent_unhighlighting(text, ents, include_ents): + """Tests unhighlighting of entities in text.""" + nlp = spacy.blank("en") + doc = nlp.make_doc(text) + doc.ents = [Span(doc=doc, **ents[0]), Span(doc=doc, **ents[1])] + assert ( EntityLinkerTask.unhighlight_ents_in_doc( - EntityLinkerTask.highlight_ents_in_doc(doc) + EntityLinkerTask.highlight_ents_in_doc(doc, include_ents) ).text == doc.text == text