From 31589c677bc7cb982467159d275436a35cb69904 Mon Sep 17 00:00:00 2001 From: Gregor Sturm Date: Wed, 1 Nov 2023 18:21:38 +0100 Subject: [PATCH 1/4] Speed up index_chains (#444) * Add MWE for index_chains in numba * Try different approach, but it seems worse * Another (buggy) MWE for a better implementation * Apparently fully working vectorized implementation of index_chains * Tolerate missing sort key * Document code * fix tests * Update changelog * Update changelog --- CHANGELOG.md | 9 ++ docs/tutorials/tutorial_3k_tcr.ipynb | 4 +- src/scirpy/pp/_index_chains.py | 173 ++++++++++++++++++++++++++- 3 files changed, 181 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51d4cdf8f..ca342e0c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning][]. ## [Unreleased] +### Breaking changes + +- Reimplement `pp.index_chains` using numba and awkward array functions, achieving a significant speedup. This function + behaves exactly like the previous version _except_ that callback functions passed to the `filter` arguments + must now be vectorized over an awkward array, e.g. to check if a `junction_aa` field is present you could + previously pass `lambda x: x['junction_aa'] is not None`, now an accepted version would be + `lambda x: ~ak.is_none(x["junction_aa"], axis=-1)`. To learn more about native awkward array functions, please + refer to the [awkward array documentation](https://awkward-array.org/doc/main/reference/index.html). ([#444](https://github.com/scverse/scirpy/pull/444)) + ### Fixes - Fix that `define_clonotype_clusters` could not retreive `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459)) diff --git a/docs/tutorials/tutorial_3k_tcr.ipynb b/docs/tutorials/tutorial_3k_tcr.ipynb index 9d90bb352..13f61a577 100644 --- a/docs/tutorials/tutorial_3k_tcr.ipynb +++ b/docs/tutorials/tutorial_3k_tcr.ipynb @@ -3537,7 +3537,7 @@ "notebook_metadata_filter": "-kernelspec" }, "kernelspec": { - "display_name": "scirpy_dev2", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -3551,7 +3551,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/src/scirpy/pp/_index_chains.py b/src/scirpy/pp/_index_chains.py index 143b678e7..d4fc1e4e9 100644 --- a/src/scirpy/pp/_index_chains.py +++ b/src/scirpy/pp/_index_chains.py @@ -1,19 +1,25 @@ +import operator from collections.abc import Mapping, Sequence -from functools import partial +from functools import partial, reduce from types import MappingProxyType from typing import Any, Callable, Union import awkward as ak +import numba as nb +import numpy as np from scanpy import logging from scirpy.io._datastructures import AirrCell from scirpy.util import DataHandler, _is_na2, tqdm SCIRPY_DUAL_IR_MODEL = "scirpy_dual_ir_v0.13" +# make these constants available to numba +_VJ_LOCI = tuple(AirrCell.VJ_LOCI) +_VDJ_LOCI = tuple(AirrCell.VDJ_LOCI) @DataHandler.inject_param_docs() -def index_chains( +def index_chains_legacy( adata: DataHandler.TYPE, *, filter: Union[Callable[[Mapping], bool], Sequence[Union[str, Callable[[Mapping], bool]]]] = ( @@ -135,7 +141,168 @@ def index_chains( } -def _key_sort_chains(chains: list[Mapping], sort_chains_by: Mapping[str, Any], idx: int) -> Sequence: +@DataHandler.inject_param_docs() +def index_chains( + adata: DataHandler.TYPE, + *, + filter: Union[Callable[[ak.Array], bool], Sequence[Union[str, Callable[[ak.Array], bool]]]] = ( + "productive", + "require_junction_aa", + ), + sort_chains_by: Mapping[str, Any] = MappingProxyType( + {"duplicate_count": 0, "consensus_count": 0, "junction": "", "junction_aa": ""} + ), + airr_mod: str = "airr", + airr_key: str = "airr", + key_added: str = "chain_indices", +) -> None: + """\ + Selects primary/secondary VJ/VDJ cells per chain according to the :ref:`receptor-model`. + + This function iterates through all chains stored in the :term:`awkward array` in + `adata.obsm[airr_key]` and + + * labels chains as primary/secondary VJ/VDJ chains + * labels cells as multichain cells + + based on the expression level of the chains and the specified filtering option. + By default, non-productive chains and chains without a valid CDR3 amino acid sequence are filtered out. + + Additionally, chains without a valid IMGT locus are always filtered out. + + For more details, please refer to the :ref:`receptor-model` and the :ref:`data structure `. + + Parameters + ---------- + {adata} + filter + Option to filter chains. Can be either + * a callback function that takes the full awkward array with AIRR chains as input and returns + another awkward array that is a boolean mask which can be used to index the former. + (True to keep, False to discard) + * a list of "filtering presets". Possible values are `"productive"` and `"require_junction_aa"`. + `"productive"` removes non-productive chains and `"require_junction_aa"` removes chains that don't have + a CDR3 amino acid sequence. + * a list with a combination of both. + + Multiple presets/functions are combined using `and`. Filtered chains do not count towards calling "multichain" cells. + sort_chains_by + A list of sort keys used to determine an ordering of chains. The chain with the highest value + of this tuple will be the primary chain, second-highest the secondary chain. If there are more chains, they + will not be indexed, and the cell receives the "multichain" flag. + {airr_mod} + {airr_key} + key_added + Key under which the chain indicies will be stored in `adata.obsm` and metadata will be stored in `adata.uns`. + + Returns + ------- + Nothing, but adds a dataframe to `adata.obsm[chain_indices]` + """ + params = DataHandler(adata, airr_mod, airr_key) + + # prepare filter functions + if isinstance(filter, Callable): + filter = [filter] + filter_presets = { + "productive": lambda x: x["productive"], + "require_junction_aa": lambda x: ~ak.is_none(x["junction_aa"], axis=-1), + } + filter = [filter_presets[f] if isinstance(f, str) else f for f in filter] + + # only warn if those fields are in the key (i.e. this should give a warning if those are missing with + # default settings. If the user specifies their own dictionary, they are on their own) + if "duplicate_count" in sort_chains_by and "consensus_count" in sort_chains_by: + if "duplicate_count" not in params.airr.fields and "consensus_count" not in params.airr.fields: + logging.warning("No expression information available. Cannot rank chains by expression. ") # type: ignore + + if "locus" not in params.airr.fields: + raise ValueError("The scirpy receptor model requires a `locus` field to be specified in the AIRR data.") + + airr = params.airr + logging.info("Filtering chains...") + # Get the numeric indices pre-filtering - these are the indices we need in the final output as + # .obsm["airr"] is and remains unfiltered. + airr_idx = ak.local_index(airr, axis=1) + # Filter out chains that do not match the filter criteria + # we need an initial value that selects all chains in case filter is an empty list + airr_idx = airr_idx[reduce(operator.and_, (f(airr) for f in filter), ak.ones_like(airr_idx, dtype=bool))] + + res = {} + is_multichain = np.zeros(len(airr), dtype=bool) + for chain_type, locus_names in {"VJ": AirrCell.VJ_LOCI, "VDJ": AirrCell.VDJ_LOCI}.items(): + logging.info(f"Indexing {chain_type} chains...") + # get the indices for all VJ / VDJ chains, respectively + idx = airr_idx[_awkward_isin(airr["locus"][airr_idx], locus_names)] + + # Now we need to sort the chains by the keys specified in `sort_chains_by`. + # since `argsort` doesn't support composite keys, we take advantage of the + # fact that the sorting algorithm is stable and sort the same array several times, + # starting with the lowest priority key up to the highest priority key. + for k, default in reversed(sort_chains_by.items()): + # skip this round of sorting altogether if field not present + if k in airr.fields: + logging.debug(f"Sorting chains by {k}") + tmp_idx = ak.argsort(ak.fill_none(airr[k][idx], default), stable=True, axis=-1, ascending=False) + idx = idx[tmp_idx] + else: + logging.debug(f"Skip sorting by {k} because field not present") + + # We want the result to be lists of exactly 2 - clip if longer, pad with None if shorter. + res[chain_type] = ak.pad_none(idx, 2, axis=1, clip=True) + is_multichain |= ak.to_numpy(_awkward_len(idx)) > 2 + + # build results + logging.info("build result array") + res["multichain"] = is_multichain + + params.adata.obsm[key_added] = ak.zip(res, depth_limit=1) # type: ignore + + # store metadata in .uns + params.adata.uns[key_added] = { + "model": SCIRPY_DUAL_IR_MODEL, # can be used to distinguish different receptor models that may be added in the future. + "filter": str(filter), + "airr_key": airr_key, + "sort_chains_by": str(sort_chains_by), + } + + +@nb.njit +def _awkward_len_inner(arr, ab): + for row in arr: + ab.append(len(row)) + return ab + + +def _awkward_len(arr): + return _awkward_len_inner(arr, ak.ArrayBuilder()).snapshot() + + +@nb.njit() +def _awkward_isin_inner(arr, haystack, ab): + for row in arr: + ab.begin_list() + for v in row: + ab.append(v in haystack) + ab.end_list() + return ab + + +def _awkward_isin(arr, haystack): + haystack = tuple(haystack) + return _awkward_isin_inner(arr, haystack, ak.ArrayBuilder()).snapshot() + + +# For future reference, here would be two alternative implementations that are a bit +# slower, but work without the need for numba. +# def _awkward_len(arr): +# return ak.max(ak.local_index(arr, axis=1), axis=1) +# +# def _awkward_isin(arr, haystack): +# return reduce(operator.or_, (arr == el for el in haystack)) + + +def _key_sort_chains(chains, sort_chains_by: Mapping[str, Any], idx: int) -> Sequence: """Get key to sort chains by expression. Parameters From 1a75fdc39e764423b5289565b3d042a3625f0983 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Nov 2023 09:14:55 +0100 Subject: [PATCH 2/4] [pre-commit.ci] pre-commit autoupdate (#462) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.1.3 → v0.1.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.3...v0.1.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b46e797de..7b6f7c958 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: language_version: "17.9.1" exclude: '^\.conda' - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.1.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 3120a9d97729e9cb346de0aa389fd9732aa232b0 Mon Sep 17 00:00:00 2001 From: Gregor Sturm Date: Fri, 10 Nov 2023 18:33:54 +0100 Subject: [PATCH 3/4] Fix #454 (#465) * Fix #454 * Update changelog --- CHANGELOG.md | 1 + src/scirpy/io/_io.py | 6 +++--- src/scirpy/tests/test_io.py | 9 +++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca342e0c5..4cc1900ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning][]. ### Fixes - Fix that `define_clonotype_clusters` could not retreive `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459)) +- Fix that AIRR Rearrangment fields of integer types could not be written when their value was None ([#465](https://github.com/scverse/scirpy/pull/465)) ## v0.13.1 diff --git a/src/scirpy/io/_io.py b/src/scirpy/io/_io.py index bc36a44cd..f368bc2c5 100644 --- a/src/scirpy/io/_io.py +++ b/src/scirpy/io/_io.py @@ -584,9 +584,9 @@ def write_airr(adata: DataHandler.TYPE, filename: Union[str, Path], **kwargs) -> for tmp_cell in airr_cells: for chain in tmp_cell.to_airr_records(): # workaround for AIRR library writing out int field as floats (if it happens to be a float) - for f in chain: - if RearrangementSchema.type(f) == "integer": - chain[f] = int(chain[f]) + for field, value in chain.items(): + if RearrangementSchema.type(field) == "integer" and value is not None: + chain[field] = int(value) writer.write(chain) writer.close() diff --git a/src/scirpy/tests/test_io.py b/src/scirpy/tests/test_io.py index 6cfb60495..92919fe3c 100644 --- a/src/scirpy/tests/test_io.py +++ b/src/scirpy/tests/test_io.py @@ -236,6 +236,15 @@ def test_airr_roundtrip_conversion(anndata_from_10x_sample, tmp_path): pdt.assert_frame_equal(anndata.obs, anndata2.obs, check_dtype=False, check_categorical=False) +def test_write_airr_none_field_issue_454(tmp_path): + cell = AirrCell("cell1") + chain = cell.empty_chain_dict() + chain["d_sequence_end"] = None + cell.add_chain(chain) + adata = from_airr_cells([cell]) + write_airr(adata, tmp_path / "test.airr.tsv") + + @pytest.mark.extra @pytest.mark.parametrize( "anndata_from_10x_sample", From de6453e21c760967b98dd3b0650628660e05f1c4 Mon Sep 17 00:00:00 2001 From: Gregor Sturm Date: Sat, 11 Nov 2023 17:33:59 +0100 Subject: [PATCH 4/4] Add testcase --- src/scirpy/tests/test_plotting.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/scirpy/tests/test_plotting.py b/src/scirpy/tests/test_plotting.py index 882cc537a..9a0630bbb 100644 --- a/src/scirpy/tests/test_plotting.py +++ b/src/scirpy/tests/test_plotting.py @@ -16,6 +16,13 @@ def test_clonal_expansion(adata_clonotype): assert isinstance(p, plt.Axes) +@pytest.mark.parametrize("adata_clonotype", [True], indirect=["adata_clonotype"], ids=["MuData"]) +def test_clonal_expansion_mudata_prefix(adata_clonotype): + """Regression test for #445""" + p = pl.clonal_expansion(adata_clonotype, groupby="group", target_col="airr:clone_id") + assert isinstance(p, plt.Axes) + + def test_alpha_diversity(adata_diversity): p = pl.alpha_diversity(adata_diversity, groupby="group", target_col="clonotype_") assert isinstance(p, plt.Axes)