Skip to content

Commit

Permalink
Merge branch 'fix-445' into clonal-expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
grst committed Nov 11, 2023
2 parents b972171 + de6453e commit f19e216
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
language_version: "17.9.1"
exclude: '^\.conda'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.3
rev: v0.1.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,24 @@ and this project adheres to [Semantic Versioning][].

## [Unreleased]

### Breaking changes

- Reimplement `pp.index_chains` using numba and awkward array functions, achieving a significant speedup. This function
behaves exactly like the previous version _except_ that callback functions passed to the `filter` arguments
must now be vectorized over an awkward array, e.g. to check if a `junction_aa` field is present you could
previously pass `lambda x: x['junction_aa'] is not None`, now an accepted version would be
`lambda x: ~ak.is_none(x["junction_aa"], axis=-1)`. To learn more about native awkward array functions, please
refer to the [awkward array documentation](https://awkward-array.org/doc/main/reference/index.html). ([#444](https://github.com/scverse/scirpy/pull/444))

### Additions

- The `clonal_expansion` function now supports a `breakpoints` argument for more flexible "expansion categories".
The `breakpoints` argument supersedes the `clip_at` parameter, which is now deprecated. ([#439](https://github.com/scverse/scirpy/pull/439))

### Fixes

- Fix that `define_clonotype_clusters` could not retrieve `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459))
- Fix that `define_clonotype_clusters` could not retreive `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459))
- Fix that AIRR Rearrangment fields of integer types could not be written when their value was None ([#465](https://github.com/scverse/scirpy/pull/465))

## v0.13.1

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/tutorial_3k_tcr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3537,7 +3537,7 @@
"notebook_metadata_filter": "-kernelspec"
},
"kernelspec": {
"display_name": "scirpy_dev2",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -3551,7 +3551,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions src/scirpy/io/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,9 @@ def write_airr(adata: DataHandler.TYPE, filename: Union[str, Path], **kwargs) ->
for tmp_cell in airr_cells:
for chain in tmp_cell.to_airr_records():
# workaround for AIRR library writing out int field as floats (if it happens to be a float)
for f in chain:
if RearrangementSchema.type(f) == "integer":
chain[f] = int(chain[f])
for field, value in chain.items():
if RearrangementSchema.type(field) == "integer" and value is not None:
chain[field] = int(value)
writer.write(chain)
writer.close()

Expand Down
173 changes: 170 additions & 3 deletions src/scirpy/pp/_index_chains.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import operator
from collections.abc import Mapping, Sequence
from functools import partial
from functools import partial, reduce
from types import MappingProxyType
from typing import Any, Callable, Union

import awkward as ak
import numba as nb
import numpy as np
from scanpy import logging

from scirpy.io._datastructures import AirrCell
from scirpy.util import DataHandler, _is_na2, tqdm

SCIRPY_DUAL_IR_MODEL = "scirpy_dual_ir_v0.13"
# make these constants available to numba
_VJ_LOCI = tuple(AirrCell.VJ_LOCI)
_VDJ_LOCI = tuple(AirrCell.VDJ_LOCI)


@DataHandler.inject_param_docs()
def index_chains(
def index_chains_legacy(
adata: DataHandler.TYPE,
*,
filter: Union[Callable[[Mapping], bool], Sequence[Union[str, Callable[[Mapping], bool]]]] = (
Expand Down Expand Up @@ -135,7 +141,168 @@ def index_chains(
}


def _key_sort_chains(chains: list[Mapping], sort_chains_by: Mapping[str, Any], idx: int) -> Sequence:
@DataHandler.inject_param_docs()
def index_chains(
adata: DataHandler.TYPE,
*,
filter: Union[Callable[[ak.Array], bool], Sequence[Union[str, Callable[[ak.Array], bool]]]] = (
"productive",
"require_junction_aa",
),
sort_chains_by: Mapping[str, Any] = MappingProxyType(
{"duplicate_count": 0, "consensus_count": 0, "junction": "", "junction_aa": ""}
),
airr_mod: str = "airr",
airr_key: str = "airr",
key_added: str = "chain_indices",
) -> None:
"""\
Selects primary/secondary VJ/VDJ cells per chain according to the :ref:`receptor-model`.
This function iterates through all chains stored in the :term:`awkward array` in
`adata.obsm[airr_key]` and
* labels chains as primary/secondary VJ/VDJ chains
* labels cells as multichain cells
based on the expression level of the chains and the specified filtering option.
By default, non-productive chains and chains without a valid CDR3 amino acid sequence are filtered out.
Additionally, chains without a valid IMGT locus are always filtered out.
For more details, please refer to the :ref:`receptor-model` and the :ref:`data structure <data-structure>`.
Parameters
----------
{adata}
filter
Option to filter chains. Can be either
* a callback function that takes the full awkward array with AIRR chains as input and returns
another awkward array that is a boolean mask which can be used to index the former.
(True to keep, False to discard)
* a list of "filtering presets". Possible values are `"productive"` and `"require_junction_aa"`.
`"productive"` removes non-productive chains and `"require_junction_aa"` removes chains that don't have
a CDR3 amino acid sequence.
* a list with a combination of both.
Multiple presets/functions are combined using `and`. Filtered chains do not count towards calling "multichain" cells.
sort_chains_by
A list of sort keys used to determine an ordering of chains. The chain with the highest value
of this tuple will be the primary chain, second-highest the secondary chain. If there are more chains, they
will not be indexed, and the cell receives the "multichain" flag.
{airr_mod}
{airr_key}
key_added
Key under which the chain indicies will be stored in `adata.obsm` and metadata will be stored in `adata.uns`.
Returns
-------
Nothing, but adds a dataframe to `adata.obsm[chain_indices]`
"""
params = DataHandler(adata, airr_mod, airr_key)

# prepare filter functions
if isinstance(filter, Callable):
filter = [filter]
filter_presets = {
"productive": lambda x: x["productive"],
"require_junction_aa": lambda x: ~ak.is_none(x["junction_aa"], axis=-1),
}
filter = [filter_presets[f] if isinstance(f, str) else f for f in filter]

# only warn if those fields are in the key (i.e. this should give a warning if those are missing with
# default settings. If the user specifies their own dictionary, they are on their own)
if "duplicate_count" in sort_chains_by and "consensus_count" in sort_chains_by:
if "duplicate_count" not in params.airr.fields and "consensus_count" not in params.airr.fields:
logging.warning("No expression information available. Cannot rank chains by expression. ") # type: ignore

if "locus" not in params.airr.fields:
raise ValueError("The scirpy receptor model requires a `locus` field to be specified in the AIRR data.")

airr = params.airr
logging.info("Filtering chains...")
# Get the numeric indices pre-filtering - these are the indices we need in the final output as
# .obsm["airr"] is and remains unfiltered.
airr_idx = ak.local_index(airr, axis=1)
# Filter out chains that do not match the filter criteria
# we need an initial value that selects all chains in case filter is an empty list
airr_idx = airr_idx[reduce(operator.and_, (f(airr) for f in filter), ak.ones_like(airr_idx, dtype=bool))]

res = {}
is_multichain = np.zeros(len(airr), dtype=bool)
for chain_type, locus_names in {"VJ": AirrCell.VJ_LOCI, "VDJ": AirrCell.VDJ_LOCI}.items():
logging.info(f"Indexing {chain_type} chains...")
# get the indices for all VJ / VDJ chains, respectively
idx = airr_idx[_awkward_isin(airr["locus"][airr_idx], locus_names)]

# Now we need to sort the chains by the keys specified in `sort_chains_by`.
# since `argsort` doesn't support composite keys, we take advantage of the
# fact that the sorting algorithm is stable and sort the same array several times,
# starting with the lowest priority key up to the highest priority key.
for k, default in reversed(sort_chains_by.items()):
# skip this round of sorting altogether if field not present
if k in airr.fields:
logging.debug(f"Sorting chains by {k}")
tmp_idx = ak.argsort(ak.fill_none(airr[k][idx], default), stable=True, axis=-1, ascending=False)
idx = idx[tmp_idx]
else:
logging.debug(f"Skip sorting by {k} because field not present")

# We want the result to be lists of exactly 2 - clip if longer, pad with None if shorter.
res[chain_type] = ak.pad_none(idx, 2, axis=1, clip=True)
is_multichain |= ak.to_numpy(_awkward_len(idx)) > 2

# build results
logging.info("build result array")
res["multichain"] = is_multichain

params.adata.obsm[key_added] = ak.zip(res, depth_limit=1) # type: ignore

# store metadata in .uns
params.adata.uns[key_added] = {
"model": SCIRPY_DUAL_IR_MODEL, # can be used to distinguish different receptor models that may be added in the future.
"filter": str(filter),
"airr_key": airr_key,
"sort_chains_by": str(sort_chains_by),
}


@nb.njit
def _awkward_len_inner(arr, ab):
for row in arr:
ab.append(len(row))
return ab


def _awkward_len(arr):
return _awkward_len_inner(arr, ak.ArrayBuilder()).snapshot()


@nb.njit()
def _awkward_isin_inner(arr, haystack, ab):
for row in arr:
ab.begin_list()
for v in row:
ab.append(v in haystack)
ab.end_list()
return ab


def _awkward_isin(arr, haystack):
haystack = tuple(haystack)
return _awkward_isin_inner(arr, haystack, ak.ArrayBuilder()).snapshot()


# For future reference, here would be two alternative implementations that are a bit
# slower, but work without the need for numba.
# def _awkward_len(arr):
# return ak.max(ak.local_index(arr, axis=1), axis=1)
#
# def _awkward_isin(arr, haystack):
# return reduce(operator.or_, (arr == el for el in haystack))


def _key_sort_chains(chains, sort_chains_by: Mapping[str, Any], idx: int) -> Sequence:
"""Get key to sort chains by expression.
Parameters
Expand Down
9 changes: 9 additions & 0 deletions src/scirpy/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,15 @@ def test_airr_roundtrip_conversion(anndata_from_10x_sample, tmp_path):
pdt.assert_frame_equal(anndata.obs, anndata2.obs, check_dtype=False, check_categorical=False)


def test_write_airr_none_field_issue_454(tmp_path):
cell = AirrCell("cell1")
chain = cell.empty_chain_dict()
chain["d_sequence_end"] = None
cell.add_chain(chain)
adata = from_airr_cells([cell])
write_airr(adata, tmp_path / "test.airr.tsv")


@pytest.mark.extra
@pytest.mark.parametrize(
"anndata_from_10x_sample",
Expand Down
7 changes: 7 additions & 0 deletions src/scirpy/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def test_clonal_expansion(adata_clonotype):
assert isinstance(p, plt.Axes)


@pytest.mark.parametrize("adata_clonotype", [True], indirect=["adata_clonotype"], ids=["MuData"])
def test_clonal_expansion_mudata_prefix(adata_clonotype):
"""Regression test for #445"""
p = pl.clonal_expansion(adata_clonotype, groupby="group", target_col="airr:clone_id")
assert isinstance(p, plt.Axes)


def test_alpha_diversity(adata_diversity):
p = pl.alpha_diversity(adata_diversity, groupby="group", target_col="clonotype_")
assert isinstance(p, plt.Axes)
Expand Down

0 comments on commit f19e216

Please sign in to comment.