Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/index error when unhighlighting #434

Merged
merged 2 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions spacy_llm/tasks/entity_linker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._n_shards = None

return [
EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i])
for i, doc in enumerate(docs)
Expand Down Expand Up @@ -335,7 +334,11 @@ def unhighlight_ents_in_doc(doc: Doc) -> Doc:
for ent in doc.ents
if ent.start - 1 > 0 and doc[ent.start - 1].text == "*"
}
highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"}
highlight_end_idx = {
ent.end
for ent in doc.ents
if ent.end < len(doc) and doc[ent.end].text == "*"
}
highlight_idx = highlight_start_idx | highlight_end_idx

# Compute entity indices with removed highlights.
Expand Down
31 changes: 30 additions & 1 deletion spacy_llm/tests/tasks/test_entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,38 @@ def test_ent_highlighting():
EntityLinkerTask.highlight_ents_in_doc(doc).text
== "Alice goes to *Boston* to see the *Boston Celtics* game."
)


@pytest.mark.parametrize(
"text,ents,include_ents",
[
(
"Alice goes to Boston to see the Boston Celtics game.",
[
{"start": 3, "end": 4, "label": "LOC"},
{"start": 7, "end": 9, "label": "ORG"},
],
[True, True],
),
(
"I went to see Boston in concert yesterday",
[
{"start": 4, "end": 5, "label": "GPE"},
{"start": 7, "end": 8, "label": "DATE"},
],
[True, False],
),
],
)
def test_ent_unhighlighting(text, ents, include_ents):
"""Tests unhighlighting of entities in text."""
nlp = spacy.blank("en")
doc = nlp.make_doc(text)
doc.ents = [Span(doc=doc, **ents[0]), Span(doc=doc, **ents[1])]

assert (
EntityLinkerTask.unhighlight_ents_in_doc(
EntityLinkerTask.highlight_ents_in_doc(doc)
EntityLinkerTask.highlight_ents_in_doc(doc, include_ents)
).text
== doc.text
== text
Expand Down
Loading