Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parallelism of ir_dist #473

Merged
merged 9 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ and this project adheres to [Semantic Versioning][].
### Fixes

- Fix incompatibility with `adjustText` 1.0 ([#477](https://github.com/scverse/scirpy/pull/477))
- Reduce overall importtime by deferring the import of the `airr` package until it is actually used. ([#473](https://github.com/scverse/scirpy/pull/473))

### New features

- Speed up alignment distances by pre-filtering. There are two filtering strategies: A (lossless) length-based filter
and a heuristic based on the expected penalty per mismatch. This is implemented in the `FastAlignmentDistanceCalculator`
class which supersedes the `AlignmentDistanceCalculator` class, which is now deprecated. Using the `"alignment"` metric
in `pp.ir_dist` now uses the `FastAlignmentDistanceCalculator` with only the lenght-based filter activated.
Using the `"fastalignment"` activates the heuristic, which is significantly faster, but results in some false-negatives.
Using the `"fastalignment"` activates the heuristic, which is significantly faster, but results in some false-negatives. ([#456](https://github.com/scverse/scirpy/pull/456))
- Switch to [joblib/loky](https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html) as a backend for parallel
processing in `pp.ir_dist`. Joblib enables to switch to alternative backends that support out-of-machine computing
(e.g. `dask`, `ray`) via the `parallel_config` context manager. Additionally, chunk sizes are now adjusted dynamically based on the problem size. ([#473](https://github.com/scverse/scirpy/pull/473))

### Documentation

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
"mudata": ("https://mudata.readthedocs.io/en/latest/", None),
"awkward": ("https://awkward-array.org/doc/main/", None),
"pooch": ("https://www.fatiando.org/pooch/latest/", None),
"joblib": ("https://joblib.readthedocs.io/en/latest", None),
}

# List of patterns, relative to source directory, that match files and
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
'numba>=0.41.0',
'pooch>=1.7.0',
'pycairo>=1.20; sys_platform == "win32"',
'joblib>=1.3.1',
]

[project.optional-dependencies]
Expand Down
9 changes: 5 additions & 4 deletions src/scirpy/io/_datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from typing import Any

import scanpy
from airr import RearrangementSchema

from scirpy.util import _is_na2

from ._util import get_rearrangement_schema


class AirrCell(MutableMapping):
"""Data structure for a Cell with immune receptors. Represents one row of `adata.obsm["airr"]`.
Expand Down Expand Up @@ -126,8 +127,8 @@ def add_chain(self, chain: Mapping) -> None:
# sanitize NA values
chain = {k: None if _is_na2(v) else v for k, v in chain.items()}

RearrangementSchema.validate_header(chain.keys())
RearrangementSchema.validate_row(chain)
get_rearrangement_schema().validate_header(chain.keys())
get_rearrangement_schema().validate_row(chain)

for tmp_field in self._cell_attribute_fields:
# It is ok if a field specified as cell attribute is not present in the chain
Expand Down Expand Up @@ -182,4 +183,4 @@ def empty_chain_dict() -> dict:
"""Generate an empty chain dictionary, containing all required AIRR
columns, but set to `None`
"""
return {field: None for field in RearrangementSchema.required}
return {field: None for field in get_rearrangement_schema().required}
16 changes: 10 additions & 6 deletions src/scirpy/io/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@
from pathlib import Path
from typing import Any, Literal, Union

import airr
import numpy as np
import pandas as pd
from airr import RearrangementSchema
from anndata import AnnData

from scirpy.util import DataHandler, _doc_params, _is_na2, _is_true, _is_true2, _translate_dna_to_protein

from . import _tracerlib
from ._convert_anndata import from_airr_cells, to_airr_cells
from ._datastructures import AirrCell
from ._util import _IOLogger, _read_airr_rearrangement_df, doc_working_model
from ._util import _IOLogger, _read_airr_rearrangement_df, doc_working_model, get_rearrangement_schema

# patch sys.modules to enable pickle import.
# see https://stackoverflow.com/questions/2121874/python-pckling-after-changing-a-modules-directory
Expand Down Expand Up @@ -402,6 +400,9 @@ def read_airr(
AnnData object with :term:`AIRR` data in `obsm["airr"]` for each cell. For more details see
:ref:`data-structure`..
"""
# defer import, as this is very slow
import airr

airr_cells = {}
logger = _IOLogger()

Expand All @@ -426,7 +427,7 @@ def _decide_use_umi_count_col(chain_dict):

for chain_dict in iterator:
cell_id = chain_dict.pop("cell_id")
chain_dict.update({req: None for req in RearrangementSchema.required if req not in chain_dict})
chain_dict.update({req: None for req in get_rearrangement_schema().required if req not in chain_dict})
try:
tmp_cell = airr_cells[cell_id]
except KeyError:
Expand All @@ -438,7 +439,7 @@ def _decide_use_umi_count_col(chain_dict):
airr_cells[cell_id] = tmp_cell

if _decide_use_umi_count_col(chain_dict):
chain_dict["duplicate_count"] = RearrangementSchema.to_int(chain_dict.pop("umi_count"))
chain_dict["duplicate_count"] = get_rearrangement_schema().to_int(chain_dict.pop("umi_count"))

if infer_locus and "locus" not in chain_dict:
logger.warning(
Expand Down Expand Up @@ -571,6 +572,9 @@ def write_airr(adata: DataHandler.TYPE, filename: Union[str, Path], **kwargs) ->
**kwargs
additional arguments passed to :func:`~scirpy.io.to_airr_cells`
"""
# defer import, as this is very slow
import airr

airr_cells = to_airr_cells(adata, **kwargs)
try:
fields = airr_cells[0].fields
Expand All @@ -585,7 +589,7 @@ def write_airr(adata: DataHandler.TYPE, filename: Union[str, Path], **kwargs) ->
for chain in tmp_cell.to_airr_records():
# workaround for AIRR library writing out int field as floats (if it happens to be a float)
for field, value in chain.items():
if RearrangementSchema.type(field) == "integer" and value is not None:
if airr.RearrangementSchema.type(field) == "integer" and value is not None:
chain[field] = int(value)
writer.write(chain)
writer.close()
Expand Down
17 changes: 15 additions & 2 deletions src/scirpy/io/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import Counter

import pandas as pd
from airr import RearrangementReader
from scanpy import logging

doc_working_model = """\
Expand All @@ -18,6 +17,20 @@
"""


def get_rearrangement_reader():
"""Defer importing from airr package until it is used, since this is very slow"""
from airr import RearrangementReader

return RearrangementReader


def get_rearrangement_schema():
"""Defer importing from airr package until it is used, since this is very slow"""
from airr import RearrangementSchema

return RearrangementSchema


class _IOLogger:
"""Logger wrapper that prints identical messages only once"""

Expand Down Expand Up @@ -50,7 +63,7 @@ def fieldnames(self):
def __next__(self):
return next(self.reader)

class PdRearrangementReader(RearrangementReader):
class PdRearrangementReader(get_rearrangement_reader()):
def __init__(self, df, *args, **kwargs):
super().__init__(os.devnull, *args, **kwargs)
self.dict_reader = PdDictReader(df)
Expand Down
10 changes: 6 additions & 4 deletions src/scirpy/ir_dist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_metric_key(metric: MetricType) -> str:
return "custom" if isinstance(metric, metrics.DistanceCalculator) else metric # type: ignore


def _get_distance_calculator(metric: MetricType, cutoff: Union[int, None], *, n_jobs=None, **kwargs):
def _get_distance_calculator(metric: MetricType, cutoff: Union[int, None], *, n_jobs=-1, **kwargs):
"""Returns an instance of :class:`~scirpy.ir_dist.metrics.DistanceCalculator`
given a metric.

Expand Down Expand Up @@ -114,7 +114,7 @@ def _ir_dist(
sequence: Literal["aa", "nt"] = "nt",
key_added: Union[str, None] = None,
inplace: bool = True,
n_jobs: Union[int, None] = None,
n_jobs: Union[int, None] = -1,
airr_mod: str = "airr",
airr_key: str = "airr",
chain_idx_key: str = "chain_indices",
Expand Down Expand Up @@ -158,7 +158,9 @@ def _ir_dist(
with the results.
n_jobs
Number of cores to use for distance calculation. Passed on to
:class:`scirpy.ir_dist.metrics.DistanceCalculator`.
:class:`scirpy.ir_dist.metrics.DistanceCalculator`. :class:`joblib.Parallel` is
used internally. Via the :class:`joblib.parallel_config` context manager, you can set another
backend (e.g. `dask`) and adjust other configuration options.
{airr_mod}
{airr_key}
{chain_idx_key}
Expand Down Expand Up @@ -245,7 +247,7 @@ def sequence_dist(
*,
metric: MetricType = "identity",
cutoff: Union[None, int] = None,
n_jobs: Union[None, int] = None,
n_jobs: Union[None, int] = -1,
**kwargs,
) -> csr_matrix:
"""
Expand Down
72 changes: 49 additions & 23 deletions src/scirpy/ir_dist/metrics.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import abc
import itertools
import warnings
from collections.abc import Sequence
from multiprocessing import cpu_count
from typing import Optional, Union

import joblib
import numpy as np
import scipy.sparse
import scipy.spatial
from Levenshtein import distance as levenshtein_dist
from Levenshtein import hamming as hamming_dist
from scanpy import logging
from scipy.sparse import coo_matrix, csr_matrix
from tqdm.contrib.concurrent import process_map

from scirpy.util import _doc_params, deprecated, tqdm
from scirpy.util import _doc_params, _parallelize_with_joblib, deprecated

_doc_params_parallel_distance_calculator = """\
n_jobs
Number of jobs to use for the pairwise distance calculation.
If None, use all jobs (only for ParallelDistanceCalculators).
Number of jobs to use for the pairwise distance calculation, passed to
:class:`joblib.Parallel`. If -1, use all CPUs (only for ParallelDistanceCalculators).
Via the :class:`joblib.parallel_config` context manager, another backend (e.g. `dask`)
can be selected.
block_size
The width of a block of the matrix that will be delegated to a worker
process. The block contains `block_size ** 2` elements.
Deprecated. This is now set in `calc_dist_mat`.
"""


Expand Down Expand Up @@ -113,12 +115,16 @@
self,
cutoff: int,
*,
n_jobs: Optional[int] = None,
block_size: Optional[int] = 50,
n_jobs: Optional[int] = -1,
block_size: Optional[int] = None,
):
super().__init__(cutoff)
self.n_jobs = n_jobs
self.block_size = block_size
if block_size is not None:
warnings.warn(

Check warning on line 124 in src/scirpy/ir_dist/metrics.py

View check run for this annotation

Codecov / codecov/patch

src/scirpy/ir_dist/metrics.py#L124

Added line #L124 was not covered by tests
"The `block_size` parameter is now set in the `calc_dist_mat` function instead of the object level. It is ignored here.",
category=FutureWarning,
)

@abc.abstractmethod
def _compute_block(
Expand Down Expand Up @@ -189,21 +195,41 @@
else:
yield seqs1[row : row + block_size], seqs2[col : col + block_size], (row, col)

def calc_dist_mat(self, seqs: Sequence[str], seqs2: Optional[Sequence[str]] = None) -> csr_matrix:
def calc_dist_mat(
self, seqs: Sequence[str], seqs2: Optional[Sequence[str]] = None, *, block_size: Optional[int] = None
) -> csr_matrix:
"""Calculate the distance matrix.

See :meth:`DistanceCalculator.calc_dist_mat`.

Parameters
----------
seqs
array containing CDR3 sequences. Must not contain duplicates.
seqs2
second array containing CDR3 sequences. Must not contain
duplicates either.
block_size
The width of a block that's sent to a worker. A block contains
`block_size ** 2` elements. If `None` the block
size is determined automatically based on the problem size.



Returns
-------
Sparse pairwise distance matrix.
"""
problem_size = len(seqs) * len(seqs2) if seqs2 is not None else len(seqs) ** 2
# dynamicall adjust the block size such that there are ~1000 blocks within a range of 50 and 5000
block_size = int(np.ceil(min(max(np.sqrt(problem_size / 1000), 50), 5000)))
logging.info(f"block size set to {block_size}")

# precompute blocks as list to have total number of blocks for progressbar
blocks = list(self._block_iter(seqs, seqs2, self.block_size))

block_results = process_map(
self._compute_block,
*zip(*blocks),
max_workers=self.n_jobs if self.n_jobs is not None else cpu_count(),
chunksize=50,
tqdm_class=tqdm,
total=len(blocks),
blocks = list(self._block_iter(seqs, seqs2, block_size=block_size))

block_results = _parallelize_with_joblib(
(joblib.delayed(self._compute_block)(*block) for block in blocks), total=len(blocks), n_jobs=self.n_jobs
)

try:
Expand Down Expand Up @@ -418,8 +444,8 @@
self,
cutoff: int = 10,
*,
n_jobs: Union[int, None] = None,
block_size: int = 50,
n_jobs: Union[int, None] = -1,
block_size: Optional[int] = None,
subst_mat: str = "blosum62",
gap_open: int = 11,
gap_extend: int = 11,
Expand Down Expand Up @@ -555,7 +581,7 @@
cutoff: int = 10,
*,
n_jobs: Union[int, None] = None,
block_size: int = 50,
block_size: Optional[int] = None,
subst_mat: str = "blosum62",
gap_open: int = 11,
gap_extend: int = 11,
Expand Down
6 changes: 5 additions & 1 deletion src/scirpy/tests/test_ir_dist.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import joblib
import numpy as np
import numpy.testing as npt
import pandas as pd
Expand Down Expand Up @@ -163,6 +164,7 @@ def test_ir_dist(


@pytest.mark.parametrize("with_adata2", [False, True])
@pytest.mark.parametrize("joblib_backend", ["loky", "multiprocessing", "threading"])
@pytest.mark.parametrize("n_jobs", [1, 2])
@pytest.mark.parametrize(
"comment,metric,ctn_kwargs,expected_clonotype_df,expected_dist",
Expand Down Expand Up @@ -340,13 +342,15 @@ def test_compute_distances(
expected_dist,
n_jobs,
with_adata2,
joblib_backend,
):
"""Test that distances are calculated correctly with different settings"""
distance_key = f"ir_dist_aa_{metric}"
metric = adata_cdr3_mock_distance_calculator if metric == "custom" else metric
adata2 = adata_cdr3 if with_adata2 else None
expected_dist = np.array(expected_dist)
ir.pp.ir_dist(adata_cdr3, adata2, metric=metric, sequence="aa", key_added=distance_key)
with joblib.parallel_config(backend=joblib_backend):
ir.pp.ir_dist(adata_cdr3, adata2, metric=metric, sequence="aa", key_added=distance_key)
cn = ClonotypeNeighbors(
DataHandler.default(adata_cdr3),
DataHandler.default(adata2),
Expand Down
12 changes: 6 additions & 6 deletions src/scirpy/tests/test_ir_dist_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ def test_levenshtein_compute_block():


def test_levensthein_dist():
levenshtein10 = LevenshteinDistanceCalculator(10, block_size=50)
levenshtein10_2 = LevenshteinDistanceCalculator(10, block_size=2)
levenshtein1 = LevenshteinDistanceCalculator(1, n_jobs=1, block_size=1)
levenshtein10 = LevenshteinDistanceCalculator(10)
levenshtein10_2 = LevenshteinDistanceCalculator(10)
levenshtein1 = LevenshteinDistanceCalculator(1, n_jobs=1)

res10 = levenshtein10.calc_dist_mat(np.array(["A", "AA", "AAA", "AAR"]))
res10_2 = levenshtein10_2.calc_dist_mat(np.array(["A", "AA", "AAA", "AAR"]))
res1 = levenshtein1.calc_dist_mat(np.array(["A", "AA", "AAA", "AAR"]))
res10 = levenshtein10.calc_dist_mat(np.array(["A", "AA", "AAA", "AAR"]), block_size=50)
res10_2 = levenshtein10_2.calc_dist_mat(np.array(["A", "AA", "AAA", "AAR"]), block_size=2)
res1 = levenshtein1.calc_dist_mat(np.array(["A", "AA", "AAA", "AAR"]), block_size=1)

assert isinstance(res10, scipy.sparse.csr_matrix)
assert isinstance(res10_2, scipy.sparse.csr_matrix)
Expand Down
2 changes: 1 addition & 1 deletion src/scirpy/tl/_ir_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def ir_query(
key_added: Optional[str] = None,
distance_key: Optional[str] = None,
inplace: bool = True,
n_jobs: Optional[int] = None,
n_jobs: Optional[int] = -1,
chunksize: int = 2000,
airr_mod: str = "airr",
airr_key: str = "airr",
Expand Down
Loading
Loading