Skip to content

Commit

Permalink
Add fastalignment metric (#456)
Browse files Browse the repository at this point in the history
* Add FastAlignmentDistanceCalculator

Added variant of AlignmentDistanceCalculator which performs faster, but may miss some of the pairwise distances.

* Place parasail import in methods

* Add fastalignment key

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update src/scirpy/ir_dist/metrics.py

Co-authored-by: Gregor Sturm <[email protected]>

* Added unit tests for fastalignment

* Added exception, updated docstrings

* Added fastalignment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove accidentically committed data objects

* Improve docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Test FastAlignmentDistanceCalculator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update changelog

* Update tutorial

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Gregor Sturm <[email protected]>
  • Loading branch information
3 people authored Jan 4, 2024
1 parent 7dd3377 commit e1c8028
Show file tree
Hide file tree
Showing 7 changed files with 504 additions and 129 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## Unreleased

## New features

- Speed up alignment distances by pre-filtering. There are two filtering strategies: A (lossless) length-based filter
and a heuristic based on the expected penalty per mismatch. This is implemented in the `FastAlignmentDistanceCalculator`
class which supersedes the `AlignmentDistanceCalculator` class, which is now deprecated. Using the `"alignment"` metric
in `pp.ir_dist` now uses the `FastAlignmentDistanceCalculator` with only the lenght-based filter activated.
Using the `"fastalignment"` activates the heuristic, which is significantly faster, but results in some false-negatives.

## v0.14.0

### Breaking changes
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,4 @@ distance metrics
ir_dist.metrics.LevenshteinDistanceCalculator
ir_dist.metrics.HammingDistanceCalculator
ir_dist.metrics.AlignmentDistanceCalculator
ir_dist.metrics.FastAlignmentDistanceCalculator
322 changes: 208 additions & 114 deletions docs/tutorials/tutorial_3k_tcr.ipynb

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions src/scirpy/ir_dist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def IrNeighbors(*args, **kwargs):


MetricType = Union[
Literal["alignment", "identity", "levenshtein", "hamming"],
Literal["alignment", "fastalignment", "identity", "levenshtein", "hamming"],
metrics.DistanceCalculator,
]

Expand All @@ -50,7 +50,11 @@ def IrNeighbors(*args, **kwargs):
See :class:`~scirpy.ir_dist.metrics.HammingDistanceCalculator`.
* `alignment` -- Distance based on pairwise sequence alignments using the
BLOSUM62 matrix. This option is incompatible with nucleotide sequences.
See :class:`~scirpy.ir_dist.metrics.AlignmentDistanceCalculator`.
See :class:`~scirpy.ir_dist.metrics.FastAlignmentDistanceCalculator`.
* `fastalignment` -- Distance based on pairwise sequence alignments using the
BLOSUM62 matrix. Faster implementation of `alignment` with some loss.
This option is incompatible with nucleotide sequences.
See :class:`~scirpy.ir_dist.metrics.FastAlignmentDistanceCalculator`.
* any instance of :class:`~scirpy.ir_dist.metrics.DistanceCalculator`.
"""

Expand All @@ -59,7 +63,7 @@ def IrNeighbors(*args, **kwargs):
All distances `> cutoff` will be replaced by `0` and eliminated from the sparse
matrix. A sensible cutoff depends on the distance metric, you can find
information in the corresponding docs. If set to `None`, the cutoff
will be `10` for the `alignment` metric, and `2` for `levenshtein` and `hamming`.
will be `10` for the `alignment` and `fastalignment` metric, and `2` for `levenshtein` and `hamming`.
For the identity metric, the cutoff is ignored and always set to `0`.
"""

Expand All @@ -81,7 +85,9 @@ def _get_distance_calculator(metric: MetricType, cutoff: Union[int, None], *, n_
if isinstance(metric, metrics.DistanceCalculator):
dist_calc = metric
elif metric == "alignment":
dist_calc = metrics.AlignmentDistanceCalculator(cutoff=cutoff, n_jobs=n_jobs, **kwargs)
dist_calc = metrics.FastAlignmentDistanceCalculator(cutoff=cutoff, n_jobs=n_jobs, estimated_penalty=0, **kwargs)
elif metric == "fastalignment":
dist_calc = metrics.FastAlignmentDistanceCalculator(cutoff=cutoff, n_jobs=n_jobs, **kwargs)
elif metric == "identity":
dist_calc = metrics.IdentityDistanceCalculator(cutoff=cutoff, **kwargs)
elif metric == "levenshtein":
Expand Down
214 changes: 213 additions & 1 deletion src/scirpy/ir_dist/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scipy.sparse import coo_matrix, csr_matrix
from tqdm.contrib.concurrent import process_map

from scirpy.util import _doc_params, tqdm
from scirpy.util import _doc_params, deprecated, tqdm

_doc_params_parallel_distance_calculator = """\
n_jobs
Expand Down Expand Up @@ -412,6 +412,12 @@ class AlignmentDistanceCalculator(ParallelDistanceCalculator):
Gap extend penatly
"""

@deprecated(
"""\
FastAlignmentDistanceCalculator achieves (depending on the settings) identical results
at a higher speed.
"""
)
def __init__(
self,
cutoff: Union[None, int] = None,
Expand Down Expand Up @@ -478,3 +484,209 @@ def _self_alignment_scores(self, seqs: Sequence) -> dict:
dtype=int,
count=len(seqs),
)


@_doc_params(params=_doc_params_parallel_distance_calculator)
class FastAlignmentDistanceCalculator(ParallelDistanceCalculator):
"""\
Calculates distance between sequences based on pairwise sequence alignment.
The distance between two sequences is defined as :math:`S_{{1,2}}^{{max}} - S_{{1,2}}`,
where :math:`S_{{1,2}}` is the alignment score of sequences 1 and 2 and
:math:`S_{{1,2}}^{{max}}` is the max. achievable alignment score of sequences 1 and 2.
:math:`S_{{1,2}}^{{max}}` is defined as :math:`\\min(S_{{1,1}}, S_{{2,2}})`.
The use of alignment-based distances is heavily inspired by :cite:`TCRdist`.
High-performance sequence alignments are calculated leveraging
the `parasail library <https://github.com/jeffdaily/parasail-python>`_ (:cite:`Daily2016`).
To speed up the computation, we pre-filter sequence pairs based on
a) differences in sequence length
b) the number of different characters, based on an estimate of the mismatch penalty (`estimated_penalty`).
The filtering based on `estimated_penalty` is a *heuristic* and *may lead to false negatives*, i.e. sequence
pairs that are actually below the distance cutoff, but are removed during pre-filtering. Sensible values for
`estimated_penalty` are depending on the substitution matrix, but higher values lead to a higher false negative rate.
We provide default values for BLOSUM and PAM matrices. The provided default values were obtained by testing different
alues on the Wu dataset (:cite:`Wu2020`) and selecting those that provided a reasonable balance between
speedup and loss. Loss stayed well under 10% in all our test cases with the default values, and speedup
increases with the number of cells.
While the length-based filter is always active, the filter based on different characters can be disabled by
setting the estimated penalty to zero. Using length-based filtering only, there won't be any false negatives
*unless with a substitution matrix in which a substitution results in a higher score than the corresponding match.*
Using the length-based filter only results in a substancially reduced speedup compared to combining it with
the `estimated_penalty` heuristic.
Choosing a cutoff:
Alignment distances need to be viewed in the light of the substitution matrix.
The alignment distance is the difference between the actual alignment
score and the max. achievable alignment score. For instance, a mutation
from *Leucine* (`L`) to *Isoleucine* (`I`) results in a BLOSUM62 score of `2`.
An `L` aligned with `L` achieves a score of `4`. The distance is, therefore, `2`.
On the other hand, a single *Tryptophane* (`W`) mutating into, e.g.
*Proline* (`P`) already results in a distance of `15`.
We are still lacking empirical data up to which distance a CDR3 sequence still
is likely to recognize the same antigen, but reasonable cutoffs are `<15`.
Choosing an expected penalty:
The choice of an expected penalty is likely influenced by similar considerations as the
other parameters. Essentially, this can be thought of as a superficial (dis)similarity
measure. A higher value more strongly penalizes mismatching characters and is more in line
with looking for closely related sequence pairs, while a lower value is more forgiving
and better suited when looking for more distantly related sequence pairs.
Parameters
----------
cutoff
Will eleminate distances > cutoff to make efficient
use of sparse matrices. The default cutoff is `10`.
{params}
subst_mat
Name of parasail substitution matrix
gap_open
Gap open penalty
gap_extend
Gap extend penatly
estimated_penalty
Estimate of the average mismatch penalty
"""

def __init__(
self,
cutoff: Union[None, int] = None,
*,
n_jobs: Union[int, None] = None,
block_size: int = 50,
subst_mat: str = "blosum62",
gap_open: int = 11,
gap_extend: int = 11,
estimated_penalty: float = None,
):
if cutoff is None:
cutoff = 10
super().__init__(cutoff, n_jobs=n_jobs, block_size=block_size)
self.subst_mat = subst_mat
self.gap_open = gap_open
self.gap_extend = gap_extend

penalty_dict = {
"blosum30": 4.0,
"blosum35": 4.0,
"blosum40": 4.0,
"blosum45": 4.0,
"blosum50": 4.0,
"blosum55": 4.0,
"blosum60": 4.0,
"blosum62": 4.0,
"blosum65": 4.0,
"blosum70": 4.0,
"blosum75": 4.0,
"blosum80": 4.0,
"blosum85": 4.0,
"blosum90": 4.0,
"pam10": 8.0,
"pam20": 8.0,
"pam30": 8.0,
"pam40": 8.0,
"pam50": 8.0,
"pam60": 4.0,
"pam70": 4.0,
"pam80": 4.0,
"pam90": 4.0,
"pam100": 4.0,
"pam110": 2.0,
"pam120": 2.0,
"pam130": 2.0,
"pam140": 2.0,
"pam150": 2.0,
"pam160": 2.0,
"pam170": 2.0,
"pam180": 2.0,
"pam190": 2.0,
"pam200": 2.0,
}

if subst_mat not in penalty_dict.keys():
raise Exception("Invalid substitution matrix.")

self.estimated_penalty = estimated_penalty if estimated_penalty is not None else penalty_dict[subst_mat]

def _compute_block(self, seqs1, seqs2, origin):
import parasail

subst_mat = parasail.Matrix(self.subst_mat)
origin_row, origin_col = origin

square_matrix = seqs2 is None
if square_matrix:
seqs2 = seqs1

self_alignment_scores1 = self._self_alignment_scores(seqs1)
if square_matrix:
self_alignment_scores2 = self_alignment_scores1
else:
self_alignment_scores2 = self._self_alignment_scores(seqs2)

max_len_diff = ((self.cutoff - self.gap_open) / self.gap_extend) + 1

result = []
for row, s1 in enumerate(seqs1):
col_start = row if square_matrix else 0
profile = parasail.profile_create_16(s1, subst_mat)
len1 = len(s1)

for col, s2 in enumerate(seqs2[col_start:], start=col_start):
len_diff = abs(len1 - len(s2))
# No need to calculate diagonal values
if s1 == s2:
result.append((1, origin_row + row, origin_col + col))
# Dismiss sequences based on length
elif len_diff <= max_len_diff:
# Dismiss sequences that are too different
if (
self._num_different_characters(s1, s2, len_diff) * self.estimated_penalty
+ len_diff * self.gap_extend
<= self.cutoff
):
r = parasail.nw_scan_profile_16(profile, s2, self.gap_open, self.gap_extend)
max_score = np.min([self_alignment_scores1[row], self_alignment_scores2[col]])
d = max_score - r.score

if d <= self.cutoff:
result.append((d + 1, origin_row + row, origin_col + col))

return result

def _self_alignment_scores(self, seqs: Sequence) -> dict:
"""Calculate self-alignments. We need them as reference values
to turn scores into dists
"""
import parasail

return np.fromiter(
(
parasail.nw_scan_16(
s,
s,
self.gap_open,
self.gap_extend,
parasail.Matrix(self.subst_mat),
).score
for s in seqs
),
dtype=int,
count=len(seqs),
)

def _num_different_characters(self, s1, s2, len_diff):
longer, shorter = (s1, s2) if len(s1) >= len(s2) else (s2, s1)

for c in shorter:
if c in longer:
longer = longer.replace(c, "", 1)
return len(longer) - len_diff
Loading

0 comments on commit e1c8028

Please sign in to comment.