Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace clip_at with breakpoints in clonal_expansion. #439

Merged
merged 12 commits into from
Nov 11, 2023
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [Unreleased]
## v0.14.0

### Breaking changes

Expand All @@ -19,6 +19,11 @@ and this project adheres to [Semantic Versioning][].
`lambda x: ~ak.is_none(x["junction_aa"], axis=-1)`. To learn more about native awkward array functions, please
refer to the [awkward array documentation](https://awkward-array.org/doc/main/reference/index.html). ([#444](https://github.com/scverse/scirpy/pull/444))

### Additions

- The `clonal_expansion` function now supports a `breakpoints` argument for more flexible "expansion categories".
The `breakpoints` argument supersedes the `clip_at` parameter, which is now deprecated. ([#439](https://github.com/scverse/scirpy/pull/439))

### Fixes

- Fix that `define_clonotype_clusters` could not retreive `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459))
Expand Down
368 changes: 196 additions & 172 deletions docs/tutorials/tutorial_3k_tcr.ipynb

Large diffs are not rendered by default.

24 changes: 18 additions & 6 deletions src/scirpy/pl/_clonal_expansion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Literal, Union
from collections.abc import Sequence
from typing import Literal, Optional, Union

from scirpy import tl
from scirpy.util import DataHandler
Expand All @@ -12,8 +13,9 @@ def clonal_expansion(
groupby: str,
*,
target_col: str = "clone_id",
clip_at: int = 3,
expanded_in: Union[str, None] = None,
breakpoints: Sequence[int] = (1, 2),
clip_at: Optional[int] = None,
summarize_by: Literal["cell", "clone_id"] = "cell",
normalize: bool = True,
show_nonexpanded: bool = True,
Expand All @@ -39,14 +41,23 @@ def clonal_expansion(
Group by this categorical variable in `adata.obs`.
target_col
Column in `adata.obs` containing the clonotype information.
clip_at
All entries in `target_col` with more copies than `clip_at`
will be summarized into a single group.
expanded_in
Calculate clonal expansion within groups. To calculate expansion
within patients, set this to the column containing patient annotation.
If set to None, a clonotype counts as expanded if there's any cell of the
same clonotype across the entire dataset. See also :term:`Public clonotype`.
breakpoints
summarize clonotypes with a size smaller or equal than the specified numbers
into groups. For instance, if this is (1, 2, 5), there will be four categories:

* all clonotypes with a size of 1 (singletons)
* all clonotypes with a size of 2
* all clonotypes with a size between 3 and 5 (inclusive)
* all clonotypes with a size > 5
clip_at
This argument is superseded by `breakpoints` and is only kept for backwards-compatibility.
Specifying a value of `clip_at = N` equals to specifying `breakpoints = (1, 2, 3, ..., N)`
Specifying both `clip_at` overrides `breakpoints`.
summarize_by
Can be either `cell` to count cells belonging to a clonotype (the default),
or `clone_id` to count clonotypes. The former leads to a over-representation
Expand All @@ -70,9 +81,10 @@ def clonal_expansion(
summarize_by=summarize_by,
normalize=normalize,
expanded_in=expanded_in,
breakpoints=breakpoints,
clip_at=clip_at,
)
if not show_nonexpanded:
plot_df.drop("1", axis="columns", inplace=True)
plot_df.drop("<= 1", axis="columns", inplace=True)

return {"bar": base.bar, "barh": base.barh}[viztype](plot_df, **kwargs)
Binary file not shown.
Binary file not shown.
7 changes: 7 additions & 0 deletions src/scirpy/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def test_clonal_expansion(adata_clonotype):
assert isinstance(p, plt.Axes)


@pytest.mark.parametrize("adata_clonotype", [True], indirect=["adata_clonotype"], ids=["MuData"])
def test_clonal_expansion_mudata_prefix(adata_clonotype):
"""Regression test for #445"""
p = pl.clonal_expansion(adata_clonotype, groupby="group", target_col="airr:clone_id")
assert isinstance(p, plt.Axes)


def test_alpha_diversity(adata_diversity):
p = pl.alpha_diversity(adata_diversity, groupby="group", target_col="clonotype_")
assert isinstance(p, plt.Axes)
Expand Down
77 changes: 28 additions & 49 deletions src/scirpy/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import scanpy as sc
from mudata import MuData
from pytest import approx

import scirpy as ir
from scirpy.util import DataHandler
Expand Down Expand Up @@ -106,13 +107,13 @@ def test_clip_and_count_clonotypes(adata_clonotype):
adata = adata_clonotype

res = ir.tl._clonal_expansion._clip_and_count(
adata, groupby="group", target_col="clone_id", clip_at=2, inplace=False
adata, groupby="group", target_col="clone_id", breakpoints=(1,), inplace=False
)
npt.assert_equal(res, np.array([">= 2"] * 3 + ["nan"] * 2 + ["1"] * 3 + [">= 2"] * 2))
npt.assert_equal(res, np.array(["> 1"] * 3 + ["nan"] * 2 + ["<= 1"] * 3 + ["> 1"] * 2))

# check without group
res = ir.tl._clonal_expansion._clip_and_count(adata, target_col="clone_id", clip_at=5, inplace=False)
npt.assert_equal(res, np.array(["4"] * 3 + ["nan"] * 2 + ["4"] + ["1"] * 2 + ["2"] * 2))
res = ir.tl._clonal_expansion._clip_and_count(adata, target_col="clone_id", breakpoints=(1, 2, 4), inplace=False)
npt.assert_equal(res, np.array(["<= 4"] * 3 + ["nan"] * 2 + ["<= 4"] + ["<= 1"] * 2 + ["<= 2"] * 2))

# check if target_col works
params = DataHandler.default(adata)
Expand All @@ -123,45 +124,35 @@ def test_clip_and_count_clonotypes(adata_clonotype):
adata,
groupby="group",
target_col="new_col",
clip_at=2,
breakpoints=(1,),
)
npt.assert_equal(
params.adata.obs["new_col_clipped_count"],
np.array([">= 2"] * 3 + ["nan"] * 2 + ["1"] * 3 + [">= 2"] * 2),
np.array(["> 1"] * 3 + ["nan"] * 2 + ["<= 1"] * 3 + ["> 1"] * 2),
)

# check if it raises value error if target_col does not exist
with pytest.raises(ValueError):
ir.tl._clonal_expansion._clip_and_count(
adata,
groupby="group",
target_col="clone_id",
clip_at=2,
fraction=False,
)


@pytest.mark.parametrize(
"expanded_in,expected",
[
("group", [">= 2"] * 3 + ["nan"] * 2 + ["1"] * 3 + [">= 2"] * 2),
(None, [">= 2"] * 3 + ["nan"] * 2 + [">= 2"] + ["1"] * 2 + [">= 2"] * 2),
("group", ["> 1"] * 3 + ["nan"] * 2 + ["<= 1"] * 3 + ["> 1"] * 2),
(None, ["> 1"] * 3 + ["nan"] * 2 + ["> 1"] + ["<= 1"] * 2 + ["> 1"] * 2),
],
)
def test_clonal_expansion(adata_clonotype, expanded_in, expected):
res = ir.tl.clonal_expansion(adata_clonotype, expanded_in=expanded_in, clip_at=2, inplace=False)
res = ir.tl.clonal_expansion(adata_clonotype, expanded_in=expanded_in, breakpoints=(1,), inplace=False)
npt.assert_equal(res, np.array(expected))


def test_clonal_expansion_summary(adata_clonotype):
res = ir.tl.summarize_clonal_expansion(adata_clonotype, "group", target_col="clone_id", clip_at=2, normalize=True)
pdt.assert_frame_equal(
res,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 2 / 5], ">= 2": [1.0, 3 / 5]}).set_index("group"),
check_names=False,
check_index_type=False,
check_categorical=False,
res = ir.tl.summarize_clonal_expansion(
adata_clonotype, "group", target_col="clone_id", breakpoints=(1,), normalize=True
)
assert res.reset_index().to_dict(orient="list") == {
"group": ["A", "B"],
"<= 1": [0, approx(0.4)],
"> 1": [1.0, approx(0.6)],
}

# test the `expanded_in` parameter.
res = ir.tl.summarize_clonal_expansion(
Expand All @@ -172,13 +163,11 @@ def test_clonal_expansion_summary(adata_clonotype):
normalize=True,
expanded_in="group",
)
pdt.assert_frame_equal(
res,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 3 / 5], ">= 2": [1.0, 2 / 5]}).set_index("group"),
check_names=False,
check_index_type=False,
check_categorical=False,
)
assert res.reset_index().to_dict(orient="list") == {
"group": ["A", "B"],
"<= 1": [0, approx(0.6)],
"> 1": [1.0, approx(0.4)],
}

# test the `summarize_by` parameter.
res = ir.tl.summarize_clonal_expansion(
Expand All @@ -189,26 +178,16 @@ def test_clonal_expansion_summary(adata_clonotype):
normalize=True,
summarize_by="clone_id",
)
pdt.assert_frame_equal(
res,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 2 / 4], ">= 2": [1.0, 2 / 4]}).set_index("group"),
check_names=False,
check_index_type=False,
check_categorical=False,
)
assert res.reset_index().to_dict(orient="list") == {
"group": ["A", "B"],
"<= 1": [0, approx(0.5)],
"> 1": [1.0, approx(0.5)],
}

res_counts = ir.tl.summarize_clonal_expansion(
adata_clonotype, "group", target_col="clone_id", clip_at=2, normalize=False
)
print(res_counts)
pdt.assert_frame_equal(
res_counts,
pd.DataFrame.from_dict({"group": ["A", "B"], "1": [0, 2], ">= 2": [3, 3]}).set_index("group"),
check_names=False,
check_dtype=False,
check_index_type=False,
check_categorical=False,
)
assert res_counts.reset_index().to_dict(orient="list") == {"group": ["A", "B"], "<= 1": [0, 2], "> 1": [3, 3]}


@pytest.mark.extra
Expand Down
59 changes: 42 additions & 17 deletions src/scirpy/tl/_clonal_expansion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Literal, Union
import warnings
from collections.abc import Sequence
from typing import Literal, Optional, Union

import numpy as np
import pandas as pd

from scirpy.util import DataHandler, _is_na, _normalize_counts
Expand All @@ -10,10 +13,9 @@
target_col: str,
*,
groupby: Union[str, None, list[str]] = None,
clip_at: int = 3,
breakpoints: Sequence[int] = (1, 2, 3),
inplace: bool = True,
key_added: Union[str, None] = None,
fraction: bool = True,
airr_mod="airr",
) -> Union[None, pd.Series]:
"""Counts the number of identical entries in `target_col`
Expand All @@ -22,22 +24,32 @@
`nan`s in the input remain `nan` in the output.
"""
params = DataHandler(adata, airr_mod)
if target_col not in params.adata.obs.columns:
raise ValueError("`target_col` not found in obs.")
if not len(breakpoints):
raise ValueError("Need to specify at least one breakpoint.")

Check warning on line 28 in src/scirpy/tl/_clonal_expansion.py

View check run for this annotation

Codecov / codecov/patch

src/scirpy/tl/_clonal_expansion.py#L28

Added line #L28 was not covered by tests

categories = [f"<= {b}" for b in breakpoints] + [f"> {breakpoints[-1]}", "nan"]

@np.vectorize
def _get_interval(value: int) -> str:
"""Return the interval of `value`, given breakpoints."""
for b in breakpoints:
if value <= b:
return f"<= {b}"
return f"> {b}"

groupby = [groupby] if isinstance(groupby, str) else groupby
groupby_cols = [target_col] if groupby is None else groupby + [target_col]
obs = params.get_obs(groupby_cols)

clonotype_counts = (
params.adata.obs.groupby(groupby_cols, observed=True)
obs.groupby(groupby_cols, observed=True)
.size()
.reset_index(name="tmp_count")
.assign(
tmp_count=lambda X: [f">= {min(n, clip_at)}" if n >= clip_at else str(n) for n in X["tmp_count"].values]
)
.assign(tmp_count=lambda X: pd.Categorical(_get_interval(X["tmp_count"].values), categories=categories))
)
clipped_count = params.adata.obs.merge(clonotype_counts, how="left", on=groupby_cols)["tmp_count"]
clipped_count[_is_na(params.adata.obs[target_col])] = "nan"
clipped_count.index = params.adata.obs.index
clipped_count = obs.merge(clonotype_counts, how="left", on=groupby_cols)["tmp_count"]
clipped_count[_is_na(obs[target_col])] = "nan"
clipped_count.index = obs.index

if inplace:
key_added = f"{target_col}_clipped_count" if key_added is None else key_added
Expand All @@ -52,7 +64,8 @@
*,
target_col: str = "clone_id",
expanded_in: Union[str, None] = None,
clip_at: int = 3,
breakpoints: Sequence[int] = (1, 2),
clip_at: Optional[int] = None,
key_added: str = "clonal_expansion",
inplace: bool = True,
**kwargs,
Expand All @@ -72,9 +85,18 @@
this to the column containing sample annotation. If set to None,
a clonotype counts as expanded if there's any cell of the same clonotype
across the entire dataset.
clip_at:
All clonotypes with more than `clip_at` clones will be summarized into
a single category
breakpoints
summarize clonotypes with a size smaller or equal than the specified numbers
into groups. For instance, if this is (1, 2, 5), there will be four categories:

* all clonotypes with a size of 1 (singletons)
* all clonotypes with a size of 2
* all clonotypes with a size between 3 and 5 (inclusive)
* all clonotypes with a size > 5
clip_at
This argument is superseded by `breakpoints` and is only kept for backwards-compatibility.
Specifying a value of `clip_at = N` equals to specifying `breakpoints = (1, 2, 3, ..., N)`
Specifying both `clip_at` overrides `breakpoints`.
{key_added}
{inplace}
{airr_mod}
Expand All @@ -84,11 +106,14 @@
Depending on the value of inplace, adds a column to adata or returns
a Series with the clipped count per cell.
"""
if clip_at is not None:
breakpoints = list(range(1, clip_at))
warnings.warn("The argument `clip_at` is deprecated. Please use `brekpoints` instead.", category=FutureWarning)
return _clip_and_count(
adata,
target_col,
groupby=expanded_in,
clip_at=clip_at,
breakpoints=breakpoints,
key_added=key_added,
inplace=inplace,
**kwargs,
Expand Down
Loading