Skip to content

Commit

Permalink
Add TSI score (#1166)
Browse files Browse the repository at this point in the history
* Add `GPPCA::tsi`

Adds function for computing terminal state identification score

* Add `GPCCA::plot_tsi`

Add class method to plot terminal state identification.

* Add `test_gpcca/TestGPCCA::test_tsi`

* Add `test_plotting.py::TestGPCCA::test_plot_tsi`

---------

Co-authored-by: Michal Klein <[email protected]>
  • Loading branch information
WeilerP and michalk8 authored Mar 4, 2024
1 parent 721c59f commit f5a9566
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 1 deletion.
165 changes: 164 additions & 1 deletion src/cellrank/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import collections
import datetime
import enum
import pathlib
import types
from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union
from pathlib import Path
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import ListedColormap, Normalize
Expand Down Expand Up @@ -86,6 +89,7 @@ def __init__(
self._coarse_init_dist: Optional[pd.Series] = None
self._coarse_stat_dist: Optional[pd.Series] = None
self._coarse_tmat: Optional[pd.DataFrame] = None
self._tsi: Optional[AnnData] = None

@property
@d.get_summary(base="gpcca_macro")
Expand Down Expand Up @@ -532,6 +536,165 @@ def set_initial_states(
)
return self

# TODO: Add definition/link to paper.
def tsi(
self,
n_macrostates: int,
terminal_states: Optional[List[str]] = None,
cluster_key: Optional[str] = None,
**kwargs: Any,
) -> float:
"""Compute terminal state identificiation (TSI) score.
Parameters
----------
n_macrostates
Maximum number of macrostates to consider.
terminal_states
List of terminal states.
cluster_key
Key in :attr:`~anndata.AnnData.obs` defining cluster labels including terminal states.
kwargs
Keyword arguments passed to :meth:`compute_macrostates` function.
Returns
-------
Returns TSI score.
"""
tsi_precomputed = (self._tsi is not None) and (self._tsi[:, "number_of_macrostates"].X.max() >= n_macrostates)
if terminal_states is not None:
tsi_precomputed = tsi_precomputed and (set(self._tsi.uns["terminal_states"]) == set(terminal_states))
if cluster_key is not None:
tsi_precomputed = tsi_precomputed and (self._tsi.uns["cluster_key"] == cluster_key)

if not tsi_precomputed:
if terminal_states is None:
raise RuntimeError("`terminal_states` needs to be specified to compute TSI.")
if cluster_key is None:
raise RuntimeError("`cluster_key` needs to be specified to compute TSI.")

macrostates = {}
for n_states in range(n_macrostates, 0, -1):
self.compute_macrostates(n_states=n_states, cluster_key=cluster_key, **kwargs)
macrostates[n_states] = self.macrostates.cat.categories

max_terminal_states = len(terminal_states)

tsi_df = collections.defaultdict(list)
for n_states, states in macrostates.items():
n_terminal_states = (
states.str.replace(r"(_).*", "", regex=True).drop_duplicates().isin(terminal_states).sum()
)
tsi_df["number_of_macrostates"].append(n_states)
tsi_df["identified_terminal_states"].append(n_terminal_states)

tsi_df["optimal_identification"].append(min(n_states, max_terminal_states))

tsi_df = AnnData(pd.DataFrame(tsi_df), uns={"terminal_states": terminal_states, "cluster_key": cluster_key})
self._tsi = tsi_df

tsi_df = self._tsi.to_df()
row_mask = tsi_df["number_of_macrostates"] <= n_macrostates
optimal_score = tsi_df.loc[row_mask, "optimal_identification"].sum()

return tsi_df.loc[row_mask, "identified_terminal_states"].sum() / optimal_score

@d.dedent
def plot_tsi(
self,
n_macrostates: Optional[int] = None,
x_offset: Tuple[float, float] = (0.2, 0.2),
y_offset: Tuple[float, float] = (0.1, 0.1),
figsize: Tuple[float, float] = (6, 4),
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> Axes:
"""Plot terminal state identificiation (TSI).
Requires computing TSI with :meth:`tsi` first.
Parameters
----------
n_macrostates
Maximum number of macrostates to consider. Defaults to using all.
x_offset
Offset of x-axis.
y_offset
Offset of y-axis.
%(plotting)s
kwargs
Keyword arguments for :func:`~seaborn.lineplot`.
Returns
-------
Plot TSI of the kernel and an optimal identification strategy.
"""
if self._tsi is None:
raise RuntimeError("Compute TSI with `tsi` first as `.tsi()`.")

tsi_df = self._tsi.to_df()
if n_macrostates is not None:
tsi_df = tsi_df.loc[tsi_df["number_of_macrostates"] <= n_macrostates, :]

optimal_identification = tsi_df[["number_of_macrostates", "optimal_identification"]]
optimal_identification = optimal_identification.rename(
columns={"optimal_identification": "identified_terminal_states"}
)
optimal_identification["method"] = "Optimal identification"
optimal_identification["line_style"] = "--"

df = tsi_df[["number_of_macrostates", "identified_terminal_states"]]
df["method"] = self.kernel.__class__.__name__
df["line_style"] = "-"

df = pd.concat([df, optimal_identification])

fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
sns.lineplot(
data=df,
x="number_of_macrostates",
y="identified_terminal_states",
hue="method",
style="line_style",
drawstyle="steps-post",
ax=ax,
**kwargs,
)

ax.set_xticks(df["number_of_macrostates"].unique().astype(int))
# Plot is generated from large to small values on the x-axis
for label_id, label in enumerate(ax.xaxis.get_ticklabels()[::-1]):
if ((label_id + 1) % 5 != 0) and label_id != 0:
label.set_visible(False)
ax.set_yticks(df["identified_terminal_states"].unique())

x_min = df["number_of_macrostates"].min() - x_offset[0]
x_max = df["number_of_macrostates"].max() + x_offset[1]
y_min = df["identified_terminal_states"].min() - y_offset[0]
y_max = df["identified_terminal_states"].max() + y_offset[1]
ax.set(
xlim=[x_min, x_max],
ylim=[y_min, y_max],
xlabel="Number of macrostates",
ylabel="Identified terminal states",
)

ax.get_legend().remove()

n_methods = len(df["method"].unique())
handles, labels = ax.get_legend_handles_labels()
handles[n_methods].set_linestyle("--")
handles = handles[: (n_methods + 1)]
labels = labels[: (n_methods + 1)]
labels[0] = "Method"
fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1))

if save is not None:
save_fig(fig=fig, path=save)

return ax

@d.dedent
def fit(
self,
Expand Down
Binary file modified tests/_ground_truth_adatas/adata_200.h5ad
Binary file not shown.
Binary file added tests/_ground_truth_figures/plot_tsi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions tests/test_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,24 @@ def test_plot_lineage_drivers_normal_run(self, adata_large: AnnData):

mc.plot_lineage_drivers("0", use_raw=False)

def test_tsi(self, adata_large: AnnData):
groundtruth_adata = adata_large.uns["tsi"].copy()

vk = VelocityKernel(adata_large).compute_transition_matrix()
estimator = cr.estimators.GPCCA(vk)
estimator.compute_schur(n_components=5)

terminal_states = ["Neuroblast", "Astrocyte", "Granule mature"]
cluster_key = "clusters"
tsi_score = estimator.tsi(n_macrostates=3, terminal_states=terminal_states, cluster_key=cluster_key, n_cells=10)

np.testing.assert_almost_equal(tsi_score, groundtruth_adata.uns["score"])
assert isinstance(estimator._tsi.uns["terminal_states"], list)
assert len(estimator._tsi.uns["terminal_states"]) == len(groundtruth_adata.uns["terminal_states"])
assert (estimator._tsi.uns["terminal_states"] == groundtruth_adata.uns["terminal_states"]).all()
assert estimator._tsi.uns["cluster_key"] == groundtruth_adata.uns["cluster_key"]
pd.testing.assert_frame_equal(estimator._tsi.to_df(), groundtruth_adata.to_df())

def test_compute_priming_clusters(self, adata_large: AnnData):
vk = VelocityKernel(adata_large).compute_transition_matrix(softmax_scale=4)
ck = ConnectivityKernel(adata_large).compute_transition_matrix()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2257,6 +2257,14 @@ def test_scvelo_transition_matrix_projection(self, mc: GPCCA, fpath: str):
save=fpath,
)

@compare(kind="gpcca")
def test_plot_tsi(self, mc: GPCCA, fpath: str):
mc = mc.copy(deep=True)
terminal_states = ["Neuroblast", "Astrocyte", "Granule mature"]
cluster_key = "clusters"
_ = mc.tsi(n_macrostates=3, terminal_states=terminal_states, cluster_key=cluster_key, n_cells=10)
mc.plot_tsi(dpi=DPI, save=fpath)


class TestLineage:
@compare(kind="lineage")
Expand Down

0 comments on commit f5a9566

Please sign in to comment.