Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 7, 2024
1 parent 6dd0e0b commit c90ac96
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 56 deletions.
128 changes: 76 additions & 52 deletions src/scirpy/ir_dist/_clonotype_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import itertools
from collections.abc import Mapping, Sequence
from multiprocessing import cpu_count
from typing import Literal, Optional, Union

import numpy as np
import pandas as pd
import scipy.sparse as sp
from scanpy import logging
from tqdm.contrib.concurrent import process_map

from scirpy.get import _has_ir
from scirpy.get import airr as get_airr
from scirpy.util import DataHandler, tqdm
from scirpy.util import DataHandler

from ._util import DoubleLookupNeighborFinder, merge_coo_matrices, reduce_and, reduce_or
from ._util import DoubleLookupNeighborFinder


class ClonotypeNeighbors:
Expand Down Expand Up @@ -205,151 +203,177 @@ def compute_distances(self) -> sp.csr_matrix:
start = logging.info("Computing clonotype x clonotype distances.") # type: ignore
n_clonotypes = self.clonotypes.shape[0]
clonotype_ids = np.arange(n_clonotypes)

dist = self._dist_for_clonotype(clonotype_ids)

dist.eliminate_zeros()
logging.hint("Done computing clonotype x clonotype distances. ", time=start)
return dist # type: ignore

def _dist_for_clonotype(self, ct_ids: np.ndarray[int]) -> sp.csr_matrix:

lookup = {}
lookup = {}
chain_ids = [(1, 1)] if self.dual_ir == "primary_only" else [(1, 1), (2, 2), (1, 2), (2, 1)]
for tmp_arm in self._receptor_arm_cols:
for c1, c2 in chain_ids:
lookup[(tmp_arm, c1, c2)] = self.neighbor_finder.lookup(
ct_ids,
f"{tmp_arm}_{c1}",
f"{tmp_arm}_{c2}",
)
)
id_len = len(ct_ids)

has_distance_table = sp.csr_matrix((id_len, id_len))
for value in lookup.values():
has_distance_table += value

has_distance_mask = has_distance_table
has_distance_mask.data = np.ones_like(has_distance_mask.data)

def csr_min(a,b):
def csr_min(a, b):
max_value_a = np.max(a.data, initial=0)
max_value_b = np.max(b.data, initial=0)
max_value = np.max([max_value_a,max_value_b]) + 1
max_value = np.max([max_value_a, max_value_b]) + 1
min_mat_a = a.copy()
min_mat_a.data -= max_value
min_mat_b = b.copy()
min_mat_b.data -= max_value
a_smaller_b = min_mat_a < min_mat_b
return b + (a-b).multiply(a_smaller_b)
return b + (a - b).multiply(a_smaller_b)

def csr_max(a, b):
max_value_a = np.max(a.data)
max_value_b = np.max(b.data)
max_value = np.max([max_value_a,max_value_b]) + 1
max_value = np.max([max_value_a, max_value_b]) + 1
max_mat_a = a.copy()
max_mat_a.data -= max_value
max_mat_b = b.copy()
max_mat_b.data -= max_value
a_greater_b = max_mat_a > max_mat_b
return b + (a-b).multiply(a_greater_b)
return b + (a - b).multiply(a_greater_b)

if self.match_columns is not None:
distance_matrix_name, forward, _ = self.neighbor_finder.lookups["match_columns"]
distance_matrix_name_reverse, _, reverse = self.neighbor_finder.lookups["match_columns"]
if distance_matrix_name != distance_matrix_name_reverse:
raise ValueError("Forward and reverse lookup tablese must be defined " "on the same distance matrices.")
raise ValueError("Forward and reverse lookup tablese must be defined " "on the same distance matrices.")
reverse_lookup_values = np.vstack(list(reverse.lookup.values()))
reverse_lookup_keys = np.zeros(id_len, dtype=np.int32)
reverse_lookup_keys[list(reverse.lookup.keys())] = np.arange(len(list(reverse.lookup.keys())))
match_column_mask = sp.csr_matrix((np.empty(len(has_distance_mask.indices)), has_distance_mask.indices, has_distance_mask.indptr), shape = has_distance_mask.shape)
match_column_mask = sp.csr_matrix(
(np.empty(len(has_distance_mask.indices)), has_distance_mask.indices, has_distance_mask.indptr),
shape=has_distance_mask.shape,
)
has_distance_mask_coo = match_column_mask.tocoo()
indices_in_dist_mat = forward[has_distance_mask_coo.row]
match_column_mask.data = reverse_lookup_values[reverse_lookup_keys[indices_in_dist_mat], has_distance_mask_coo.col]

match_column_mask.data = reverse_lookup_values[
reverse_lookup_keys[indices_in_dist_mat], has_distance_mask_coo.col
]

tmp_arm_res = {}
chain_res = {}

def filter_chain_count_data(matrix_coo_data_chain_filtered, matrix_coo_data, matrix_coo_row, matrix_coo_col, chain_count_array):

def filter_chain_count_data(
matrix_coo_data_chain_filtered, matrix_coo_data, matrix_coo_row, matrix_coo_col, chain_count_array
):
data_indices = np.arange(len(matrix_coo_data))
chain_counts1 = chain_count_array[matrix_coo_row]
chain_counts2 = chain_count_array[matrix_coo_col]
chain_counts_equal = chain_counts1 == chain_counts2
matrix_coo_data_chain_filtered[chain_counts1[chain_counts_equal],data_indices[chain_counts_equal]] = matrix_coo_data[chain_counts_equal]
return matrix_coo_data_chain_filtered[0], matrix_coo_data_chain_filtered[1], matrix_coo_data_chain_filtered[2]
matrix_coo_data_chain_filtered[
chain_counts1[chain_counts_equal], data_indices[chain_counts_equal]
] = matrix_coo_data[chain_counts_equal]
return (
matrix_coo_data_chain_filtered[0],
matrix_coo_data_chain_filtered[1],
matrix_coo_data_chain_filtered[2],
)

def filter_chain_count(matrix: sp.csr_matrix, col: str) -> sp.csr_matrix:
chain_count = self._chain_count[col]
matrix_coo = matrix.tocoo()
matrix_coo_data_chain_filtered = np.array([np.zeros_like(matrix_coo.data),np.zeros_like(matrix_coo.data),np.zeros_like(matrix_coo.data)])
matrix_coo_data_chain_filtered = np.array(
[np.zeros_like(matrix_coo.data), np.zeros_like(matrix_coo.data), np.zeros_like(matrix_coo.data)]
)
csr_filtered1, csr_filtered2, csr_filtered3 = matrix.copy(), matrix.copy(), matrix.copy()
csr_filtered1.data, csr_filtered2.data, csr_filtered3.data = filter_chain_count_data(matrix_coo_data_chain_filtered, matrix_coo.data, matrix_coo.row, matrix_coo.col, chain_count)
csr_filtered1.data, csr_filtered2.data, csr_filtered3.data = filter_chain_count_data(
matrix_coo_data_chain_filtered, matrix_coo.data, matrix_coo.row, matrix_coo.col, chain_count
)
return csr_filtered1, csr_filtered2, csr_filtered3

for tmp_arm in self._receptor_arm_cols:
for c1, c2 in chain_ids:
tmp_arrays = lookup[(tmp_arm, c1, c2)][ct_ids]
if(not (self.same_v_gene or self.match_columns)):

if not (self.same_v_gene or self.match_columns):
tmp_arrays = tmp_arrays.multiply(has_distance_mask)

if self.same_v_gene:
distance_matrix_name, forward, _ = self.neighbor_finder.lookups[f"{tmp_arm}_{c1}_v_call"]
distance_matrix_name_reverse, _, reverse = self.neighbor_finder.lookups[f"{tmp_arm}_{c2}_v_call"]
if distance_matrix_name != distance_matrix_name_reverse:
raise ValueError("Forward and reverse lookup tablese must be defined " "on the same distance matrices.")
raise ValueError(
"Forward and reverse lookup tablese must be defined " "on the same distance matrices."
)
empty_row = np.array([np.zeros(len(ct_ids), dtype=bool)])
reverse_lookup_values = np.vstack((*reverse.lookup.values(), empty_row))
reverse_lookup_keys = np.full(id_len, -1, dtype=np.int32)
keys_array = np.fromiter(reverse.lookup.keys(), dtype=int, count=len(reverse.lookup))
reverse_lookup_keys[keys_array] = np.arange(len(keys_array))
v_gene_mask = sp.csr_matrix((np.empty(len(has_distance_mask.indices)), has_distance_mask.indices, has_distance_mask.indptr), shape = has_distance_mask.shape)
v_gene_mask = sp.csr_matrix(
(np.empty(len(has_distance_mask.indices)), has_distance_mask.indices, has_distance_mask.indptr),
shape=has_distance_mask.shape,
)
has_distance_mask_coo = v_gene_mask.tocoo()
indices_in_dist_mat = forward[has_distance_mask_coo.row]
v_gene_mask.data = reverse_lookup_values[reverse_lookup_keys[indices_in_dist_mat], has_distance_mask_coo.col]
v_gene_mask.data = reverse_lookup_values[
reverse_lookup_keys[indices_in_dist_mat], has_distance_mask_coo.col
]
tmp_arrays = tmp_arrays.multiply(v_gene_mask)

if self.match_columns is not None:
tmp_arrays = tmp_arrays.multiply(match_column_mask)

if(self.dual_ir == "all"):
filtered1, filtered2, filtered3 = filter_chain_count(tmp_arrays,tmp_arm)
if self.dual_ir == "all":
filtered1, filtered2, filtered3 = filter_chain_count(tmp_arrays, tmp_arm)
chain_res[(tmp_arm, c1, c2, 0)] = filtered1
chain_res[(tmp_arm, c1, c2, 1)] = filtered2
chain_res[(tmp_arm, c1, c2, 2)] = filtered3
else:
chain_res[(tmp_arm, c1, c2)] = tmp_arrays

for c1, c2 in chain_ids:
if self.dual_ir == "primary_only":
tmp_arm_res[tmp_arm] = chain_res[(tmp_arm, 1, 1)]
elif self.dual_ir == "any":
tmp_arm_res[tmp_arm] = csr_min(csr_min(chain_res[(tmp_arm, 1, 1)],
chain_res[(tmp_arm, 1, 2)]),
csr_min(chain_res[(tmp_arm, 2, 1)],
chain_res[(tmp_arm, 2, 2)]))
tmp_arm_res[tmp_arm] = csr_min(
csr_min(chain_res[(tmp_arm, 1, 1)], chain_res[(tmp_arm, 1, 2)]),
csr_min(chain_res[(tmp_arm, 2, 1)], chain_res[(tmp_arm, 2, 2)]),
)
elif self.dual_ir == "all":
tmp_arm_res[tmp_arm] = csr_min(csr_max(chain_res[(tmp_arm, 1, 1, 2)],
chain_res[(tmp_arm, 2, 2, 2)]),
csr_max(chain_res[(tmp_arm, 2, 1, 2)],
chain_res[(tmp_arm, 1, 2, 2)]))
tmp_arm_res[tmp_arm] = csr_min(
csr_max(chain_res[(tmp_arm, 1, 1, 2)], chain_res[(tmp_arm, 2, 2, 2)]),
csr_max(chain_res[(tmp_arm, 2, 1, 2)], chain_res[(tmp_arm, 1, 2, 2)]),
)

tmp_arm_res[tmp_arm] += chain_res[(tmp_arm, 1, 1, 1)] + chain_res[(tmp_arm, 1, 1, 0)]

else:
assert False, "self.dual_ir method " + self.dual_ir + "not implemented"
assert False, "self.dual_ir method " + self.dual_ir + "not implemented"

if(len(tmp_arm_res) == 1):
if len(tmp_arm_res) == 1:
final_res = tmp_arm_res[self._receptor_arm_cols[0]]
else:

if(self.receptor_arms == "all"):
if self.receptor_arms == "all":
arm_res = {}
arm_res[("VJ", 0)], arm_res[("VJ", 1)], arm_res[("VJ", 2)] = filter_chain_count(tmp_arm_res["VJ"], "arms")
arm_res[("VDJ", 0)], arm_res[("VDJ", 1)], arm_res[("VDJ", 2)] = filter_chain_count(tmp_arm_res["VDJ"], "arms")
arm_res[("VJ", 0)], arm_res[("VJ", 1)], arm_res[("VJ", 2)] = filter_chain_count(
tmp_arm_res["VJ"], "arms"
)
arm_res[("VDJ", 0)], arm_res[("VDJ", 1)], arm_res[("VDJ", 2)] = filter_chain_count(
tmp_arm_res["VDJ"], "arms"
)
final_res = csr_max(arm_res[("VJ", 2)], arm_res[("VDJ", 2)])
final_res += arm_res[("VJ", 0)] + arm_res[("VJ", 1)] + arm_res[("VDJ", 0)] + arm_res[("VDJ", 1)]

else:
final_res = csr_min(tmp_arm_res["VJ"], tmp_arm_res["VDJ"])
return final_res
return final_res
7 changes: 3 additions & 4 deletions src/scirpy/ir_dist/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,18 @@ def lookup(
forward_lookup_table: str,
reverse_lookup_table: Union[str, None] = None,
) -> Union[list[coo_matrix], list[np.ndarray]]:

distance_matrix_name, forward, _ = self.lookups[forward_lookup_table]

if reverse_lookup_table is not None:
distance_matrix_name_reverse, _, reverse = self.lookups[reverse_lookup_table]
if distance_matrix_name != distance_matrix_name_reverse:
raise ValueError("Forward and reverse lookup tablese must be defined " "on the same distance matrices.")

distance_matrix = self.distance_matrices[distance_matrix_name]
indices_in_dist_mat = forward[object_ids]
indices_in_dist_mat = indices_in_dist_mat + 1
empty_row = sp.csr_matrix((1, distance_matrix.shape[1]), dtype=distance_matrix.dtype)
distance_matrix_new = sp.vstack([empty_row, distance_matrix], format='csr')
distance_matrix_new = sp.vstack([empty_row, distance_matrix], format="csr")
rows = distance_matrix_new[indices_in_dist_mat, :]

reverse_empty_row_col = np.array([], dtype=np.int32)
Expand All @@ -262,7 +261,7 @@ def lookup(
reverse_table_data[key] = value.data
reverse_table_col[key] = value.col
nnz_array[key] = value.nnz

data = np.concatenate(reverse_table_data)
col = np.concatenate(reverse_table_col)
indptr = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(nnz_array)])
Expand Down

0 comments on commit c90ac96

Please sign in to comment.