From 59670f7920d965b3e3ecfdb30ac075d5fad69637 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Tue, 26 Nov 2024 17:16:46 +0100 Subject: [PATCH 01/16] init tools, move get_embeddings to tooling (with unit tests), input format detection --- src/crested/tl/__init__.py | 2 + src/crested/tl/_crested.py | 3 - src/crested/tl/_tools.py | 148 +++++++++++++++++++++++++++++++++++++ tests/test_tl.py | 111 ++++++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 src/crested/tl/_tools.py create mode 100644 tests/test_tl.py diff --git a/src/crested/tl/__init__.py b/src/crested/tl/__init__.py index a22b6bd..e19b7f5 100644 --- a/src/crested/tl/__init__.py +++ b/src/crested/tl/__init__.py @@ -7,6 +7,7 @@ from . import data, losses, metrics, zoo from ._configs import TaskConfig, default_configs from ._crested import Crested +from ._tools import get_embeddings if find_spec("modiscolite") is not None: MODISCOLITE_AVAILABLE = True @@ -35,6 +36,7 @@ "TaskConfig", "default_configs", "Crested", + "get_embeddings", ] if MODISCOLITE_AVAILABLE: diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index ea4b93d..bc1a70d 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -324,9 +324,6 @@ def fit( if self.model and ( not hasattr(self.model, "optimizer") or self.model.optimizer is None ): - logger.warning( - "Model does not have an optimizer. Please compile the model before training." - ) self.model.compile( optimizer=self.config.optimizer, loss=self.config.loss, diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py new file mode 100644 index 0000000..036d5e8 --- /dev/null +++ b/src/crested/tl/_tools.py @@ -0,0 +1,148 @@ +"""Tooling kit that handles predictions, contribution scores, enhancer design, ... .""" + +from __future__ import annotations + +import os +import re + +import keras +import numpy as np +from anndata import AnnData + +from crested.utils import ( + fetch_sequences, + one_hot_encode_sequence, +) + + +def detect_input_type(input): + """ + Detect the type of input provided. + + Parameters + ---------- + input : str | list[str] | np.array | AnnData + The input to detect the type of. + + Returns + ------- + str + One of ['sequence', 'region', 'anndata', 'array'], indicating the input type. + """ + dna_pattern = re.compile("^[ACGTNacgtn]+$") + if isinstance(input, AnnData): + return "anndata" + elif isinstance(input, list): + if all(":" in str(item) for item in input): # List of regions + return "region" + elif all( + isinstance(item, str) and dna_pattern.match(item) for item in input + ): # List of sequences + return "sequence" + else: + raise ValueError( + "List input must contain only valid region strings (chrom:var-end) or DNA sequences." + ) + elif isinstance(input, str): + if ":" in input: # Single region + return "region" + elif dna_pattern.match(input): # Single DNA sequence + return "sequence" + else: + raise ValueError( + "String input must be a valid region string (chrom:var-end) or DNA sequence." + ) + elif isinstance(input, np.ndarray): + if input.ndim == 3: + return "array" + else: + raise ValueError("Input one hot array must have shape (N, L, 4).") + else: + raise ValueError( + "Unsupported input type. Must be AnnData, str, list, or np.ndarray." + ) + + +def _transform_input(input, genome: os.PathLike | None = None) -> np.ndarray: + """ + Transform the input into a one-hot encoded matrix based on its type. + + Parameters + ---------- + input : str | list[str] | np.array | AnnData + Input data to preprocess. Can be a sequence, list of sequences, region, list of regions, or an AnnData object. + genome : str | None + Path to the genome file. Required if input is a region or AnnData. + + Returns + ------- + One-hot encoded matrix of shape (N, L, 4), where N is the number of sequences/regions and L is the sequence length. + """ + input_type = detect_input_type(input) + + if input_type == "anndata": + if genome is None: + raise ValueError( + "Genome file is required to fetch sequences for regions in AnnData." + ) + regions = list(input.var_names) + sequences = fetch_sequences(regions, genome) + elif input_type == "region": + if genome is None: + raise ValueError("Genome file is required to fetch sequences for regions.") + sequences = fetch_sequences(input, genome) + elif input_type == "sequence": + sequences = input if isinstance(input, list) else [input] + elif input_type == "array": + assert input.ndim == 3, "Input one hot array must have shape (N, L, 4)." + return input + + one_hot_data = np.array( + [one_hot_encode_sequence(seq, expand_dim=False) for seq in sequences] + ) + + return one_hot_data + + +def get_embeddings( + input: str | list[str] | np.array | AnnData, + model: keras.Model, + layer_name: str, + genome: str | None = None, + **kwargs, +) -> np.ndarray: + """ + Extract embeddings from a specified layer for all inputs. + + Parameters + ---------- + input + Input data to get embeddings for. Can be a (list of) sequence(s), a (list of) region name(s), a matrix of one hot encodings (N, L, 4), or an AnnData object with region names as its var_names. + model + A trained keras model from which to extract the embeddings. + layer_name + The name of the layer from which to extract the embeddings. + genome + Path to the genome file. Required if input is an anndata object or a list of regions. + **kwargs + Additional keyword arguments to pass to the keras.Model.predict method. + + Returns + ------- + Embeddings of shape (N, D), where N is the number of regions in the input and D is the size of the embedding layer. + """ + layer_names = [layer.name for layer in model.layers] + if layer_name not in layer_names: + raise ValueError( + f"Layer '{layer_name}' not found in model. Options (in reverse) are: {layer_names[::-1]}" + ) + embedding_model = keras.models.Model( + inputs=model.input, outputs=model.get_layer(layer_name).output + ) + input = _transform_input(input, genome) + n_predict_steps = ( + input.shape[0] if os.environ["KERAS_BACKEND"] == "tensorflow" else None + ) + embeddings = embedding_model.predict(input, steps=n_predict_steps, **kwargs) + + return embeddings diff --git a/tests/test_tl.py b/tests/test_tl.py new file mode 100644 index 0000000..8b81173 --- /dev/null +++ b/tests/test_tl.py @@ -0,0 +1,111 @@ +"""Test tl module.""" + +import os + +import genomepy +import keras +import numpy as np +import pytest + +import crested +from crested.tl._tools import _transform_input, detect_input_type + +from ._utils import create_anndata_with_regions + + +@pytest.fixture(scope="module") +def keras_model(): + from crested.tl.zoo import simple_convnet + + model = simple_convnet( + seq_len=500, + num_classes=10, + ) + model.compile( + optimizer=keras.optimizers.Adam(0.001), + loss=keras.losses.MeanSquaredError(), + metrics=[keras.metrics.CategoricalAccuracy()], + ) + return model + + +@pytest.fixture(scope="module") +def adata(): + regions = [ + "chr1:194208032-194208532", + "chr1:92202766-92203266", + "chr1:92298990-92299490", + "chr1:3406052-3406552", + "chr1:183669567-183670067", + "chr1:109912183-109912683", + "chr1:92210697-92211197", + "chr1:59100954-59101454", + "chr1:84634055-84634555", + "chr1:48792527-48793027", + ] + return create_anndata_with_regions(regions) + + +@pytest.fixture(scope="module") +def genome(): + if not os.path.exists("tests/data/genomes/hg38.fa"): + genomepy.install_genome( + "hg38", annotation=False, provider="UCSC", genomes_dir="tests/data/genomes" + ) + return "tests/data/genomes/hg38/hg38.fa" + + +def test_input_type(adata): + assert detect_input_type(adata) == "anndata" + assert detect_input_type("chr1:1-100") == "region" + assert detect_input_type("ACGT") == "sequence" + assert detect_input_type(["chr1:1-100", "chr1:101-200"]) == "region" + assert detect_input_type(["ACGT", "ACGT"]) == "sequence" + with pytest.raises(ValueError): + detect_input_type(["chr1:1-100", "ACGT"]) + with pytest.raises(ValueError): + detect_input_type(["chr1:1-100", 1]) + with pytest.raises(ValueError): + detect_input_type([1, 2]) + with pytest.raises(ValueError): + detect_input_type([1, "ACGT"]) + with pytest.raises(ValueError): + detect_input_type(1) + with pytest.raises(ValueError): + detect_input_type(1.0) + with pytest.raises(ValueError): + detect_input_type(None) + + +def test_input_transform(genome): + assert _transform_input("AACGT").shape == (1, 5, 4) + assert np.array_equal( + _transform_input("ACGT"), + np.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]]), + ) + assert np.array_equal( + _transform_input(["ACGT", "ACGT"]), + np.array( + [ + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + ] + ), + ) + assert np.array_equal( + _transform_input("chr1:1-6", genome), np.array([[[0, 0, 0, 0]] * 5]) + ) + assert _transform_input(np.array([[[1, 0, 0, 0]]])).shape == (1, 1, 4) + + +def test_get_embeddings(keras_model, genome): + input = "ATCGA" * 100 + embeddings = crested.tl.get_embeddings( + input, keras_model, genome=genome, layer_name="denseblock_dense" + ) + assert embeddings.shape == (1, 8) + input = ["ATCGA" * 100, "ATCGA" * 100] + embeddings = crested.tl.get_embeddings( + input, keras_model, genome=genome, layer_name="denseblock_dense" + ) + assert embeddings.shape == (2, 8) From d45809d0daf7d27659202aa7d4d7d3536d097f81 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Tue, 26 Nov 2024 17:17:02 +0100 Subject: [PATCH 02/16] shape fix in docstring --- src/crested/utils/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crested/utils/_utils.py b/src/crested/utils/_utils.py index 5fde197..2b5f8ac 100644 --- a/src/crested/utils/_utils.py +++ b/src/crested/utils/_utils.py @@ -52,7 +52,7 @@ def one_hot_encode_sequence(sequence: str, expand_dim: bool = True) -> np.ndarra """ One hot encode a DNA sequence. - Will return a numpy array with shape (len(sequence), 4) if expand_dim is True, otherwise (4,). + Will return a numpy array with shape (1, len(sequence), 4) if expand_dim is True, otherwise (len(sequence),4). Alphabet is ACGT. Parameters From b741af9c3b2936e9af1e27de9a16f5d43ae5cba6 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Tue, 26 Nov 2024 17:35:11 +0100 Subject: [PATCH 03/16] move predict to tooling & unittests --- src/crested/tl/__init__.py | 3 ++- src/crested/tl/_tools.py | 39 ++++++++++++++++++++++++++++++++++++-- tests/test_tl.py | 11 ++++++++++- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/crested/tl/__init__.py b/src/crested/tl/__init__.py index e19b7f5..b1cabbe 100644 --- a/src/crested/tl/__init__.py +++ b/src/crested/tl/__init__.py @@ -7,7 +7,7 @@ from . import data, losses, metrics, zoo from ._configs import TaskConfig, default_configs from ._crested import Crested -from ._tools import get_embeddings +from ._tools import get_embeddings, predict if find_spec("modiscolite") is not None: MODISCOLITE_AVAILABLE = True @@ -37,6 +37,7 @@ "default_configs", "Crested", "get_embeddings", + "predict", ] if MODISCOLITE_AVAILABLE: diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index 036d5e8..2bee366 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -108,7 +108,7 @@ def get_embeddings( input: str | list[str] | np.array | AnnData, model: keras.Model, layer_name: str, - genome: str | None = None, + genome: os.PathLike | None = None, **kwargs, ) -> np.ndarray: """ @@ -123,7 +123,7 @@ def get_embeddings( layer_name The name of the layer from which to extract the embeddings. genome - Path to the genome file. Required if input is an anndata object or a list of regions. + Path to the genome file. Required if input is an anndata object or region names. **kwargs Additional keyword arguments to pass to the keras.Model.predict method. @@ -146,3 +146,38 @@ def get_embeddings( embeddings = embedding_model.predict(input, steps=n_predict_steps, **kwargs) return embeddings + + +def predict( + input: str | list[str] | np.array | AnnData, + model: keras.Model, + genome: os.PathLike | None = None, +) -> None | np.ndarray: + """ + Make predictions using the model on the full dataset. + + If anndata and model_name are provided, will add the predictions to anndata as a .layers[model_name] attribute. + Else, will return the predictions as a numpy array. + + Parameters + ---------- + input + Input data to make predictions on. Can be a (list of) sequence(s), a (list of) region name(s), a matrix of one hot encodings (N, L, 4), or an AnnData object with region names as its var_names. + model + A trained keras model to make predictions with. + genome + Path to the genome file. Required if input is an anndata object or region names. + + Returns + ------- + Predictions of shape (N, C) + """ + input = _transform_input(input, genome) + + n_predict_steps = ( + input.shape[0] if os.environ["KERAS_BACKEND"] == "tensorflow" else None + ) + + predictions = model.predict(input, steps=n_predict_steps) + + return predictions diff --git a/tests/test_tl.py b/tests/test_tl.py index 8b81173..2bd6fb1 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -106,6 +106,15 @@ def test_get_embeddings(keras_model, genome): assert embeddings.shape == (1, 8) input = ["ATCGA" * 100, "ATCGA" * 100] embeddings = crested.tl.get_embeddings( - input, keras_model, genome=genome, layer_name="denseblock_dense" + input, keras_model, layer_name="denseblock_dense" ) assert embeddings.shape == (2, 8) + + +def test_predict(keras_model, adata, genome): + input = "ATCGA" * 100 + predictions = crested.tl.predict(input, keras_model) + assert predictions.shape == (1, 10) + + predictions = crested.tl.predict(input=adata, model=keras_model, genome=genome) + assert predictions.shape == (10, 10) From ee6f335c761b8904dd460f7ba8432f0fb06a6b40 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Tue, 26 Nov 2024 17:54:07 +0100 Subject: [PATCH 04/16] allow multiple models as input to predict --- src/crested/tl/_tools.py | 40 ++++++++++++++++++++++++++-------------- tests/test_tl.py | 36 ++++++++++++++++++++---------------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index 2bee366..7daaef5 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -15,7 +15,7 @@ ) -def detect_input_type(input): +def _detect_input_type(input): """ Detect the type of input provided. @@ -78,7 +78,7 @@ def _transform_input(input, genome: os.PathLike | None = None) -> np.ndarray: ------- One-hot encoded matrix of shape (N, L, 4), where N is the number of sequences/regions and L is the sequence length. """ - input_type = detect_input_type(input) + input_type = _detect_input_type(input) if input_type == "anndata": if genome is None: @@ -139,34 +139,32 @@ def get_embeddings( embedding_model = keras.models.Model( inputs=model.input, outputs=model.get_layer(layer_name).output ) - input = _transform_input(input, genome) - n_predict_steps = ( - input.shape[0] if os.environ["KERAS_BACKEND"] == "tensorflow" else None - ) - embeddings = embedding_model.predict(input, steps=n_predict_steps, **kwargs) + embeddings = predict(input, embedding_model, genome, **kwargs) return embeddings def predict( input: str | list[str] | np.array | AnnData, - model: keras.Model, + model: keras.Model | list[keras.Model], genome: os.PathLike | None = None, + **kwargs, ) -> None | np.ndarray: """ - Make predictions using the model on the full dataset. + Make predictions using the model(s) on the full dataset. - If anndata and model_name are provided, will add the predictions to anndata as a .layers[model_name] attribute. - Else, will return the predictions as a numpy array. + If a list of models is provided, the predictions will be averaged across all models. Parameters ---------- input Input data to make predictions on. Can be a (list of) sequence(s), a (list of) region name(s), a matrix of one hot encodings (N, L, 4), or an AnnData object with region names as its var_names. model - A trained keras model to make predictions with. + A (list of) trained keras models to make predictions with. genome Path to the genome file. Required if input is an anndata object or region names. + **kwargs + Additional keyword arguments to pass to the keras.Model.predict method. Returns ------- @@ -178,6 +176,20 @@ def predict( input.shape[0] if os.environ["KERAS_BACKEND"] == "tensorflow" else None ) - predictions = model.predict(input, steps=n_predict_steps) + if isinstance(model, list): + if not all(isinstance(m, keras.Model) for m in model): + raise ValueError("All items in the model list must be Keras models.") + + all_predictions = [] + for m in model: + predictions = m.predict(input, steps=n_predict_steps, **kwargs) + all_predictions.append(predictions) + + averaged_predictions = np.mean(all_predictions, axis=0) + return averaged_predictions + else: + if not isinstance(model, keras.Model): + raise ValueError("Model must be a Keras model or a list of Keras models.") - return predictions + predictions = model.predict(input, steps=n_predict_steps, **kwargs) + return predictions diff --git a/tests/test_tl.py b/tests/test_tl.py index 2bd6fb1..6f33bde 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -8,7 +8,7 @@ import pytest import crested -from crested.tl._tools import _transform_input, detect_input_type +from crested.tl._tools import _detect_input_type, _transform_input from ._utils import create_anndata_with_regions @@ -19,7 +19,7 @@ def keras_model(): model = simple_convnet( seq_len=500, - num_classes=10, + num_classes=5, ) model.compile( optimizer=keras.optimizers.Adam(0.001), @@ -56,25 +56,25 @@ def genome(): def test_input_type(adata): - assert detect_input_type(adata) == "anndata" - assert detect_input_type("chr1:1-100") == "region" - assert detect_input_type("ACGT") == "sequence" - assert detect_input_type(["chr1:1-100", "chr1:101-200"]) == "region" - assert detect_input_type(["ACGT", "ACGT"]) == "sequence" + assert _detect_input_type(adata) == "anndata" + assert _detect_input_type("chr1:1-100") == "region" + assert _detect_input_type("ACGT") == "sequence" + assert _detect_input_type(["chr1:1-100", "chr1:101-200"]) == "region" + assert _detect_input_type(["ACGT", "ACGT"]) == "sequence" with pytest.raises(ValueError): - detect_input_type(["chr1:1-100", "ACGT"]) + _detect_input_type(["chr1:1-100", "ACGT"]) with pytest.raises(ValueError): - detect_input_type(["chr1:1-100", 1]) + _detect_input_type(["chr1:1-100", 1]) with pytest.raises(ValueError): - detect_input_type([1, 2]) + _detect_input_type([1, 2]) with pytest.raises(ValueError): - detect_input_type([1, "ACGT"]) + _detect_input_type([1, "ACGT"]) with pytest.raises(ValueError): - detect_input_type(1) + _detect_input_type(1) with pytest.raises(ValueError): - detect_input_type(1.0) + _detect_input_type(1.0) with pytest.raises(ValueError): - detect_input_type(None) + _detect_input_type(None) def test_input_transform(genome): @@ -114,7 +114,11 @@ def test_get_embeddings(keras_model, genome): def test_predict(keras_model, adata, genome): input = "ATCGA" * 100 predictions = crested.tl.predict(input, keras_model) - assert predictions.shape == (1, 10) + assert predictions.shape == (1, 5) predictions = crested.tl.predict(input=adata, model=keras_model, genome=genome) - assert predictions.shape == (10, 10) + assert predictions.shape == (10, 5) + + models = [keras_model, keras_model] + predictions = crested.tl.predict(input=adata, model=models, genome=genome) + assert predictions.shape == (10, 5) From 8647d63e0e9b088c13b8ebbe3296bad058930ddc Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 27 Nov 2024 14:36:17 +0100 Subject: [PATCH 05/16] score_gene_locus + unit tests in tooling ml and removed window_size --- src/crested/tl/__init__.py | 3 +- src/crested/tl/_tools.py | 133 +++++++++++++++++++++++++++++++++++++ tests/test_tl.py | 67 ++++++------------- 3 files changed, 154 insertions(+), 49 deletions(-) diff --git a/src/crested/tl/__init__.py b/src/crested/tl/__init__.py index b1cabbe..2562372 100644 --- a/src/crested/tl/__init__.py +++ b/src/crested/tl/__init__.py @@ -7,7 +7,7 @@ from . import data, losses, metrics, zoo from ._configs import TaskConfig, default_configs from ._crested import Crested -from ._tools import get_embeddings, predict +from ._tools import get_embeddings, predict, score_gene_locus if find_spec("modiscolite") is not None: MODISCOLITE_AVAILABLE = True @@ -38,6 +38,7 @@ "Crested", "get_embeddings", "predict", + "score_gene_locus", ] if MODISCOLITE_AVAILABLE: diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index 7daaef5..8f6867f 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -193,3 +193,136 @@ def predict( predictions = model.predict(input, steps=n_predict_steps, **kwargs) return predictions + + +def score_gene_locus( + gene_locus: str, + all_class_names: list[str], + class_name: str, + model: keras.Model | list[keras.Model], + genome: os.PathLike, + strand: str = "+", + upstream: int = 50000, + downstream: int = 10000, + central_size: int = 1000, + step_size: int = 50, + **kwargs, +) -> tuple[np.ndarray, np.ndarray, int, int, int]: + """ + Score regions upstream and downstream of a gene locus using the model's prediction. + + The model predicts a value for the {central_size} of each window. + + Parameters + ---------- + gene_locus + The gene locus to score in the format 'chr:start-end'. + Start is the TSS for + strand and TES for - strand. + all_class_names + List of all class names in the model. Usually obtained with `list(anndata.obs_names)`. + class_name + Output class name to be used for prediction. Required to index the predictions. + model + A (list of) trained keras model(s) to make predictions with. + genome + Path to the genome file. + strand + '+' for positive strand, '-' for negative strand. Default '+'. + upstream + Distance upstream of the gene to score. + downstream + Distance downstream of the gene to score. + central_size + Size of the central region that the model predicts for. + step_size + Distance between consecutive windows. + **kwargs + Additional keyword arguments to pass to the keras.Model.predict method. + + Returns + ------- + scores + An array of prediction scores across the entire genomic range. + coordinates + An array of tuples, each containing the chromosome name and the start and end positions of the sequence for each window. + min_loc + Start position of the entire scored region. + max_loc + End position of the entire scored region. + tss_position + The transcription start site (TSS) position. + + See Also + -------- + crested.tl.predict + """ + chr_name, gene_locus = gene_locus.split(":") + gene_start, gene_end = map(int, gene_locus.split("-")) + + # Detect window size from the model input shape + if isinstance(model, list): + input_shape = model[0].input_shape + else: + input_shape = model.input_shape + window_size = input_shape[1] + + # Adjust upstream and downstream based on the strand + if strand == "+": + start_position = gene_start - upstream + end_position = gene_end + downstream + tss_position = gene_start # TSS is at the gene_start for positive strand + elif strand == "-": + end_position = gene_end + upstream + start_position = gene_start - downstream + tss_position = gene_end # TSS is at the gene_end for negative strand + else: + raise ValueError("Strand must be '+' or '-'.") + + start_position = max(0, start_position) + total_length = abs(end_position - start_position) + + # Ratio to normalize the score contributions + ratio = central_size / step_size + + try: + idx = all_class_names.index(class_name) + except ValueError as e: + raise ValueError( + f"Class name '{class_name}' not found in all_class_names" + ) from e + positions = np.arange(start_position, end_position - window_size + 1, step_size) + + all_regions = [ + f"{chr_name}:{pos}-{pos + window_size}" + for pos in range(start_position, end_position, step_size) + if pos + window_size <= end_position + ] + predictions = predict(input=all_regions, model=model, genome=genome, **kwargs) + predictions_class = predictions[:, idx] + + # Map predictions to the score array + scores = np.zeros(total_length) + for _, (pos, pred) in enumerate(zip(positions, predictions_class)): + central_start = pos + (window_size - central_size) // 2 + central_end = central_start + central_size + + # Compute indices relative to the scores array + relative_start = central_start - start_position + relative_end = central_end - start_position + + # Add the prediction to the scores array + scores[relative_start:relative_end] += pred + + window_starts = positions + window_ends = positions + window_size + coordinates = np.array( + list(zip([chr_name] * len(positions), window_starts, window_ends)) + ) + # Normalize the scores based on the number of times each position is included in the central window + return ( + scores / ratio, + coordinates, + start_position, + end_position, + tss_position, + ) diff --git a/tests/test_tl.py b/tests/test_tl.py index 6f33bde..0a5ae62 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -1,59 +1,11 @@ """Test tl module.""" -import os - -import genomepy -import keras import numpy as np import pytest import crested from crested.tl._tools import _detect_input_type, _transform_input -from ._utils import create_anndata_with_regions - - -@pytest.fixture(scope="module") -def keras_model(): - from crested.tl.zoo import simple_convnet - - model = simple_convnet( - seq_len=500, - num_classes=5, - ) - model.compile( - optimizer=keras.optimizers.Adam(0.001), - loss=keras.losses.MeanSquaredError(), - metrics=[keras.metrics.CategoricalAccuracy()], - ) - return model - - -@pytest.fixture(scope="module") -def adata(): - regions = [ - "chr1:194208032-194208532", - "chr1:92202766-92203266", - "chr1:92298990-92299490", - "chr1:3406052-3406552", - "chr1:183669567-183670067", - "chr1:109912183-109912683", - "chr1:92210697-92211197", - "chr1:59100954-59101454", - "chr1:84634055-84634555", - "chr1:48792527-48793027", - ] - return create_anndata_with_regions(regions) - - -@pytest.fixture(scope="module") -def genome(): - if not os.path.exists("tests/data/genomes/hg38.fa"): - genomepy.install_genome( - "hg38", annotation=False, provider="UCSC", genomes_dir="tests/data/genomes" - ) - return "tests/data/genomes/hg38/hg38.fa" - def test_input_type(adata): assert _detect_input_type(adata) == "anndata" @@ -122,3 +74,22 @@ def test_predict(keras_model, adata, genome): models = [keras_model, keras_model] predictions = crested.tl.predict(input=adata, model=models, genome=genome) assert predictions.shape == (10, 5) + + +def test_score_gene_locus(keras_model, adata, genome): + gene_locus = "chr1:200000-200500" + scores, coordinates, min_loc, max_loc, tss_pos = crested.tl.score_gene_locus( + gene_locus=gene_locus, + all_class_names=list(adata.obs_names), + class_name=list(adata.obs_names)[0], + model=keras_model, + genome=genome, + upstream=1000, + downstream=1000, + step_size=500, + ) + assert scores.shape == (2500,), scores.shape + assert coordinates.shape == (int(2500 / 500), 3), coordinates.shape + assert min_loc == 199000 + assert max_loc == 201500 + assert tss_pos == 200000 From c1d2f8ae7a728e59a7ecf6d4c14fea9d1ae2358d Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 27 Nov 2024 14:37:39 +0100 Subject: [PATCH 06/16] correctly use test fixtures --- tests/conftest.py | 54 ++++++++++++++++++++++++++++++++++++++++++ tests/test_pipeline.py | 22 +---------------- 2 files changed, 55 insertions(+), 21 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9f065c3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,54 @@ +"""Fixtures to be used by all unit tests.""" + +import os + +import genomepy +import keras +import pytest + +from ._utils import create_anndata_with_regions + + +@pytest.fixture(scope="module") +def keras_model(): + """Keras model fixture.""" + from crested.tl.zoo import simple_convnet + + model = simple_convnet( + seq_len=500, + num_classes=5, + ) + model.compile( + optimizer=keras.optimizers.Adam(0.001), + loss=keras.losses.MeanSquaredError(), + metrics=[keras.metrics.CategoricalAccuracy()], + ) + return model + + +@pytest.fixture(scope="module") +def adata(): + """Anndata fixture.""" + regions = [ + "chr1:194208032-194208532", + "chr1:92202766-92203266", + "chr1:92298990-92299490", + "chr1:3406052-3406552", + "chr1:183669567-183670067", + "chr1:109912183-109912683", + "chr1:92210697-92211197", + "chr1:59100954-59101454", + "chr1:84634055-84634555", + "chr1:48792527-48793027", + ] + return create_anndata_with_regions(regions) + + +@pytest.fixture(scope="module") +def genome(): + """Genome fixture.""" + if not os.path.exists("tests/data/genomes/hg38.fa"): + genomepy.install_genome( + "hg38", annotation=False, provider="UCSC", genomes_dir="tests/data/genomes" + ) + return "tests/data/genomes/hg38/hg38.fa" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a455e9a..92942a0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -7,25 +7,8 @@ import crested -from ._utils import create_anndata_with_regions - -REGIONS = [ - "chr1:194208032-194208532", - "chr1:92202766-92203266", - "chr1:92298990-92299490", - "chr1:3406052-3406552", - "chr1:183669567-183670067", - "chr1:109912183-109912683", - "chr1:92210697-92211197", - "chr1:59100954-59101454", - "chr1:84634055-84634555", - "chr1:48792527-48793027", -] - - -def test_peak_regression(): - adata = create_anndata_with_regions(REGIONS) +def test_peak_regression(adata): crested.pp.change_regions_width(adata, width=600) crested.pp.train_val_test_split( adata, strategy="region", val_size=0.1, test_size=0.1 @@ -34,9 +17,6 @@ def test_peak_regression(): genomepy.install_genome( "hg38", annotation=False, provider="UCSC", genomes_dir="tests/data/genomes" ) - print(adata) - print(adata.var) - print(adata.var_names) if os.path.exists("tests/data/test_pipeline"): import shutil From 3a43d0c7c428864d86a9253511db183448eece8f Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 27 Nov 2024 14:38:06 +0100 Subject: [PATCH 07/16] test units that ensure functionality is the same as before --- tests/test_refactor.py | 118 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 tests/test_refactor.py diff --git a/tests/test_refactor.py b/tests/test_refactor.py new file mode 100644 index 0000000..3c7e192 --- /dev/null +++ b/tests/test_refactor.py @@ -0,0 +1,118 @@ +"""Test that ensures the outputs after the functional refactor are the same as before.""" + +import keras +import numpy as np +import pytest + +from crested.tl import Crested, get_embeddings, predict, score_gene_locus +from crested.tl.data import AnnDataModule + +np.random.seed(42) +keras.utils.set_random_seed(42) + + +@pytest.fixture(scope="module") +def crested_object(keras_model, adata, genome): + anndatamodule = AnnDataModule( + adata, + genome_file=genome, + batch_size=32, + always_reverse_complement=False, + deterministic_shift=False, + shuffle=False, + ) + crested_object = Crested( + data=anndatamodule, + ) + crested_object.model = keras_model + return crested_object + + +def test_predict_adata(adata, crested_object, keras_model, genome): + crested_object_preds = crested_object.predict() + refactored_preds = predict(adata, keras_model, genome) + assert np.allclose( + crested_object_preds, + refactored_preds, + atol=1e-4, + ), "Anndata predictions are not equal." + + +def test_predict_sequence(crested_object, keras_model): + sequence = "ATCGA" * 100 + crested_object_preds = crested_object.predict_sequence(sequence) + refactored_preds = predict(sequence, keras_model) + assert np.allclose( + crested_object_preds, + refactored_preds, + atol=1e-4, + ), "Sequence predictions are not equal" + + +def test_predict_regions(crested_object, keras_model, genome): + regions = ["chr1:1-501", "chr1:101-601"] + crested_object_preds = crested_object.predict_regions(region_idx=regions) + refactored_preds = predict(regions, keras_model, genome) + assert np.allclose( + crested_object_preds, + refactored_preds, + atol=1e-4, + ), "Region predictions are not equal" + + +def test_get_embeddings(adata, crested_object, keras_model, genome): + crested_object_embeddings = crested_object.get_embeddings( + layer_name="denseblock_dense" + ) + refactored_embeddings = get_embeddings( + input=adata, + model=keras_model, + genome=genome, + layer_name="denseblock_dense", + ) + assert np.allclose( + crested_object_embeddings, + refactored_embeddings, + atol=1e-4, + ), "Embeddings are not equal." + + +def test_score_gene_locus(crested_object, adata, keras_model, genome): + chrom_name = "chr1" + gene_start = "2000000" + gene_end = "2002000" + scores, coordinates, min_loc, max_loc, tss_pos = crested_object.score_gene_locus( + chr_name=chrom_name, + gene_start=int(gene_start), + gene_end=int(gene_end), + class_name=list(adata.obs_names)[0], + window_size=500, + downstream=2000, + upstream=2000, + ) + ( + ref_scores, + ref_coordinates, + ref_min_loc, + ref_max_loc, + ref_tss_pos, + ) = score_gene_locus( + gene_locus=f"{chrom_name}:{gene_start}-{gene_end}", + all_class_names=list(adata.obs_names), + class_name=list(adata.obs_names)[0], + model=keras_model, + genome=genome, + downstream=2000, + upstream=2000, + ) + assert np.allclose( + scores, + ref_scores, + atol=1e-4, + ), "Scores are not equal." + assert ( + a == b for a, b in zip(coordinates, ref_coordinates) + ), "Coordinates are not equal." + assert min_loc == ref_min_loc, "Minimum location is not equal." + assert max_loc == ref_max_loc, "Maximum location is not equal." + assert tss_pos == ref_tss_pos, "TSS position is not equal." From 76355d8e5ca4e39dbb298ad9a4e56b1fedf06089 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 27 Nov 2024 16:35:15 +0100 Subject: [PATCH 08/16] contribution scores to tl ml + unit tests --- src/crested/tl/__init__.py | 3 +- src/crested/tl/_tools.py | 120 +++++++++++++++++++++++++++++++++++++ tests/test_refactor.py | 89 ++++++++++++++++++++++++++- tests/test_tl.py | 39 ++++++++++++ 4 files changed, 249 insertions(+), 2 deletions(-) diff --git a/src/crested/tl/__init__.py b/src/crested/tl/__init__.py index 2562372..055668b 100644 --- a/src/crested/tl/__init__.py +++ b/src/crested/tl/__init__.py @@ -7,7 +7,7 @@ from . import data, losses, metrics, zoo from ._configs import TaskConfig, default_configs from ._crested import Crested -from ._tools import get_embeddings, predict, score_gene_locus +from ._tools import contribution_scores, get_embeddings, predict, score_gene_locus if find_spec("modiscolite") is not None: MODISCOLITE_AVAILABLE = True @@ -38,6 +38,7 @@ "Crested", "get_embeddings", "predict", + "contribution_scores", "score_gene_locus", ] diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index 8f6867f..dd511c0 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -8,12 +8,19 @@ import keras import numpy as np from anndata import AnnData +from loguru import logger +from tqdm import tqdm from crested.utils import ( fetch_sequences, one_hot_encode_sequence, ) +if os.environ["KERAS_BACKEND"] == "tensorflow": + from crested.tl._explainer_tf import Explainer +elif os.environ["KERAS_BACKEND"] == "torch": + from crested.tl._explainer_torch import Explainer + def _detect_input_type(input): """ @@ -326,3 +333,116 @@ def score_gene_locus( end_position, tss_position, ) + + +def contribution_scores( + input: str | list[str] | np.array | AnnData, + model: keras.Model | list[keras.Model], + class_names: str | list[str], + all_class_names: list[str], + method: str = "expected_integrated_grad", + genome: os.PathLike | None = None, + disable_tqdm: bool = True, +) -> tuple[np.ndarray, np.ndarray]: + """ + Calculate contribution scores based on given method for the specified inputs. + + If mutliple models are provided, the contribution scores will be averaged across all models. + + These scores can then be plotted to visualize the importance of each base in the sequence + using :func:`~crested.pl.patterns.contribution_scores`. + + Parameters + ---------- + input + Input data to calculate the contribution scores for. Can be a (list of) sequence(s), a (list of) region name(s), a matrix of one hot encodings (N, L, 4), or an AnnData object with region names as its var_names. + model + A (list of) trained keras model(s) to calculate the contribution scores for. + class_names + List of class names to calculate the contribution scores for (should match anndata.obs_names) + If the list is empty, the contribution scores for the 'combined' class will be calculated. + all_class_names + List of all class names in the model. Usually obtained with `list(anndata.obs_names)`. + method + Method to use for calculating the contribution scores. + Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. + genome + Path to the genome file. Required if input is an anndata object or region names. + disable_tqdm + Boolean for disabling the plotting progress of calculations using tqdm. + + Returns + ------- + Contribution scores (N, C, L, 4) and one-hot encoded sequences (N, L, 4). + + See Also + -------- + crested.pl.patterns.contribution_scores + """ + if isinstance(class_names, str): + class_names = [class_names] + + if not set(class_names).issubset(all_class_names): + raise ValueError( + "class_names should be a subset of all_class_names. Use list(anndata.obs_names) to get all class names." + ) + + if len(class_names) > 0: + n_classes = len(class_names) + class_indices = [ + all_class_names.index(class_name) for class_name in class_names + ] + else: + logger.warning( + "No class names provided. Calculating contribution scores for the 'combined' class." + ) + n_classes = 1 # 'combined' class + class_indices = [None] + + input_sequences = _transform_input(input, genome) + + logger.info( + f"Calculating contribution scores for {n_classes} class(es) and {input_sequences.shape[0]} region(s)." + ) + if not isinstance(model, list): + model = [model] + N, L, D = input_sequences.shape + + # Initialize list to collect scores from each model + scores_per_model = [] + + # Iterate over models + for m in tqdm(model, desc="Model", disable=disable_tqdm): + # Initialize scores for this model + scores = np.zeros((N, n_classes, L, D)) # Shape: (N, C, L, 4) + + for i, class_index in enumerate(class_indices): + # Initialize the explainer for the current model and class index + explainer = Explainer(m, class_index=class_index) + + # Calculate contribution scores based on the selected method + if method == "integrated_grad": + scores[:, i, :, :] = explainer.integrated_grad( + input_sequences, + baseline_type="zeros", + ) + elif method == "mutagenesis": + scores[:, i, :, :] = explainer.mutagenesis( + input_sequences, + class_index=class_index, + ) + elif method == "expected_integrated_grad": + scores[:, i, :, :] = explainer.expected_integrated_grad( + input_sequences, + num_baseline=25, + ) + else: + raise ValueError(f"Unsupported method: {method}") + + # Collect scores from this model + scores_per_model.append(scores) + + # Average the scores across models + averaged_scores = np.mean(scores_per_model, axis=0) # Shape: (N, C, L, 4) + + return averaged_scores, input_sequences diff --git a/tests/test_refactor.py b/tests/test_refactor.py index 3c7e192..c7cf83e 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -4,7 +4,13 @@ import numpy as np import pytest -from crested.tl import Crested, get_embeddings, predict, score_gene_locus +from crested.tl import ( + Crested, + contribution_scores, + get_embeddings, + predict, + score_gene_locus, +) from crested.tl.data import AnnDataModule np.random.seed(42) @@ -116,3 +122,84 @@ def test_score_gene_locus(crested_object, adata, keras_model, genome): assert min_loc == ref_min_loc, "Minimum location is not equal." assert max_loc == ref_max_loc, "Maximum location is not equal." assert tss_pos == ref_tss_pos, "TSS position is not equal." + + +def test_contribution_scores_region(crested_object, adata, keras_model, genome): + region = ["chr1:1-501", "chr1:101-601"] + ( + scores, + one_hot_encoded_sequences, + ) = crested_object.calculate_contribution_scores_regions( + region_idx=region, + class_names=list(adata.obs_names)[0:2], + method="integrated_grad", + ) + scores_refactored, one_hot_encoded_sequences_refactored = contribution_scores( + input=region, + model=keras_model, + genome=genome, + class_names=list(adata.obs_names)[0:2], + all_class_names=list(adata.obs_names), + method="integrated_grad", + ) + assert np.allclose( + scores, + scores_refactored, + atol=1e-5, + ), "Scores are not equal." + assert np.array_equal( + one_hot_encoded_sequences, + one_hot_encoded_sequences_refactored, + ), "One-hot encoded sequences are not equal" + + +def test_contribution_scores_sequence(crested_object, keras_model, adata): + sequence = "ATCGA" * 100 + ( + scores, + one_hot_encoded_sequences, + ) = crested_object.calculate_contribution_scores_sequence( + sequence, + class_names=list(adata.obs_names)[0:2], + method="integrated_grad", + ) + scores_refactored, one_hot_encoded_sequences_refactored = contribution_scores( + input=sequence, + model=keras_model, + class_names=list(adata.obs_names)[0:2], + all_class_names=list(adata.obs_names), + method="integrated_grad", + ) + assert np.allclose( + scores, + scores_refactored, + atol=1e-5, + ), "Scores are not equal." + assert np.array_equal( + one_hot_encoded_sequences, + one_hot_encoded_sequences_refactored, + ), "One-hot encoded sequences are not equal" + + +def test_contribution_scores_adata(crested_object, adata, keras_model, genome): + scores, one_hot_encoded_sequences = crested_object.calculate_contribution_scores( + class_names=list(adata.obs_names)[0:2], + method="integrated_grad", + ) + scores_refactored, one_hot_encoded_sequences_refactored = contribution_scores( + input=adata, + model=keras_model, + class_names=list(adata.obs_names)[0:2], + all_class_names=list(adata.obs_names), + method="integrated_grad", + genome=genome, + ) + assert np.allclose( + scores, + scores_refactored, + atol=1e-5, + ), "Scores are not equal." + assert np.array_equal( + one_hot_encoded_sequences, + one_hot_encoded_sequences_refactored, + ), "One-hot encoded sequences are not equal" diff --git a/tests/test_tl.py b/tests/test_tl.py index 0a5ae62..6d2ee77 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -93,3 +93,42 @@ def test_score_gene_locus(keras_model, adata, genome): assert min_loc == 199000 assert max_loc == 201500 assert tss_pos == 200000 + + +def test_contribution_scores(keras_model, adata, genome): + sequence = "ATCGA" * 100 + scores, one_hot_encoded_sequences = crested.tl.contribution_scores( + sequence, + model=keras_model, + genome=genome, + class_names=list(adata.obs_names)[0:2], + all_class_names=list(adata.obs_names), + method="integrated_grad", + ) + assert scores.shape == (1, 2, 500, 4) + assert one_hot_encoded_sequences.shape == (1, 500, 4) + + sequences = ["ATCGA" * 100, "ATCGA" * 100] + scores, one_hot_encoded_sequences = crested.tl.contribution_scores( + sequences, + model=keras_model, + genome=genome, + class_names=list(adata.obs_names)[0:2], + all_class_names=list(adata.obs_names), + method="integrated_grad", + ) + assert scores.shape == (2, 2, 500, 4) + assert one_hot_encoded_sequences.shape == (2, 500, 4) + + # multiple models + models = [keras_model, keras_model] + scores, one_hot_encoded_sequences = crested.tl.contribution_scores( + sequence, + model=models, + genome=genome, + class_names=list(adata.obs_names)[0:2], + all_class_names=list(adata.obs_names), + method="integrated_grad", + ) + assert scores.shape == (1, 2, 500, 4) + assert one_hot_encoded_sequences.shape == (1, 500, 4) From e5b3b0e02e7ef9fab9890f4befd642fc5e34982a Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Thu, 28 Nov 2024 18:00:39 +0100 Subject: [PATCH 09/16] specific contrib scores (for modisco) to tooling + unit tests --- .gitignore | 3 +- src/crested/tl/__init__.py | 9 ++- src/crested/tl/_tools.py | 129 +++++++++++++++++++++++++++++++++++-- tests/conftest.py | 17 +++++ tests/test_refactor.py | 68 +++++++++++++++++++ tests/test_tl.py | 64 ++++++++++++++++++ 6 files changed, 282 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index a7df95a..582e38c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ wandb/ *slurm* tests/data/genomes tests/data/test_pipeline +tests/data/test_contribution_scores/ # Sphinx documentation _build @@ -12,4 +13,4 @@ node_modules .DS_Store ._.DS_Store docs/tutorials/mouse_biccn.ipynb -docs/tutorials/.ipynb_checkpoints \ No newline at end of file +docs/tutorials/.ipynb_checkpoints diff --git a/src/crested/tl/__init__.py b/src/crested/tl/__init__.py index 055668b..2e23056 100644 --- a/src/crested/tl/__init__.py +++ b/src/crested/tl/__init__.py @@ -7,7 +7,13 @@ from . import data, losses, metrics, zoo from ._configs import TaskConfig, default_configs from ._crested import Crested -from ._tools import contribution_scores, get_embeddings, predict, score_gene_locus +from ._tools import ( + contribution_scores, + contribution_scores_specific, + get_embeddings, + predict, + score_gene_locus, +) if find_spec("modiscolite") is not None: MODISCOLITE_AVAILABLE = True @@ -39,6 +45,7 @@ "get_embeddings", "predict", "contribution_scores", + "contribution_scores_specific", "score_gene_locus", ] diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index dd511c0..0092a67 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -342,7 +342,9 @@ def contribution_scores( all_class_names: list[str], method: str = "expected_integrated_grad", genome: os.PathLike | None = None, - disable_tqdm: bool = True, + transpose: bool = False, + output_dir: os.PathLike | None = None, + verbose: bool = True, ) -> tuple[np.ndarray, np.ndarray]: """ Calculate contribution scores based on given method for the specified inputs. @@ -368,7 +370,12 @@ def contribution_scores( Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. genome Path to the genome file. Required if input is an anndata object or region names. - disable_tqdm + transpose + Transpose the contribution scores to (N, C, 4, L) and one hots to (N, 4, L) (for compatibility with MoDISco). + output_dir + Path to the output directory to save the contribution scores and one hot seqs. + Will create a separate npz file per class. + verbose Boolean for disabling the plotting progress of calculations using tqdm. Returns @@ -399,11 +406,13 @@ def contribution_scores( n_classes = 1 # 'combined' class class_indices = [None] + # Handle all other input types input_sequences = _transform_input(input, genome) - logger.info( - f"Calculating contribution scores for {n_classes} class(es) and {input_sequences.shape[0]} region(s)." - ) + if verbose: + logger.info( + f"Calculating contribution scores for {n_classes} class(es) and {input_sequences.shape[0]} region(s)." + ) if not isinstance(model, list): model = [model] N, L, D = input_sequences.shape @@ -412,7 +421,7 @@ def contribution_scores( scores_per_model = [] # Iterate over models - for m in tqdm(model, desc="Model", disable=disable_tqdm): + for m in tqdm(model, desc="Model", disable=not verbose): # Initialize scores for this model scores = np.zeros((N, n_classes, L, D)) # Shape: (N, C, L, 4) @@ -445,4 +454,112 @@ def contribution_scores( # Average the scores across models averaged_scores = np.mean(scores_per_model, axis=0) # Shape: (N, C, L, 4) + if transpose: + averaged_scores = np.transpose(averaged_scores, (0, 1, 3, 2)) + input_sequences = np.transpose(input_sequences, (0, 2, 1)) + + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + for i, class_name in enumerate(class_names): + np.savez_compressed( + os.path.join(output_dir, f"{class_name}_contrib.npz"), + averaged_scores[:, i, :, :], + ) + np.savez_compressed( + os.path.join(output_dir, f"{class_name}_oh.npz"), + input_sequences, + ) + return averaged_scores, input_sequences + + +def contribution_scores_specific( + input: AnnData, + model: keras.Model | list[keras.Model], + genome: os.PathLike, + class_names: str | list[str] | None = None, + method: str = "expected_integrated_grad", + transpose: bool = True, + output_dir: os.PathLike | None = None, + verbose: bool = True, +) -> tuple[np.ndarray, np.ndarray]: + """ + Calculate contribution scores based on given method only for the most specific regions per class. + + Contrary to :func:`~crested.tl.contribution_scores`, this function will only calculate one set of contribution scores per region per class. + Expects the user to have ran `:func:~crested.pp.sort_and_filter_regions_on_specificity` beforehand. + + If multiple models are provided, the contribution scores will be averaged across all models. + + These scores can then be plotted to visualize the importance of each base in the sequence + using :func:`~crested.pl.patterns.contribution_scores`. + + Parameters + ---------- + input + Input anndata to calculate the contribution scores for. Should have a 'Class name' column in .var. + model + A (list of) trained keras model(s) to calculate the contribution scores for. + class_names + List of class names to calculate the contribution scores for (should match anndata.obs_names) + If None, the contribution scores for all classes will be calculated. + genome + Path to the genome file. + method + Method to use for calculating the contribution scores. + Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. + transpose + Transpose the contribution scores to (N, C, 4, L) and one hots to (N, 4, L) (for compatibility with MoDISco). + Defaults to True here since that is what modisco expects. + output_dir + Path to the output directory to save the contribution scores and one hot seqs. + Will create a separate npz file per class. + verbose + Boolean for disabling the plotting progress of calculations using tqdm. + + Returns + ------- + Contribution scores (N, 1, L, 4) and one-hot encoded sequences (N, L, 4). + Since each region is specific to a class, the contribution scores are only calculated for that class. + + See Also + -------- + crested.pl.patterns.contribution_scores + crested.pp.sort_and_filter_regions_on_specificity + """ + assert isinstance(input, AnnData), "Input should be an anndata object." + if "Class name" not in input.var.columns: + raise ValueError( + "Run 'crested.pp.sort_and_filter_regions_on_specificity' first" + ) + all_class_names = list(input.obs_names) + if isinstance(class_names, str): + class_names = [class_names] + if class_names is None: + class_names = all_class_names + if class_names == []: + raise ValueError("Can't calculate 'combined' scores for specific regions.") + all_scores = [] + all_one_hots = [] + + for class_name in class_names: + class_regions = input.var[input.var["Class name"] == class_name].index.tolist() + scores, one_hots = contribution_scores( + input=class_regions, + model=model, + class_names=[class_name], + all_class_names=all_class_names, + method=method, + genome=genome, + verbose=verbose, + output_dir=output_dir, + transpose=transpose, + ) + all_scores.append(scores) + all_one_hots.append(one_hots) + + # Concatenate results across all classes + return ( + np.concatenate(all_scores, axis=0), + np.concatenate(all_one_hots, axis=0), + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9f065c3..083e47a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,10 +4,16 @@ import genomepy import keras +import numpy as np import pytest +import crested + from ._utils import create_anndata_with_regions +np.random.seed(42) +keras.utils.set_random_seed(42) + @pytest.fixture(scope="module") def keras_model(): @@ -52,3 +58,14 @@ def genome(): "hg38", annotation=False, provider="UCSC", genomes_dir="tests/data/genomes" ) return "tests/data/genomes/hg38/hg38.fa" + + +@pytest.fixture(scope="module") +def adata_specific(): + """Specific anndata fixture.""" + ann_data = crested.import_bigwigs( + bigwigs_folder="tests/data/test_bigwigs", + regions_file="tests/data/test_bigwigs/consensus_peaks_subset.bed", + ) + crested.pp.sort_and_filter_regions_on_specificity(ann_data, top_k=3) + return ann_data diff --git a/tests/test_refactor.py b/tests/test_refactor.py index c7cf83e..5224eac 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -1,5 +1,7 @@ """Test that ensures the outputs after the functional refactor are the same as before.""" +import os + import keras import numpy as np import pytest @@ -7,6 +9,7 @@ from crested.tl import ( Crested, contribution_scores, + contribution_scores_specific, get_embeddings, predict, score_gene_locus, @@ -16,6 +19,11 @@ np.random.seed(42) keras.utils.set_random_seed(42) +if os.environ["KERAS_BACKEND"] == "tensorflow": + import tensorflow as tf + + tf.config.experimental.enable_op_determinism() + @pytest.fixture(scope="module") def crested_object(keras_model, adata, genome): @@ -34,6 +42,23 @@ def crested_object(keras_model, adata, genome): return crested_object +@pytest.fixture(scope="module") +def crested_object_specific(keras_model, adata_specific, genome): + anndatamodule = AnnDataModule( + adata_specific, + genome_file=genome, + batch_size=32, + always_reverse_complement=False, + deterministic_shift=False, + shuffle=False, + ) + crested_object = Crested( + data=anndatamodule, + ) + crested_object.model = keras_model + return crested_object + + def test_predict_adata(adata, crested_object, keras_model, genome): crested_object_preds = crested_object.predict() refactored_preds = predict(adata, keras_model, genome) @@ -203,3 +228,46 @@ def test_contribution_scores_adata(crested_object, adata, keras_model, genome): one_hot_encoded_sequences, one_hot_encoded_sequences_refactored, ), "One-hot encoded sequences are not equal" + + +def test_contribution_scores_modisco( + crested_object_specific, adata_specific, keras_model, genome +): + class_names = list(adata_specific.obs_names)[0:2] + crested_object_specific.tfmodisco_calculate_and_save_contribution_scores( + adata_specific, + class_names=class_names, + method="integrated_grad", + output_dir="tests/data/test_contribution_scores", + ) + # load scores and oh + scores = np.load( + f"tests/data/test_contribution_scores/{class_names[0]}_contrib.npz" + )["arr_0"] + one_hots = np.load(f"tests/data/test_contribution_scores/{class_names[0]}_oh.npz")[ + "arr_0" + ] + _, _ = contribution_scores_specific( + input=adata_specific, + model=keras_model, + class_names=class_names, + method="integrated_grad", + genome=genome, + output_dir="tests/data/test_contribution_scores", + transpose=True, + ) + scores_refactored = np.load( + f"tests/data/test_contribution_scores/{class_names[0]}_contrib.npz" + )["arr_0"] + one_hot_encoded_sequences_refactored = np.load( + f"tests/data/test_contribution_scores/{class_names[0]}_oh.npz" + )["arr_0"] + assert np.allclose( + scores, + scores_refactored, + atol=1e-5, + ), "Scores are not equal." + assert np.array_equal( + one_hots, + one_hot_encoded_sequences_refactored, + ), "One-hot encoded sequences are not equal" diff --git a/tests/test_tl.py b/tests/test_tl.py index 6d2ee77..3287c28 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -1,5 +1,7 @@ """Test tl module.""" +import os + import numpy as np import pytest @@ -132,3 +134,65 @@ def test_contribution_scores(keras_model, adata, genome): ) assert scores.shape == (1, 2, 500, 4) assert one_hot_encoded_sequences.shape == (1, 500, 4) + + +def test_contribution_scores_specific(keras_model, adata, adata_specific, genome): + with pytest.raises(ValueError): + # class names can't be empty for specific + crested.tl.contribution_scores_specific( + input=adata_specific, + model=keras_model, + class_names=[], # combined class + genome=genome, + method="integrated_grad", + transpose=True, + verbose=False, + ) + with pytest.raises(ValueError): + # requires a specific anndata + crested.tl.contribution_scores_specific( + input=adata, + model=keras_model, + genome=genome, + method="integrated_grad", + transpose=True, + verbose=False, + ) + scores, one_hots = crested.tl.contribution_scores_specific( + input=adata_specific, + model=keras_model, + genome=genome, + method="integrated_grad", + transpose=False, + verbose=False, + ) + assert scores.shape == (6, 1, 500, 4) + assert one_hots.shape == (6, 500, 4) + + # test multiple models and subsetting class names + scores, one_hots = crested.tl.contribution_scores_specific( + input=adata_specific, + model=[keras_model, keras_model], + genome=genome, + method="integrated_grad", + class_names=list(adata_specific.obs_names)[0], + transpose=True, + verbose=False, + ) + assert scores.shape == (3, 1, 4, 500) + assert one_hots.shape == (3, 4, 500) + + # test saving + class_names = list(adata_specific.obs_names) + scores, one_hots = crested.tl.contribution_scores_specific( + input=adata_specific, + model=keras_model, + genome=genome, + method="integrated_grad", + transpose=False, + verbose=False, + output_dir="tests/data/test_contribution_scores", + ) + assert os.path.exists( + f"tests/data/test_contribution_scores/{class_names[0]}_contrib.npz" + ) From 2ffa38bd75bc332c13db078ae152cb81a81f9370 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Tue, 10 Dec 2024 13:25:06 +0100 Subject: [PATCH 10/16] contribution scores specific --- src/crested/tl/__init__.py | 4 +- src/crested/tl/_tools.py | 116 +++++++++++++++++-------------------- tests/test_refactor.py | 32 +++++----- tests/test_tl.py | 29 +++++----- 4 files changed, 87 insertions(+), 94 deletions(-) diff --git a/src/crested/tl/__init__.py b/src/crested/tl/__init__.py index 2e23056..0b55614 100644 --- a/src/crested/tl/__init__.py +++ b/src/crested/tl/__init__.py @@ -10,7 +10,7 @@ from ._tools import ( contribution_scores, contribution_scores_specific, - get_embeddings, + extract_layer_embeddings, predict, score_gene_locus, ) @@ -42,7 +42,7 @@ "TaskConfig", "default_configs", "Crested", - "get_embeddings", + "extract_layer_embeddings", "predict", "contribution_scores", "contribution_scores_specific", diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index 0092a67..77031aa 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -11,10 +11,7 @@ from loguru import logger from tqdm import tqdm -from crested.utils import ( - fetch_sequences, - one_hot_encode_sequence, -) +from crested.utils import fetch_sequences, one_hot_encode_sequence if os.environ["KERAS_BACKEND"] == "tensorflow": from crested.tl._explainer_tf import Explainer @@ -111,8 +108,8 @@ def _transform_input(input, genome: os.PathLike | None = None) -> np.ndarray: return one_hot_data -def get_embeddings( - input: str | list[str] | np.array | AnnData, +def extract_layer_embeddings( + input: str | list[str] | np.ndarray | AnnData, model: keras.Model, layer_name: str, genome: os.PathLike | None = None, @@ -204,8 +201,7 @@ def predict( def score_gene_locus( gene_locus: str, - all_class_names: list[str], - class_name: str, + target_idx: int, model: keras.Model | list[keras.Model], genome: os.PathLike, strand: str = "+", @@ -225,10 +221,9 @@ def score_gene_locus( gene_locus The gene locus to score in the format 'chr:start-end'. Start is the TSS for + strand and TES for - strand. - all_class_names - List of all class names in the model. Usually obtained with `list(anndata.obs_names)`. - class_name - Output class name to be used for prediction. Required to index the predictions. + target_idx + Index of the target class to score. + You can usually get this from running `list(anndata.obs_names).index(class_name)`. model A (list of) trained keras model(s) to make predictions with. genome @@ -267,6 +262,8 @@ def score_gene_locus( gene_start, gene_end = map(int, gene_locus.split("-")) # Detect window size from the model input shape + if not isinstance(target_idx, int): + raise ValueError("Target index must be an integer.") if isinstance(model, list): input_shape = model[0].input_shape else: @@ -291,12 +288,6 @@ def score_gene_locus( # Ratio to normalize the score contributions ratio = central_size / step_size - try: - idx = all_class_names.index(class_name) - except ValueError as e: - raise ValueError( - f"Class name '{class_name}' not found in all_class_names" - ) from e positions = np.arange(start_position, end_position - window_size + 1, step_size) all_regions = [ @@ -305,7 +296,7 @@ def score_gene_locus( if pos + window_size <= end_position ] predictions = predict(input=all_regions, model=model, genome=genome, **kwargs) - predictions_class = predictions[:, idx] + predictions_class = predictions[:, target_idx] # Map predictions to the score array scores = np.zeros(total_length) @@ -337,12 +328,12 @@ def score_gene_locus( def contribution_scores( input: str | list[str] | np.array | AnnData, + target_idx: int | list[int] | None, model: keras.Model | list[keras.Model], - class_names: str | list[str], - all_class_names: list[str], method: str = "expected_integrated_grad", genome: os.PathLike | None = None, transpose: bool = False, + all_class_names: list[str] | None = None, output_dir: os.PathLike | None = None, verbose: bool = True, ) -> tuple[np.ndarray, np.ndarray]: @@ -358,13 +349,13 @@ def contribution_scores( ---------- input Input data to calculate the contribution scores for. Can be a (list of) sequence(s), a (list of) region name(s), a matrix of one hot encodings (N, L, 4), or an AnnData object with region names as its var_names. + target_idx + Index/indices of the target class(es) to calculate the contribution scores for. + If this is an empty list, the contribution scores for the 'combined' class will be calculated. + If this is None, the contribution scores for all classes will be calculated. + You can get these for your classes of interest by running `list(anndata.obs_names).index(class_name)`. model A (list of) trained keras model(s) to calculate the contribution scores for. - class_names - List of class names to calculate the contribution scores for (should match anndata.obs_names) - If the list is empty, the contribution scores for the 'combined' class will be calculated. - all_class_names - List of all class names in the model. Usually obtained with `list(anndata.obs_names)`. method Method to use for calculating the contribution scores. Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. @@ -372,11 +363,13 @@ def contribution_scores( Path to the genome file. Required if input is an anndata object or region names. transpose Transpose the contribution scores to (N, C, 4, L) and one hots to (N, 4, L) (for compatibility with MoDISco). + all_class_names + Optional list of all class names in the dataset. If provided and output_dir is not None, will use these to name the output files. output_dir Path to the output directory to save the contribution scores and one hot seqs. Will create a separate npz file per class. verbose - Boolean for disabling the plotting progress of calculations using tqdm. + Boolean for disabling the logs and plotting progress of calculations using tqdm. Returns ------- @@ -386,35 +379,26 @@ def contribution_scores( -------- crested.pl.patterns.contribution_scores """ - if isinstance(class_names, str): - class_names = [class_names] - - if not set(class_names).issubset(all_class_names): - raise ValueError( - "class_names should be a subset of all_class_names. Use list(anndata.obs_names) to get all class names." - ) - - if len(class_names) > 0: - n_classes = len(class_names) - class_indices = [ - all_class_names.index(class_name) for class_name in class_names - ] - else: - logger.warning( - "No class names provided. Calculating contribution scores for the 'combined' class." - ) - n_classes = 1 # 'combined' class - class_indices = [None] + if not isinstance(model, list): + model = [model] + if isinstance(target_idx, int): + target_idx = [target_idx] + elif target_idx is None: + target_idx = list(range(0, model[0].output_shape[-1])) + elif target_idx == []: + if verbose: + logger.info( + "No class indices provided. Calculating contribution scores for the 'combined' class." + ) + target_idx = [None] + n_classes = len(target_idx) - # Handle all other input types input_sequences = _transform_input(input, genome) if verbose: logger.info( f"Calculating contribution scores for {n_classes} class(es) and {input_sequences.shape[0]} region(s)." ) - if not isinstance(model, list): - model = [model] N, L, D = input_sequences.shape # Initialize list to collect scores from each model @@ -425,7 +409,7 @@ def contribution_scores( # Initialize scores for this model scores = np.zeros((N, n_classes, L, D)) # Shape: (N, C, L, 4) - for i, class_index in enumerate(class_indices): + for i, class_index in enumerate(target_idx): # Initialize the explainer for the current model and class index explainer = Explainer(m, class_index=class_index) @@ -460,10 +444,11 @@ def contribution_scores( if output_dir is not None: os.makedirs(output_dir, exist_ok=True) - for i, class_name in enumerate(class_names): + for target_id in target_idx: + class_name = all_class_names[target_id] if all_class_names else target_id np.savez_compressed( os.path.join(output_dir, f"{class_name}_contrib.npz"), - averaged_scores[:, i, :, :], + averaged_scores[:, target_id, :, :], ) np.savez_compressed( os.path.join(output_dir, f"{class_name}_oh.npz"), @@ -475,9 +460,9 @@ def contribution_scores( def contribution_scores_specific( input: AnnData, + target_idx: int | list[int] | None, model: keras.Model | list[keras.Model], genome: os.PathLike, - class_names: str | list[str] | None = None, method: str = "expected_integrated_grad", transpose: bool = True, output_dir: os.PathLike | None = None, @@ -498,11 +483,13 @@ def contribution_scores_specific( ---------- input Input anndata to calculate the contribution scores for. Should have a 'Class name' column in .var. + target_idx + Index/indices of the target class(es) to calculate the contribution scores for. + If this is an empty list, the contribution scores for the 'combined' class will be calculated. + If this is None, the contribution scores for all classes will be calculated. + You can get these for your classes of interest by running `list(anndata.obs_names).index(class_name)`. model A (list of) trained keras model(s) to calculate the contribution scores for. - class_names - List of class names to calculate the contribution scores for (should match anndata.obs_names) - If None, the contribution scores for all classes will be calculated. genome Path to the genome file. method @@ -533,26 +520,27 @@ def contribution_scores_specific( "Run 'crested.pp.sort_and_filter_regions_on_specificity' first" ) all_class_names = list(input.obs_names) - if isinstance(class_names, str): - class_names = [class_names] - if class_names is None: - class_names = all_class_names - if class_names == []: + if target_idx == []: raise ValueError("Can't calculate 'combined' scores for specific regions.") + if target_idx is None: + target_idx = list(range(0, len(all_class_names))) + if not isinstance(target_idx, list): + target_idx = [target_idx] all_scores = [] all_one_hots = [] - for class_name in class_names: + for target_id in target_idx: + class_name = all_class_names[target_id] class_regions = input.var[input.var["Class name"] == class_name].index.tolist() scores, one_hots = contribution_scores( input=class_regions, + target_idx=target_id, model=model, - class_names=[class_name], - all_class_names=all_class_names, method=method, genome=genome, verbose=verbose, output_dir=output_dir, + all_class_names=all_class_names, transpose=transpose, ) all_scores.append(scores) diff --git a/tests/test_refactor.py b/tests/test_refactor.py index 5224eac..fef7d02 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -10,7 +10,7 @@ Crested, contribution_scores, contribution_scores_specific, - get_embeddings, + extract_layer_embeddings, predict, score_gene_locus, ) @@ -91,11 +91,11 @@ def test_predict_regions(crested_object, keras_model, genome): ), "Region predictions are not equal" -def test_get_embeddings(adata, crested_object, keras_model, genome): +def test_extract_layer_embeddings(adata, crested_object, keras_model, genome): crested_object_embeddings = crested_object.get_embeddings( layer_name="denseblock_dense" ) - refactored_embeddings = get_embeddings( + refactored_embeddings = extract_layer_embeddings( input=adata, model=keras_model, genome=genome, @@ -129,8 +129,7 @@ def test_score_gene_locus(crested_object, adata, keras_model, genome): ref_tss_pos, ) = score_gene_locus( gene_locus=f"{chrom_name}:{gene_start}-{gene_end}", - all_class_names=list(adata.obs_names), - class_name=list(adata.obs_names)[0], + target_idx=0, model=keras_model, genome=genome, downstream=2000, @@ -161,10 +160,9 @@ def test_contribution_scores_region(crested_object, adata, keras_model, genome): ) scores_refactored, one_hot_encoded_sequences_refactored = contribution_scores( input=region, + target_idx=[0, 1], model=keras_model, genome=genome, - class_names=list(adata.obs_names)[0:2], - all_class_names=list(adata.obs_names), method="integrated_grad", ) assert np.allclose( @@ -191,8 +189,7 @@ def test_contribution_scores_sequence(crested_object, keras_model, adata): scores_refactored, one_hot_encoded_sequences_refactored = contribution_scores( input=sequence, model=keras_model, - class_names=list(adata.obs_names)[0:2], - all_class_names=list(adata.obs_names), + target_idx=[0, 1], method="integrated_grad", ) assert np.allclose( @@ -213,9 +210,8 @@ def test_contribution_scores_adata(crested_object, adata, keras_model, genome): ) scores_refactored, one_hot_encoded_sequences_refactored = contribution_scores( input=adata, + target_idx=[0, 1], model=keras_model, - class_names=list(adata.obs_names)[0:2], - all_class_names=list(adata.obs_names), method="integrated_grad", genome=genome, ) @@ -233,6 +229,8 @@ def test_contribution_scores_adata(crested_object, adata, keras_model, genome): def test_contribution_scores_modisco( crested_object_specific, adata_specific, keras_model, genome ): + import shutil + class_names = list(adata_specific.obs_names)[0:2] crested_object_specific.tfmodisco_calculate_and_save_contribution_scores( adata_specific, @@ -247,10 +245,11 @@ def test_contribution_scores_modisco( one_hots = np.load(f"tests/data/test_contribution_scores/{class_names[0]}_oh.npz")[ "arr_0" ] - _, _ = contribution_scores_specific( + shutil.rmtree("tests/data/test_contribution_scores") # ensure different outputs + scores_ref_output, _ = contribution_scores_specific( input=adata_specific, + target_idx=[0, 1], model=keras_model, - class_names=class_names, method="integrated_grad", genome=genome, output_dir="tests/data/test_contribution_scores", @@ -262,6 +261,8 @@ def test_contribution_scores_modisco( one_hot_encoded_sequences_refactored = np.load( f"tests/data/test_contribution_scores/{class_names[0]}_oh.npz" )["arr_0"] + print(scores.shape) + print(scores_refactored.shape) assert np.allclose( scores, scores_refactored, @@ -271,3 +272,8 @@ def test_contribution_scores_modisco( one_hots, one_hot_encoded_sequences_refactored, ), "One-hot encoded sequences are not equal" + assert np.allclose( + scores_ref_output[:3, 0, :, :], + scores, + atol=1e-5, + ), "Scores are not equal." diff --git a/tests/test_tl.py b/tests/test_tl.py index 3287c28..63f3a51 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -52,14 +52,14 @@ def test_input_transform(genome): assert _transform_input(np.array([[[1, 0, 0, 0]]])).shape == (1, 1, 4) -def test_get_embeddings(keras_model, genome): +def test_extract_layer_embeddings(keras_model, genome): input = "ATCGA" * 100 - embeddings = crested.tl.get_embeddings( + embeddings = crested.tl.extract_layer_embeddings( input, keras_model, genome=genome, layer_name="denseblock_dense" ) assert embeddings.shape == (1, 8) input = ["ATCGA" * 100, "ATCGA" * 100] - embeddings = crested.tl.get_embeddings( + embeddings = crested.tl.extract_layer_embeddings( input, keras_model, layer_name="denseblock_dense" ) assert embeddings.shape == (2, 8) @@ -82,8 +82,7 @@ def test_score_gene_locus(keras_model, adata, genome): gene_locus = "chr1:200000-200500" scores, coordinates, min_loc, max_loc, tss_pos = crested.tl.score_gene_locus( gene_locus=gene_locus, - all_class_names=list(adata.obs_names), - class_name=list(adata.obs_names)[0], + target_idx=1, model=keras_model, genome=genome, upstream=1000, @@ -97,14 +96,13 @@ def test_score_gene_locus(keras_model, adata, genome): assert tss_pos == 200000 -def test_contribution_scores(keras_model, adata, genome): +def test_contribution_scores(keras_model, genome): sequence = "ATCGA" * 100 scores, one_hot_encoded_sequences = crested.tl.contribution_scores( sequence, + target_idx=[0, 1], model=keras_model, genome=genome, - class_names=list(adata.obs_names)[0:2], - all_class_names=list(adata.obs_names), method="integrated_grad", ) assert scores.shape == (1, 2, 500, 4) @@ -113,23 +111,21 @@ def test_contribution_scores(keras_model, adata, genome): sequences = ["ATCGA" * 100, "ATCGA" * 100] scores, one_hot_encoded_sequences = crested.tl.contribution_scores( sequences, + target_idx=0, model=keras_model, genome=genome, - class_names=list(adata.obs_names)[0:2], - all_class_names=list(adata.obs_names), method="integrated_grad", ) - assert scores.shape == (2, 2, 500, 4) + assert scores.shape == (2, 1, 500, 4) assert one_hot_encoded_sequences.shape == (2, 500, 4) # multiple models models = [keras_model, keras_model] scores, one_hot_encoded_sequences = crested.tl.contribution_scores( sequence, + target_idx=[0, 1], model=models, genome=genome, - class_names=list(adata.obs_names)[0:2], - all_class_names=list(adata.obs_names), method="integrated_grad", ) assert scores.shape == (1, 2, 500, 4) @@ -141,8 +137,8 @@ def test_contribution_scores_specific(keras_model, adata, adata_specific, genome # class names can't be empty for specific crested.tl.contribution_scores_specific( input=adata_specific, + target_idx=[], # combined class model=keras_model, - class_names=[], # combined class genome=genome, method="integrated_grad", transpose=True, @@ -152,6 +148,7 @@ def test_contribution_scores_specific(keras_model, adata, adata_specific, genome # requires a specific anndata crested.tl.contribution_scores_specific( input=adata, + target_idx=None, model=keras_model, genome=genome, method="integrated_grad", @@ -160,6 +157,7 @@ def test_contribution_scores_specific(keras_model, adata, adata_specific, genome ) scores, one_hots = crested.tl.contribution_scores_specific( input=adata_specific, + target_idx=None, model=keras_model, genome=genome, method="integrated_grad", @@ -172,10 +170,10 @@ def test_contribution_scores_specific(keras_model, adata, adata_specific, genome # test multiple models and subsetting class names scores, one_hots = crested.tl.contribution_scores_specific( input=adata_specific, + target_idx=1, model=[keras_model, keras_model], genome=genome, method="integrated_grad", - class_names=list(adata_specific.obs_names)[0], transpose=True, verbose=False, ) @@ -186,6 +184,7 @@ def test_contribution_scores_specific(keras_model, adata, adata_specific, genome class_names = list(adata_specific.obs_names) scores, one_hots = crested.tl.contribution_scores_specific( input=adata_specific, + target_idx=None, model=keras_model, genome=genome, method="integrated_grad", From adc37adc1e9cb5669cdd6a5f5e5c07d1abba1dde Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 11 Dec 2024 13:25:17 +0100 Subject: [PATCH 11/16] small docstring fix --- src/crested/tl/_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index 77031aa..dde24ca 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -164,7 +164,7 @@ def predict( input Input data to make predictions on. Can be a (list of) sequence(s), a (list of) region name(s), a matrix of one hot encodings (N, L, 4), or an AnnData object with region names as its var_names. model - A (list of) trained keras models to make predictions with. + A (list of) trained keras model(s) to make predictions with. genome Path to the genome file. Required if input is an anndata object or region names. **kwargs From 86fb03e92c3e97863bc67e7699df3ddc1dae86c8 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 11 Dec 2024 15:20:19 +0100 Subject: [PATCH 12/16] update all tests and tools to use new Genome --- src/crested/tl/_tools.py | 58 ++++++++++++++++---------------------- tests/conftest.py | 14 ++++++++-- tests/test_data.py | 60 +++++++++++++++++++--------------------- tests/test_refactor.py | 4 +-- tests/test_tl.py | 2 +- 5 files changed, 67 insertions(+), 71 deletions(-) diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index dde24ca..d9a2a3b 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -11,7 +11,8 @@ from loguru import logger from tqdm import tqdm -from crested.utils import fetch_sequences, one_hot_encode_sequence +from crested._genome import Genome, _resolve_genome +from crested.utils import one_hot_encode_sequence if os.environ["KERAS_BACKEND"] == "tensorflow": from crested.tl._explainer_tf import Explainer @@ -19,19 +20,18 @@ from crested.tl._explainer_torch import Explainer -def _detect_input_type(input): +def _detect_input_type(input: str | list[str] | np.array | AnnData) -> str: """ Detect the type of input provided. Parameters ---------- - input : str | list[str] | np.array | AnnData + input The input to detect the type of. Returns ------- - str - One of ['sequence', 'region', 'anndata', 'array'], indicating the input type. + One of ['sequence', 'region', 'anndata', 'array'], indicating the input type. """ dna_pattern = re.compile("^[ACGTNacgtn]+$") if isinstance(input, AnnData): @@ -67,16 +67,16 @@ def _detect_input_type(input): ) -def _transform_input(input, genome: os.PathLike | None = None) -> np.ndarray: +def _transform_input(input, genome: Genome | os.PathLike | None = None) -> np.ndarray: """ Transform the input into a one-hot encoded matrix based on its type. Parameters ---------- - input : str | list[str] | np.array | AnnData + input Input data to preprocess. Can be a sequence, list of sequences, region, list of regions, or an AnnData object. - genome : str | None - Path to the genome file. Required if input is a region or AnnData. + genome + Genome or Path to the genome file. Required if no genome is registered and input is a region or AnnData. Returns ------- @@ -85,16 +85,13 @@ def _transform_input(input, genome: os.PathLike | None = None) -> np.ndarray: input_type = _detect_input_type(input) if input_type == "anndata": - if genome is None: - raise ValueError( - "Genome file is required to fetch sequences for regions in AnnData." - ) + genome = _resolve_genome(genome) regions = list(input.var_names) - sequences = fetch_sequences(regions, genome) + sequences = [genome.fetch(region=region) for region in regions] elif input_type == "region": - if genome is None: - raise ValueError("Genome file is required to fetch sequences for regions.") - sequences = fetch_sequences(input, genome) + genome = _resolve_genome(genome) + regions = input if isinstance(input, list) else [input] + sequences = [genome.fetch(region=region) for region in regions] elif input_type == "sequence": sequences = input if isinstance(input, list) else [input] elif input_type == "array": @@ -112,7 +109,7 @@ def extract_layer_embeddings( input: str | list[str] | np.ndarray | AnnData, model: keras.Model, layer_name: str, - genome: os.PathLike | None = None, + genome: Genome | os.PathLike | None = None, **kwargs, ) -> np.ndarray: """ @@ -127,7 +124,7 @@ def extract_layer_embeddings( layer_name The name of the layer from which to extract the embeddings. genome - Path to the genome file. Required if input is an anndata object or region names. + Genome or path to the genome fasta. Required if no genome is registered and input is an anndata object or region names. **kwargs Additional keyword arguments to pass to the keras.Model.predict method. @@ -151,7 +148,7 @@ def extract_layer_embeddings( def predict( input: str | list[str] | np.array | AnnData, model: keras.Model | list[keras.Model], - genome: os.PathLike | None = None, + genome: Genome | os.PathLike | None = None, **kwargs, ) -> None | np.ndarray: """ @@ -166,7 +163,7 @@ def predict( model A (list of) trained keras model(s) to make predictions with. genome - Path to the genome file. Required if input is an anndata object or region names. + Genome or path to the genome file. Required if no genome is registered and input is an anndata object or region names. **kwargs Additional keyword arguments to pass to the keras.Model.predict method. @@ -203,7 +200,7 @@ def score_gene_locus( gene_locus: str, target_idx: int, model: keras.Model | list[keras.Model], - genome: os.PathLike, + genome: Genome | os.PathLike | None = None, strand: str = "+", upstream: int = 50000, downstream: int = 10000, @@ -227,7 +224,7 @@ def score_gene_locus( model A (list of) trained keras model(s) to make predictions with. genome - Path to the genome file. + Genome or path to the genome file. Required if no genome is registered. strand '+' for positive strand, '-' for negative strand. Default '+'. upstream @@ -331,7 +328,7 @@ def contribution_scores( target_idx: int | list[int] | None, model: keras.Model | list[keras.Model], method: str = "expected_integrated_grad", - genome: os.PathLike | None = None, + genome: Genome | os.PathLike | None = None, transpose: bool = False, all_class_names: list[str] | None = None, output_dir: os.PathLike | None = None, @@ -360,7 +357,7 @@ def contribution_scores( Method to use for calculating the contribution scores. Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. genome - Path to the genome file. Required if input is an anndata object or region names. + Genome or path to the genome fasta. Required if no genome is registered and input is an anndata object or region names. transpose Transpose the contribution scores to (N, C, 4, L) and one hots to (N, 4, L) (for compatibility with MoDISco). all_class_names @@ -401,19 +398,13 @@ def contribution_scores( ) N, L, D = input_sequences.shape - # Initialize list to collect scores from each model scores_per_model = [] - - # Iterate over models for m in tqdm(model, desc="Model", disable=not verbose): - # Initialize scores for this model scores = np.zeros((N, n_classes, L, D)) # Shape: (N, C, L, 4) for i, class_index in enumerate(target_idx): - # Initialize the explainer for the current model and class index explainer = Explainer(m, class_index=class_index) - # Calculate contribution scores based on the selected method if method == "integrated_grad": scores[:, i, :, :] = explainer.integrated_grad( input_sequences, @@ -432,7 +423,6 @@ def contribution_scores( else: raise ValueError(f"Unsupported method: {method}") - # Collect scores from this model scores_per_model.append(scores) # Average the scores across models @@ -462,7 +452,7 @@ def contribution_scores_specific( input: AnnData, target_idx: int | list[int] | None, model: keras.Model | list[keras.Model], - genome: os.PathLike, + genome: Genome | os.PathLike | None = None, method: str = "expected_integrated_grad", transpose: bool = True, output_dir: os.PathLike | None = None, @@ -491,7 +481,7 @@ def contribution_scores_specific( model A (list of) trained keras model(s) to calculate the contribution scores for. genome - Path to the genome file. + Genome or Path to the genome file. Required if no genome is registered. method Method to use for calculating the contribution scores. Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. diff --git a/tests/conftest.py b/tests/conftest.py index 083e47a..848abfc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,8 +51,8 @@ def adata(): @pytest.fixture(scope="module") -def genome(): - """Genome fixture.""" +def genome_path(): + """Genome path fixture.""" if not os.path.exists("tests/data/genomes/hg38.fa"): genomepy.install_genome( "hg38", annotation=False, provider="UCSC", genomes_dir="tests/data/genomes" @@ -60,6 +60,16 @@ def genome(): return "tests/data/genomes/hg38/hg38.fa" +@pytest.fixture(scope="module") +def genome(genome_path): + """Genome fixture.""" + genome = crested.Genome( + fasta=genome_path, + chrom_sizes="tests/data/genomes/hg38/hg38.fa.sizes", + ) + return genome + + @pytest.fixture(scope="module") def adata_specific(): """Specific anndata fixture.""" diff --git a/tests/test_data.py b/tests/test_data.py index 632f1d1..a1962a4 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -16,16 +16,15 @@ def log_capture(level="WARNING"): logger.remove(handler_id) -def test_genome_persistence(genome): +def test_genome_persistence(genome_path): """Test that the genome object is correctly stored.""" import crested - fasta_file = genome # check that does not yet exist assert crested._conf.genome is None genome = crested.Genome( - fasta=fasta_file, + fasta=genome_path, chrom_sizes={"chr1": 1000, "chr2": 2000}, ) crested.register_genome(genome) @@ -44,15 +43,14 @@ def test_no_genome_fasta(): ) -def test_import_beds_with_genome(genome): +def test_import_beds_with_genome(genome_path): """Test that import_beds uses genome chromsizes.""" import crested - fasta_file = genome # Scenario 1: Genome registered with chromsizes provided with log_capture(level="WARNING") as messages: genome = crested.Genome( - fasta=fasta_file, + fasta=genome_path, chrom_sizes="tests/data/test.chrom.sizes", ) crested.register_genome(genome) @@ -106,72 +104,70 @@ def test_import_beds_with_genome(genome): warning_text_filtered not in msg for msg in warning_messages ), "Warning about filtered regions was unexpectedly raised." -def test_genome_fetch(genome): + +def test_genome_fetch(genome_path): """Test reading the genome.""" import crested - fasta_file = genome genome = crested.Genome( - fasta=fasta_file, + fasta=genome_path, chrom_sizes="tests/data/test.chrom.sizes", ) - seq = genome.fetch('chr1', 10000, 10100) + seq = genome.fetch("chr1", 10000, 10100) assert len(seq) == 100 -def test_genome_fetch_region(genome): +def test_genome_fetch_region(genome_path): """Test reading the genome with a region string.""" import crested - fasta_file = genome genome = crested.Genome( - fasta=fasta_file, + fasta=genome_path, chrom_sizes="tests/data/test.chrom.sizes", ) - seq1 = genome.fetch('chr1', 10000, 10100) - seq2 = genome.fetch(region = 'chr1:10000-10100') + seq1 = genome.fetch("chr1", 10000, 10100) + seq2 = genome.fetch(region="chr1:10000-10100") assert seq1 == seq2 -def test_genome_fetch_reverse(genome): + +def test_genome_fetch_reverse(genome_path): """Test reading the genome on the negative strand.""" import crested - fasta_file = genome genome = crested.Genome( - fasta=fasta_file, + fasta=genome_path, chrom_sizes="tests/data/test.chrom.sizes", ) - seq_forward = genome.fetch('chr1', 10000, 10100) - seq_rev = genome.fetch('chr1', 10000, 10100, "-") - seq_rev_region = genome.fetch(region = 'chr1:10000-10100:-') + seq_forward = genome.fetch("chr1", 10000, 10100) + seq_rev = genome.fetch("chr1", 10000, 10100, "-") + seq_rev_region = genome.fetch(region="chr1:10000-10100:-") assert seq_rev == crested.utils.reverse_complement(seq_forward) assert seq_rev_region == seq_rev -def test_genome_fetch_mismatch(genome): + +def test_genome_fetch_mismatch(genome_path): """Test reading the genome when supplying both coordinates and a region.""" import crested - fasta_file = genome genome = crested.Genome( - fasta=fasta_file, + fasta=genome_path, chrom_sizes="tests/data/test.chrom.sizes", ) - seq = genome.fetch('chr1', 10000, 10100, region = 'chr1:10000-10200') + seq = genome.fetch("chr1", 10000, 10100, region="chr1:10000-10200") assert len(seq) == 100 -def test_genome_fetch_missing(genome): + +def test_genome_fetch_missing(genome_path): """Test reading the genome when not supplying all information""" import crested - fasta_file = genome genome = crested.Genome( - fasta=fasta_file, + fasta=genome_path, chrom_sizes="tests/data/test.chrom.sizes", ) with pytest.raises(ValueError): - genome.fetch('chr1', 10000) + genome.fetch("chr1", 10000) with pytest.raises(ValueError): - genome.fetch('chr1', end = 10100) + genome.fetch("chr1", end=10100) with pytest.raises(ValueError): - genome.fetch('chr1', 10000, region = 'chr1:10000-10200') - + genome.fetch("chr1", 10000, region="chr1:10000-10200") diff --git a/tests/test_refactor.py b/tests/test_refactor.py index fef7d02..574da35 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -29,7 +29,7 @@ def crested_object(keras_model, adata, genome): anndatamodule = AnnDataModule( adata, - genome_file=genome, + genome=genome, batch_size=32, always_reverse_complement=False, deterministic_shift=False, @@ -46,7 +46,7 @@ def crested_object(keras_model, adata, genome): def crested_object_specific(keras_model, adata_specific, genome): anndatamodule = AnnDataModule( adata_specific, - genome_file=genome, + genome=genome, batch_size=32, always_reverse_complement=False, deterministic_shift=False, diff --git a/tests/test_tl.py b/tests/test_tl.py index 63f3a51..f70d092 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -78,7 +78,7 @@ def test_predict(keras_model, adata, genome): assert predictions.shape == (10, 5) -def test_score_gene_locus(keras_model, adata, genome): +def test_score_gene_locus(keras_model, genome): gene_locus = "chr1:200000-200500" scores, coordinates, min_loc, max_loc, tss_pos = crested.tl.score_gene_locus( gene_locus=gene_locus, From f8d010581d46d9dd79b8d780bd60862987a707e3 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 11 Dec 2024 15:49:18 +0100 Subject: [PATCH 13/16] fix saving of contrib scores based on id --- src/crested/tl/_tools.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/crested/tl/_tools.py b/src/crested/tl/_tools.py index d9a2a3b..76d6134 100644 --- a/src/crested/tl/_tools.py +++ b/src/crested/tl/_tools.py @@ -434,11 +434,16 @@ def contribution_scores( if output_dir is not None: os.makedirs(output_dir, exist_ok=True) - for target_id in target_idx: - class_name = all_class_names[target_id] if all_class_names else target_id + for i in range(n_classes): + target_id = target_idx[i] + class_name = ( + all_class_names[target_id] + if all_class_names + else f"class_id_{target_id}" + ) np.savez_compressed( os.path.join(output_dir, f"{class_name}_contrib.npz"), - averaged_scores[:, target_id, :, :], + averaged_scores[:, i, :, :], ) np.savez_compressed( os.path.join(output_dir, f"{class_name}_oh.npz"), From ffed21e7375bf1cdabb3662509e46b2fe6e38010 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Wed, 11 Dec 2024 17:51:40 +0100 Subject: [PATCH 14/16] add repr and acgt distrib to genome --- src/crested/_genome.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/crested/_genome.py b/src/crested/_genome.py index 8f4281f..ee80350 100644 --- a/src/crested/_genome.py +++ b/src/crested/_genome.py @@ -4,6 +4,7 @@ import errno import os +import random from pathlib import Path from loguru import logger @@ -88,6 +89,7 @@ def __init__( self._annotation = None self._name = name + self._acgt = None @property def fasta(self) -> FastaFile: @@ -151,6 +153,42 @@ def name(self) -> str: return basename return self._name + @property + def acgt(self) -> list[float]: + """ + The ACGT distribution of the genome. + + Returns + ------- + The ACGT distribution as a list of floats. + """ + if self._acgt is None: + self._acgt = self._get_acgt() + return self._acgt + + def _get_acgt(self, n: int = 10000, region_length: int = 1000) -> list[float]: + """Return the ACGT distribution of the genome based on n random regions.""" + acgt = [0, 0, 0, 0] + chrom_sizes = self.chrom_sizes + # discard small chromosomes + chrom_sizes = {k: v for k, v in chrom_sizes.items() if v > region_length * 10} + chroms = list(chrom_sizes.keys()) + + for _ in range(n): + chrom = random.choice(chroms) + chrom_length = chrom_sizes[chrom] + start = random.randint(0, chrom_length - region_length) + end = start + region_length + seq = self.fasta.fetch(chrom, start, end) + + acgt[0] += seq.count("A") + acgt[1] += seq.count("C") + acgt[2] += seq.count("G") + acgt[3] += seq.count("T") + + total = sum(acgt) + return [x / total for x in acgt] + def fetch(self, chrom=None, start=None, end=None, strand="+", region=None) -> str: """ Fetch a sequence from a genomic region. @@ -196,6 +234,13 @@ def fetch(self, chrom=None, start=None, end=None, strand="+", region=None) -> st else: return seq + def __repr__(self) -> str: + """Return a string representation of the Genome object.""" + fasta_exists = self.fasta is not None + chrom_sizes_exists = self.chrom_sizes is not None + annotations_exists = self.annotation is not None + return f"Genome({self.name}, fasta={fasta_exists}, chrom_sizes={chrom_sizes_exists}, annotation={annotations_exists})" + def register_genome(genome: Genome): """ From f806975845f408e1eb8cffad376b40dbc5427952 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Thu, 12 Dec 2024 11:36:16 +0100 Subject: [PATCH 15/16] type int in genome.fetch --- src/crested/_genome.py | 47 +++++++----------------------------------- 1 file changed, 8 insertions(+), 39 deletions(-) diff --git a/src/crested/_genome.py b/src/crested/_genome.py index ee80350..093fdda 100644 --- a/src/crested/_genome.py +++ b/src/crested/_genome.py @@ -4,7 +4,6 @@ import errno import os -import random from pathlib import Path from loguru import logger @@ -89,7 +88,6 @@ def __init__( self._annotation = None self._name = name - self._acgt = None @property def fasta(self) -> FastaFile: @@ -153,43 +151,14 @@ def name(self) -> str: return basename return self._name - @property - def acgt(self) -> list[float]: - """ - The ACGT distribution of the genome. - - Returns - ------- - The ACGT distribution as a list of floats. - """ - if self._acgt is None: - self._acgt = self._get_acgt() - return self._acgt - - def _get_acgt(self, n: int = 10000, region_length: int = 1000) -> list[float]: - """Return the ACGT distribution of the genome based on n random regions.""" - acgt = [0, 0, 0, 0] - chrom_sizes = self.chrom_sizes - # discard small chromosomes - chrom_sizes = {k: v for k, v in chrom_sizes.items() if v > region_length * 10} - chroms = list(chrom_sizes.keys()) - - for _ in range(n): - chrom = random.choice(chroms) - chrom_length = chrom_sizes[chrom] - start = random.randint(0, chrom_length - region_length) - end = start + region_length - seq = self.fasta.fetch(chrom, start, end) - - acgt[0] += seq.count("A") - acgt[1] += seq.count("C") - acgt[2] += seq.count("G") - acgt[3] += seq.count("T") - - total = sum(acgt) - return [x / total for x in acgt] - - def fetch(self, chrom=None, start=None, end=None, strand="+", region=None) -> str: + def fetch( + self, + chrom: str | None = None, + start: int | None = None, + end: int | None = None, + strand: str = "+", + region: str | None = None, + ) -> str: """ Fetch a sequence from a genomic region. From daa2ec54c03633cf34cc2837cd325a277fb90b72 Mon Sep 17 00:00:00 2001 From: LukasMahieu Date: Thu, 12 Dec 2024 11:58:24 +0100 Subject: [PATCH 16/16] calculate gc distribution util function --- docs/api/utils.md | 1 + src/crested/utils/__init__.py | 1 + src/crested/utils/_utils.py | 44 +++++++++++++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/docs/api/utils.md b/docs/api/utils.md index 551df3d..07c6d4a 100644 --- a/docs/api/utils.md +++ b/docs/api/utils.md @@ -17,5 +17,6 @@ CREsted provides a few utility function to help with sequence encoding, function fetch_sequences reverse_complement permute_model + calculate_gc_distribution setup_logging ``` diff --git a/src/crested/utils/__init__.py b/src/crested/utils/__init__.py index b62cd5a..d0ba980 100644 --- a/src/crested/utils/__init__.py +++ b/src/crested/utils/__init__.py @@ -9,6 +9,7 @@ ) from ._utils import ( EnhancerOptimizer, + calculate_gc_distribution, extract_bigwig_values_per_bp, fetch_sequences, read_bigwig_region, diff --git a/src/crested/utils/_utils.py b/src/crested/utils/_utils.py index c95fdaa..bc6567c 100644 --- a/src/crested/utils/_utils.py +++ b/src/crested/utils/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import random from typing import Any, Callable import numpy as np @@ -125,7 +126,7 @@ class EnhancerOptimizer: """ Class to optimize the mutated sequence based on the original prediction. - Can be passed as the 'enhancer_optimizer' argument to :func:`crested.tl.Crested.enhancer_design_in_silico_evolution` + Can be passed as the 'enhancer_optimizer' argument to :func:`crested.tl.enhancer_design_in_silico_evolution` Parameters ---------- @@ -134,7 +135,7 @@ class EnhancerOptimizer: See Also -------- - crested.tl.Crested.enhancer_design_in_silico_evolution + crested.tl.enhancer_design_in_silico_evolution """ def __init__(self, optimize_func: Callable[..., int]) -> None: @@ -365,3 +366,42 @@ def read_bigwig_region( ).squeeze() return values, positions + + +def calculate_gc_distribution( + genome: Genome | os.PathLike, + regions: list[str], + n_regions: int | None = None, +) -> list[float]: + """ + Calculate the GC content distribution of a genome in a set of regions. + + Parameters + ---------- + genome + The genome object or path to the genome fasta file. + regions + A list of region names in the format "chr:start-end". + n_regions + Randomly sample n_regions from the regions. If None, all regions are used. + + Returns + ------- + The GC content distribution as a list of floats in order A, C, G, T. + """ + genome = _resolve_genome(genome) + + if n_regions is not None: + regions = random.sample(regions, n_regions) + + acgt = [0, 0, 0, 0] + + for region in regions: + seq = genome.fasta.fetch(region=region) + acgt[0] += seq.count("A") + acgt[1] += seq.count("C") + acgt[2] += seq.count("G") + acgt[3] += seq.count("T") + + total = sum(acgt) or 1 + return [round(a / total, 4) for a in acgt]