diff --git a/align_pangenome.nf b/align_pangenome.nf index 5b2e10a..ce5a069 100644 --- a/align_pangenome.nf +++ b/align_pangenome.nf @@ -8,7 +8,7 @@ GroovyShell shell = new GroovyShell() def helpers = shell.parse(new File("${workflow.projectDir}/helpers.gvy")) // Import sub-workflows -include { bin_metagenomes } from './modules/processes/bin_metagenomes' +include { bin_metagenomes } from './modules/bin_metagenomes' include { align_reads } from './modules/align_reads' include { find_reads } from './modules/find_reads' diff --git a/bin/bin_metagenomes.py b/bin/collect_metagenomes.py similarity index 53% rename from bin/bin_metagenomes.py rename to bin/collect_metagenomes.py index 1701ac6..128ba32 100755 --- a/bin/bin_metagenomes.py +++ b/bin/collect_metagenomes.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from collections.abc import Mapping import json +import os import anndata as ad import click import logging @@ -520,90 +521,21 @@ def calc_rel_abund(self, mods=None): ).T self.data.mod[mod].layers["prop"] = rel_abund - def compare_groups(self, category, mods=None): - assert category in self.data.obs.columns.values - - # Get the metadata values - meta: pd.Series = ( - self.data - .obs[category] - .reindex(index=self.filtered_samples) - .dropna() - ) - logger.info(f"Samples with {category} defined: {meta.shape[0]:,}") - - # We can only support the comparison of two groups (at this point) - logger.info(f"Unique values: {', '.join([str(v) for v in meta.unique()])}") - assert meta.unique().shape[0] == 2, f"{category} must contain 2 groups" - - # Convert meta to 1/0 - meta = (meta == max(meta.unique())).apply(int) - - if mods is None: - mods = ["bins", "genomes"] - - for mod in mods: - self.compare_groups_mod(meta, mod) - self.log_content() - - def compare_groups_mod(self, meta: pd.Series, mod: str): - logger.info(f"Comparing samples on the basis of {mod}") - - # Use the relative abundance table computed for this modality - logger.info(f"Comparing samples using {mod} content") - abund: pd.DataFrame = self.data.mod[mod].to_df("prop") - - self.data.mod[mod].varm[f"~{meta.name}"] = ( - self.compare_groups_single(meta, abund) - ) - - def compare_groups_single( - self, - meta: pd.Series, - abund: pd.DataFrame - ): - # At this point meta should just be 0 or 1 - assert meta.isin([0, 1]).all() - comparison_obs = meta.index.values[meta == 0] - control_obs = meta.index.values[meta == 1] - - res = ( - pd.DataFrame([ - self.mannwhitneyu(var, vals, control_obs, comparison_obs) - for var, vals in abund.items() - ]) - .set_index("index") - .reindex(index=abund.columns.values) - ) - - res = res.assign( - qvalue=multipletests(res["pvalue"].values, method=self.fdr_method)[1], - neg_log10_pvalue=lambda d: d['pvalue'].apply(np.log10) * -1, - neg_log10_qvalue=lambda d: d['qvalue'].apply(np.log10) * -1 - ) - return res - - @staticmethod - def mannwhitneyu(var: str, vals: pd.Series, control_obs, comparison_obs): - res = stats.mannwhitneyu( - vals.reindex(index=control_obs).values, - vals.reindex(index=comparison_obs).values - ) - return dict( - pvalue=res.pvalue, - statistic=res.statistic, - index=var - ) - def to_h5ad(self, output_folder): + os.makedirs(output_folder, exist_ok=True) if "filtered_out" in self.data.uns: del self.data.uns["filtered_out"] for mod in self.data.mod: - self.data.mod[mod].write_h5ad(f"{output_folder}/metagenome.{mod}.h5ad") + adata: ad.AnnData = self.data.mod[mod] + adata.obs = self.data.obs + for kw, val in self.data.uns.items(): + adata.uns[kw] = val + adata.write_h5ad(f"{output_folder}/metagenome.{mod}.h5ad") def to_csv(self, output_folder): + os.makedirs(output_folder, exist_ok=True) if "filtered_out" in self.data.uns: del self.data.uns["filtered_out"] self.to_csv_obj(self.data, f"{output_folder}/metagenome") @@ -653,28 +585,6 @@ def write_json(val, path): with open(path, "w") as handle: json.dump(val, handle, indent=4) - def plot(self, output_folder): - - # The central heatmap is the comparison of bins and genomes - # Bins extend vertically from that heatmap, - # and genomes extend horizontally. - - for bins_varm in self.data.mod["bins"].varm: - assert bins_varm.startswith("~"), f"Expected ~meta, not {bins_varm}" - bins_meta = bins_varm[1:] - - for genomes_varm in self.data.mod["genomes"].varm: - assert genomes_varm.startswith("~"), f"Expected ~meta, not {genomes_varm}" - genomes_meta = genomes_varm[1:] - - if bins_meta == genomes_meta: - output_fn = genomes_meta - output_fp = f"{output_folder}/{output_fn}" - self.write_image( - bins_meta, - output_fp, - ) - @staticmethod def log_scale(df: pd.DataFrame): @@ -697,483 +607,6 @@ def sort_index(self, df: pd.DataFrame, metric="cosine", method="average"): self.log_df(df) raise e - def write_image( - self, - meta_cname, - output_fp - ): - """" - Figure layout: - Genomes vs. Bins in the center, with bins extending - vertically and genomes extending horizontally. - - | | (2, 7): Bins Boxplot | - | | (2, 6): Bins -log10(p) | - | | (2, 5): Bins Silhouette | - | (1, 3): Sample Metadata | (2, 4): Bins Heatmap | - | | (2, 3): # Genes per Bin | | (4, 3): NNLS Residual | - | | (2, 2): Central Heatmap | (3, 2): # of Genomes | (4, 2): Genomes Heatmap | (5, 2): Genomes -log10(p) | (6, 2): Genomes Boxplot | - | | | | (4, 1): Sample Metadata | - """ - - # Relative size of rows and columns - heatmap_size = 3 - column_widths = np.array([0.5, heatmap_size, 1, heatmap_size, 1, 1]) - column_widths = list(column_widths / column_widths.sum()) - row_heights = np.array([0.5, heatmap_size, 1, heatmap_size, 1, 1, 1]) - row_heights = list(row_heights / row_heights.sum()) - - # Genomes across samples - genomes_df: pd.DataFrame = ( - self.data - .mod["genomes"] - .to_df("prop") - ) - - # Bins across samples - bins_df: pd.DataFrame = ( - self.data - .mod["bins"] - .to_df("prop") - ) - - # Sort the bins and genomes - bin_order = self.sort_index( - bins_df.T, - metric="euclidean", - method="ward" - ) - - genome_order = self.sort_index( - genomes_df.T, - metric="euclidean", - method="ward" - ) - - cols = 6 - rows = 7 - fig = make_subplots( - rows=rows, - cols=cols, - shared_xaxes=True, - shared_yaxes=True, - start_cell="bottom-left", - column_widths=column_widths, - row_heights=row_heights - ) - # Bins across Genomes - self.heatmap( - ( - self.data - .uns["group_profile"] - .T - .reindex( - index=genome_order, - columns=bin_order - ) - ), - fig, - value_label="Gene Bin Presence in Genome", - row=2, - col=2, - showscale=True, - coloraxis="coloraxis2", - colorbar_x=0.0, - colorbar_xpad=0, - colorbar_title_side="right", - colorbar_xanchor="center", - colorbar_yanchor="top", - colorbar_len=0.25, - colorbar_y=0.3 - ) - - # Number of genomes per group - self.bar( - ( - self.data - .mod["genomes"] - .var["n_genomes"] - .reindex(index=genome_order) - ), - "Number of Genomes per Group", - fig, - row=2, - col=3, - orient="h" - ) - - # Genomes across samples - sample_order = self.sort_index(genomes_df) - genomes_df = genomes_df.reindex( - columns=genome_order, - index=sample_order - ).fillna(0) - - self.heatmap( - self.log_scale(genomes_df.T), - fig, - value_label="Proportion of Aligned Reads (log10)", - row=2, - col=4, - showscale=True, - coloraxis="coloraxis3", - colorbar_x=0.5, - colorbar_xanchor="left", - colorbar_xpad=0, - colorbar_orientation="h", - colorbar_title_side="top", - colorbar_title_text="Genome Group Abundance
Proportion of Aligned Reads (log10)", - colorbar_y=0.45, - colorbar_len=0.25 - ) - - # Annotate metadata on the genomes heatmap - metadata = ( - self.data - .obs - .reindex( - columns=[meta_cname], - index=sample_order - ) - ) - - self.heatmap( - metadata.reindex(index=sample_order).T, - fig, - value_label=meta_cname, - row=1, - col=4, - colorscale=["blue", "orange"], - showscale=False - ) - - # NNLS Residual on genome abundance predictions from samples - self.bar( - ( - self.data - .mod["genomes"] - .obs["nnls_residual"] - .reindex(index=sample_order) - ), - "NNLS Residual", - fig, - row=3, - col=4 - ) - fig.for_each_yaxis( - lambda axis: axis.update(matches=None, showticklabels=True), - row=3, - col=4 - ) - - # Bars showing the -log10(qvalue) for each genome group - self.bar( - ( - self.data - .mod["genomes"] - .varm[f"~{meta_cname}"] - ["neg_log10_qvalue"] - ), - "FDR-adjusted p-value (-log10)", - fig, - row=2, - col=5, - orient="h" - ) - - # Boxplot comparing genome abundances by sample group - self.box( - genomes_df, - self.data.obs[meta_cname].reindex(index=sample_order), - meta_cname, - "Genome Group Abundance", - fig, - row=2, - col=6, - orient="h", - showlegend=False, - log=True - ) - - # Bins Information (extending vertically from the central heatmap) - - # Number of genes per bin - self.bar( - pd.Series( - self.data.uns["bin_size"] - ).reindex( - index=bin_order - ), - "Number of Genes per Bin", - fig, - row=3, - col=2 - ) - - # Bins across samples - sample_order = self.sort_index(bins_df) - bins_df = bins_df.reindex( - columns=bin_order, - index=sample_order - ).fillna(0) - - self.heatmap( - self.log_scale(bins_df), - fig, - value_label="Proportion of Aligned Reads (log10)", - row=4, - col=2, - coloraxis="coloraxis4", - showscale=True, - colorbar_x=0.35, - colorbar_xanchor="left", - colorbar_y=0.55, - colorbar_len=0.25, - colorbar_title_side="right", - colorbar_title_text="Gene Bin Abundance
Proportion of Aligned Reads (log10)", - colorbar_xpad=0 - ) - - # Metadata data on the bins heatmap - self.heatmap( - metadata.reindex(index=sample_order), - fig, - value_label=meta_cname, - row=4, - col=1, - colorscale=["blue", "orange"], - coloraxis="coloraxis", - showscale=False - ) - - # Silhouette score on bin assignments - self.bar( - ( - self.data - .mod["bins"] - .var["silhouette_score"] - .reindex(index=bin_order) - ), - "Silhouette Score", - fig, - row=5, - col=2 - ) - - # Bars showing the -log10(qvalue) for each gene bin - self.bar( - ( - self.data - .mod["bins"] - .varm[f"~{meta_cname}"] - ["neg_log10_qvalue"] - ), - "FDR-adjusted p-value (-log10)", - fig, - row=6, - col=2 - ) - - # Boxplot comparing gene bin abundances by sample group - self.box( - bins_df, - self.data.obs[meta_cname].reindex(index=sample_order), - meta_cname, - "Gene Bin Abundance", - fig, - row=7, - col=2, - log=True - ) - - logger.info(fig.layout) - title_text = f"Association with {meta_cname}" - fig.update_layout( - title_text=title_text, - title_xanchor="center", - title_x=0.5, - plot_bgcolor='rgba(255, 255, 255, 1.0)', - paper_bgcolor='rgba(255, 255, 255, 1.0)', - legend=dict( - yanchor="top", - xanchor="left", - x=1.0, - y=0.3 - ), - **{ - f"xaxis{i + cols}": dict(showticklabels=True) - for i in [2, 3, 5, 6] - }, - **{ - f"yaxis{2 + (cols * i)}": dict(showticklabels=True) - for i in [1, 2, 4, 5, 6] - } - ) - fig.write_html(f"{output_fp}.html", include_plotlyjs="cdn") - for ext in ['png', 'pdf']: - fig.write_image( - f"{output_fp}.{ext}", - width=2880, - height=1800 - ) - - @staticmethod - def heatmap( - df: pd.DataFrame, - fig, - value_label: str, - row=None, - col=None, - coloraxis="coloraxis", - colorscale=px.colors.sequential.Blues, - showlegend=False, - showscale=False, - **kwargs - ): - - if "colorbar_title_text" not in kwargs: - kwargs["colorbar_title_text"] = value_label - - fig.layout[coloraxis] = dict( - colorscale=colorscale, - showscale=showscale, - **{ - kw: val - for kw, val in kwargs.items() - if kw.startswith("colorbar_") - } - ) - - fig.add_trace( - go.Heatmap( - z=df.values, - x=df.columns.values, - y=df.index.values, - hovertext=df.applymap( - lambda v: f"{value_label}: {v}" - ), - coloraxis=coloraxis, - showlegend=showlegend, - showscale=showscale - ), - row=row, - col=col - ) - - def bar(self, vals: pd.Series, label: str, fig, row, col, orient="v", log=False): - assert orient in ["v", "h"] - kwargs = ( - dict( - y=vals.values, - x=vals.index.values - ) - if orient == "v" else - dict( - x=vals.values, - y=vals.index.values, - orientation='h' - ) - ) - fig.add_trace( - go.Bar( - showlegend=False, - hovertext=vals.apply( - lambda v: f"{label}: {v}" - ), - marker=dict(color="blue"), - **kwargs - ), - row=row, - col=col - ) - label = label.replace(" ", "
") - self.label_axis(fig, row, col, orient, label) - if log: - self.logscale_axis(fig, row, col, orient) - - def label_axis(self, fig, row, col, orient, title): - self.update_axis( - fig, row, col, orient, title=title - ) - - def logscale_axis(self, fig, row, col, orient): - self.update_axis( - fig, row, col, orient, type="log" - ) - - @staticmethod - def update_axis(fig, row, col, orient, **kwargs): - if orient == "v": - fig.for_each_yaxis( - lambda axis: axis.update(**kwargs), - row=row, - col=col - ) - else: - fig.for_each_xaxis( - lambda axis: axis.update(**kwargs), - row=row, - col=col - ) - - def box( - self, - df: pd.DataFrame, - huevals: pd.Series, - huelabel: str, - label: str, - fig, - row, - col, - orient="v", - log=False, - **box_kwargs - ): - assert orient in ["v", "h"] - colors = ["blue", "orange", "red", "green", "black"] - for ix, (hueval, hue_df) in enumerate(df.groupby(huevals)): - trace_df = ( - hue_df - .reset_index() - .melt( - id_vars=[( - "index" - if hue_df.index.name is None - else hue_df.index.name - )], - var_name="variable" - ) - ) - - kwargs = ( - dict( - y=trace_df["value"], - x=trace_df["variable"] - ) - if orient == "v" else - dict( - orientation="h", - x=trace_df["value"], - y=trace_df["variable"] - ) - ) - - fig.add_trace( - go.Box( - name=f"{huelabel}: {hueval}", - legendgroup=f"{huelabel}: {hueval}", - marker=dict(color=colors[ix % len(colors)]), - **kwargs, - **box_kwargs - ), - row=row, - col=col - ) - label = label.replace(" ", "
") - self.label_axis(fig, row, col, orient, label) - if log: - self.logscale_axis(fig, row, col, orient) - def bin_metagenomes( read_alignments, @@ -1181,7 +614,6 @@ def bin_metagenomes( genome_groups, group_profile, metadata, - category, min_n_reads, min_n_genes ): @@ -1194,7 +626,6 @@ def bin_metagenomes( min_n_genes ) mdata.add_metadata(metadata) - mdata.compare_groups(category) return mdata @@ -1204,7 +635,6 @@ def bin_metagenomes( @click.option('--genome_groups', type=click.Path(exists=True)) @click.option('--group_profile', type=click.Path(exists=True)) @click.option('--metadata', type=click.Path(exists=True)) -@click.option('--category', type=str) @click.option('--min_n_reads', type=int) @click.option('--min_n_genes', type=int) @click.option('--output_folder', type=click.Path()) @@ -1214,7 +644,6 @@ def main( genome_groups, group_profile, metadata, - category, min_n_reads, min_n_genes, output_folder @@ -1225,7 +654,6 @@ def main( logger.info(f"genome_groups: {genome_groups}") logger.info(f"group_profile: {group_profile}") logger.info(f"metadata: {metadata}") - logger.info(f"category: {category}") logger.info(f"min_n_reads: {min_n_reads}") logger.info(f"min_n_genes: {min_n_genes}") @@ -1235,15 +663,13 @@ def main( genome_groups, group_profile, metadata, - category, min_n_reads, min_n_genes ) - mdata.to_csv(output_folder) - mdata.to_h5ad(output_folder) + mdata.to_csv(f"{output_folder}/csv") + mdata.to_h5ad(f"{output_folder}/h5ad") mdata.data.write_h5mu(f"{output_folder}/metagenome.h5mu") - mdata.plot(output_folder) if __name__ == "__main__": diff --git a/bin/plot_metagenomes.py b/bin/plot_metagenomes.py new file mode 100755 index 0000000..af86c34 --- /dev/null +++ b/bin/plot_metagenomes.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python3 +from typing import List +import anndata as ad +import click +import logging +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import plotly.express as px +from plotly.subplots import make_subplots +from scipy.cluster import hierarchy +from statsmodels.stats.multitest import multipletests + +# Set the level of the logger to INFO +logFormatter = logging.Formatter( + '%(asctime)s %(levelname)-8s [plot_metagenomes.py] %(message)s' +) +logger = logging.getLogger('plot_metagenomes.py') +logger.setLevel(logging.INFO) + +# Write to STDOUT +consoleHandler = logging.StreamHandler() +consoleHandler.setFormatter(logFormatter) +logger.addHandler(consoleHandler) + +# Write to file +fileHandler = logging.FileHandler("plot_metagenomes.log") +fileHandler.setFormatter(logFormatter) +fileHandler.setLevel(logging.INFO) +logger.addHandler(fileHandler) + + +def fix_name(n: str, options: List[str]): + if n in options: + return n + for m in options: + if m.replace(" ", ".") == n: + return m + return n + + +class Metagenome: + + adata: ad.AnnData + + def __init__( + self, + stats, + h5ad + ): + + stats = self._read_csv(stats).set_index("feature") + self.adata = ad.read_h5ad(h5ad) + + # Restore the variable names that may have been mangled by R + stats = stats.rename( + index=lambda n: fix_name(n, self.adata.var_names) + ) + self.adata.var = self.adata.var.merge( + stats, + left_index=True, + right_index=True + ) + self.adata.var["qvalue"] = multipletests(self.adata.var["p_value"], 0.1, "fdr_bh")[1] + self.adata.var["neg_log10_qvalue"] = -np.log10(self.adata.var["qvalue"]) + for line in str(self.adata).split("\n"): + logger.info(line) + + def _read_csv(self, fp, **kwargs): + logger.info(f"Reading in {fp}") + df: pd.DataFrame = pd.read_csv(fp, **kwargs) + logger.info(f"Read in {df.shape[0]:,} rows and {df.shape[1]:,} columns") + self.log_df(df.head().T.head().T) + return df + + @staticmethod + def log_df(df: pd.DataFrame, **kwargs): + for line in df.to_csv(**kwargs).split("\n"): + logger.info(line) + + @staticmethod + def log_scale(df: pd.DataFrame): + + lowest = df.apply(lambda c: c[c > 0].min()).min() + return df.clip(lower=lowest).apply(np.log10) + + def sort_index(self, df: pd.DataFrame, metric="cosine", method="average"): + try: + return df.index.values[ + hierarchy.leaves_list( + hierarchy.linkage( + df.values, + metric=metric, + method=method + ) + ) + ] + except Exception as e: + logger.info("Error encountered while sorting table:") + self.log_df(df) + raise e + + def plot(self, meta_cname, output_folder): + + output_fp = f"{output_folder}/{meta_cname}" + + # Sort the bins and samples + bin_order = self.sort_index( + self.adata.to_df().T, + metric="cosine", + method="average" + ) + + sample_order = self.sort_index( + self.adata.to_df(), + metric="euclidean", + method="ward" + ) + + genome_order = self.sort_index( + self.adata.uns["group_profile"].T, + metric="cosine", + method="average" + ) + + # Relative size of rows and columns + heatmap_size = 3 + column_widths = np.array([heatmap_size, 1, heatmap_size, 1, 1]) + column_widths = list(column_widths / column_widths.sum()) + row_heights = np.array([0.5, heatmap_size, 0.5]) + row_heights = list(row_heights / row_heights.sum()) + horizontal_spacing = 0.05 + + cols = 5 + rows = 3 + fig = make_subplots( + rows=rows, + cols=cols, + shared_xaxes=True, + shared_yaxes=True, + start_cell="bottom-left", + column_widths=column_widths, + row_heights=row_heights, + horizontal_spacing=horizontal_spacing + ) + # Bins across Genomes + self.heatmap( + ( + self + .adata + .uns["group_profile"] + .reindex( + columns=genome_order, + index=bin_order + ) + ), + fig, + value_label="Gene Bin Presence in Genome", + row=2, + col=1, + showscale=True, + coloraxis="coloraxis2", + colorbar_x=(horizontal_spacing / 4), + colorbar_xanchor="left", + colorbar_orientation="h", + colorbar_y=0.85, + colorbar_len=column_widths[0] - (horizontal_spacing * 2), + colorbar_title_side="top" + ) + + # # Bins Information + + # Number of genes per bin + self.bar( + pd.Series( + self.adata.uns["bin_size"] + ).reindex( + index=bin_order + ), + "Genes per Bin (#)", + fig, + row=2, + col=2, + orient="h" + ) + + # Bins across samples + self.heatmap( + self.log_scale(( + self + .adata + .to_df("prop") + .reindex( + index=sample_order, + columns=bin_order + ) + .T + )), + fig, + value_label="Proportion of Aligned Reads (log10)", + row=2, + col=3, + coloraxis="coloraxis3", + showscale=True, + colorbar_x=sum(column_widths[:2]) + (horizontal_spacing / 2), + colorbar_xanchor="left", + colorbar_orientation="h", + colorbar_y=0.85, + colorbar_len=column_widths[2] - (horizontal_spacing * 1.5), + colorbar_title_side="top", + colorbar_title_text="Gene Bin Abundance
Proportion of Aligned Reads (log10)", + colorbar_xpad=0, + ) + + # Metadata data on the bins heatmap + self.heatmap( + self.adata.obs.reindex( + index=sample_order, + columns=[meta_cname] + ).T, + fig, + value_label=meta_cname, + col=3, + row=1, + colorscale=["blue", "orange"], + coloraxis="coloraxis", + showscale=False + ) + + # Bars showing the -log10(qvalue) for each gene bin + self.bar( + ( + self + .adata + .var + ["neg_log10_qvalue"] + ), + "FDR-adjusted
p-value (-log10)", + fig, + orient="h", + col=4, + row=2 + ) + + # Boxplot comparing gene bin abundances by sample group + self.box( + self.adata.to_df().reindex(index=sample_order), + self.adata.obs[meta_cname].reindex(index=sample_order), + meta_cname, + "Gene Bin Abundance", + fig, + col=5, + row=2, + log=True, + orient="h" + ) + + logger.info(fig.layout) + title_text = f"Association with {meta_cname}" + fig.update_layout( + title_text=title_text, + title_xanchor="center", + title_x=0.5, + plot_bgcolor='rgba(255, 255, 255, 1.0)', + paper_bgcolor='rgba(255, 255, 255, 1.0)', + legend=dict( + yanchor="top", + xanchor="left", + x=1.0, + y=0.3 + ), + **{ + f"xaxis{i + cols}": dict(showticklabels=True) + for i in [1, 2, 4, 5] + }, + **{ + f"yaxis{2 + (cols * i)}": dict(showticklabels=True) + for i in [1, 2, 4, 5, 6] + } + ) + fig.write_html(f"{output_fp}.html") + for ext in ['png', 'pdf']: + fig.write_image( + f"{output_fp}.{ext}", + width=2880, + height=1800 + ) + + @staticmethod + def heatmap( + df: pd.DataFrame, + fig, + value_label: str, + row=None, + col=None, + coloraxis="coloraxis", + colorscale=px.colors.sequential.Blues, + showlegend=False, + showscale=False, + **kwargs + ): + + assert df.dropna().shape[0] == df.shape[0], df + + if "colorbar_title_text" not in kwargs: + kwargs["colorbar_title_text"] = value_label + + fig.layout[coloraxis] = dict( + colorscale=colorscale, + showscale=showscale, + **{ + kw: val + for kw, val in kwargs.items() + if kw.startswith("colorbar_") + } + ) + + fig.add_trace( + go.Heatmap( + z=df.values, + x=df.columns.values, + y=df.index.values, + hovertext=df.applymap( + lambda v: f"{value_label}: {v}" + ), + coloraxis=coloraxis, + showlegend=showlegend, + showscale=showscale + ), + row=row, + col=col + ) + + def bar(self, vals: pd.Series, label: str, fig, row, col, orient="v", log=False): + assert orient in ["v", "h"] + kwargs = ( + dict( + y=vals.values, + x=vals.index.values + ) + if orient == "v" else + dict( + x=vals.values, + y=vals.index.values, + orientation='h' + ) + ) + fig.add_trace( + go.Bar( + showlegend=False, + hovertext=vals.apply( + lambda v: f"{label}: {v}" + ), + marker=dict(color="blue"), + **kwargs + ), + row=row, + col=col + ) + self.label_axis(fig, row, col, orient, label) + if log: + self.logscale_axis(fig, row, col, orient) + + def label_axis(self, fig, row, col, orient, title): + self.update_axis( + fig, row, col, orient, title=title + ) + + def logscale_axis(self, fig, row, col, orient): + self.update_axis( + fig, row, col, orient, type="log" + ) + + @staticmethod + def update_axis(fig, row, col, orient, **kwargs): + if orient == "v": + fig.for_each_yaxis( + lambda axis: axis.update(**kwargs), + row=row, + col=col + ) + else: + fig.for_each_xaxis( + lambda axis: axis.update(**kwargs), + row=row, + col=col + ) + + def box( + self, + df: pd.DataFrame, + huevals: pd.Series, + huelabel: str, + label: str, + fig, + row, + col, + orient="v", + log=False, + **box_kwargs + ): + assert orient in ["v", "h"] + colors = ["blue", "orange", "red", "green", "black"] + for ix, (hueval, hue_df) in enumerate(df.groupby(huevals)): + trace_df = ( + hue_df + .reset_index() + .melt( + id_vars=[( + "index" + if hue_df.index.name is None + else hue_df.index.name + )], + var_name="variable" + ) + ) + + kwargs = ( + dict( + y=trace_df["value"], + x=trace_df["variable"] + ) + if orient == "v" else + dict( + orientation="h", + x=trace_df["value"], + y=trace_df["variable"] + ) + ) + + fig.add_trace( + go.Box( + name=f"{huelabel}: {hueval}", + legendgroup=f"{huelabel}: {hueval}", + marker=dict(color=colors[ix % len(colors)]), + **kwargs, + **box_kwargs + ), + row=row, + col=col + ) + label = label.replace(" ", "
") + self.label_axis(fig, row, col, orient, label) + if log: + self.logscale_axis(fig, row, col, orient) + + +def plot_metagenomes( + param, + stats, + h5ad, + output_folder +): + mdata = Metagenome(stats, h5ad) + mdata.plot(param, output_folder) + + +@click.command +@click.option('--param', type=str) +@click.option('--stats', type=click.Path(exists=True)) +@click.option('--h5ad', type=click.Path(exists=True)) +@click.option('--output_folder', type=click.Path()) +def main( + param, + stats, + h5ad, + output_folder +): + + logger.info(f"param: {param}") + logger.info(f"stats: {stats}") + logger.info(f"counts: {h5ad}") + logger.info(f"output_folder: {output_folder}") + + plot_metagenomes( + param, + stats, + h5ad, + output_folder + ) + + +if __name__ == "__main__": + main() diff --git a/bin_metagenomes.nf b/bin_metagenomes.nf index 209444e..ed19b91 100644 --- a/bin_metagenomes.nf +++ b/bin_metagenomes.nf @@ -8,7 +8,7 @@ GroovyShell shell = new GroovyShell() def helpers = shell.parse(new File("${workflow.projectDir}/helpers.gvy")) // Import sub-workflows -include { bin_metagenomes } from './modules/processes/bin_metagenomes' +include { bin_metagenomes } from './modules/bin_metagenomes' // Standalone entrypoint workflow { diff --git a/modules/bin_metagenomes.nf b/modules/bin_metagenomes.nf new file mode 100644 index 0000000..0eefab4 --- /dev/null +++ b/modules/bin_metagenomes.nf @@ -0,0 +1,46 @@ +include { + collect; + split; + plot +} from "./processes/bin_metagenomes" + +include { corncob } from "./test_reads" + +workflow bin_metagenomes { + take: + read_alignments + gene_bins + genome_groups + group_profile + metadata + + main: + + collect( + read_alignments, + gene_bins, + genome_groups, + group_profile, + metadata + ) + + corncob( + collect.out.metadata, + collect.out.bin_counts + ) + + split( + corncob.out + ) + + plot( + split + .out + .flatten() + .map { + it -> [it.name.replaceAll(".results.csv", ""), it] + } + .combine(collect.out.bins_h5ad) + ) + +} \ No newline at end of file diff --git a/modules/processes/align_reads.nf b/modules/processes/align_reads.nf index bbe804b..644a850 100644 --- a/modules/processes/align_reads.nf +++ b/modules/processes/align_reads.nf @@ -62,7 +62,7 @@ process diamond { process diamond_logs { container "${params.container__pandas}" label 'io_limited' - publishDir "${params.output}", mode: 'copy', overwrite: true + publishDir "${params.output}/logs/", mode: 'copy', overwrite: true input: path "*" @@ -114,7 +114,7 @@ process famli { process famli_logs { container "${params.container__pandas}" label 'io_limited' - publishDir "${params.output}", mode: "copy", overwrite: true + publishDir "${params.output}/logs/", mode: "copy", overwrite: true input: path "*" diff --git a/modules/processes/bin_metagenomes.nf b/modules/processes/bin_metagenomes.nf index 14aeff1..977db8a 100644 --- a/modules/processes/bin_metagenomes.nf +++ b/modules/processes/bin_metagenomes.nf @@ -1,4 +1,4 @@ -process bin_metagenomes { +process collect { container "${params.container__pandas}" label 'io_limited' publishDir "${params.output}", mode: 'copy', overwrite: true @@ -11,21 +11,76 @@ process bin_metagenomes { path metadata output: - path "*" + path "*", emit: all + path "csv/metagenome.obs.csv.gz", emit: metadata + path "csv/metagenome.bins.X.csv.gz", emit: bin_counts + path "h5ad/metagenome.bins.h5ad", emit: bins_h5ad + path "h5ad/metagenome.genes.h5ad", emit: genes_h5ad + path "h5ad/metagenome.genomes.h5ad", emit: genomes_h5ad """#!/bin/bash set -e -bin_metagenomes.py \ +collect_metagenomes.py \ --read_alignments "${read_alignments}" \ --gene_bins "${gene_bins}" \ --genome_groups "${genome_groups}" \ --group_profile "${group_profile}" \ --metadata "${metadata}" \ - --category ${params.category} \ --min_n_reads ${params.min_n_reads} \ --min_n_genes ${params.min_n_genes} \ --output_folder ./ """ +} + +// Split up the corncob results as inputs for plotting +process split { + container "${params.container__pandas}" + label 'io_limited' + + input: + path "corncob.results.csv" + + output: + path "*.results.csv" + + """#!/usr/bin/env python3 +import pandas as pd +df = pd.read_csv("corncob.results.csv") + +for param, param_df in df.groupby("parameter"): + if not param.startswith("mu.") or param.endswith("(Intercept)"): + continue + + param = param[3:] + ( + param_df + .set_index("feature") + .drop(columns=["parameter"]) + .to_csv(f"{param}.results.csv") + ) + """ +} + +process plot { + container "${params.container__pandas}" + label 'io_limited' + publishDir "${params.output}/association/${param}/", mode: 'copy', overwrite: true + + input: + tuple val(param), path(stats), path(h5ad) + + output: + path "*" + + """#!/bin/bash +set -e + +plot_metagenomes.py \ + --param "${param}" \ + --stats "${stats}" \ + --h5ad "${h5ad}" \ + --output_folder ./ +""" } \ No newline at end of file diff --git a/modules/test_reads.nf b/modules/test_reads.nf index 02f3672..f605c6f 100644 --- a/modules/test_reads.nf +++ b/modules/test_reads.nf @@ -41,7 +41,7 @@ process shard_genes { // Test for differences between samples process corncob { container "${params.container__corncob}" - label "mem_medium" + label "io_limited" input: file metadata_csv diff --git a/nextflow.config b/nextflow.config index 93c5575..53f337c 100644 --- a/nextflow.config +++ b/nextflow.config @@ -195,7 +195,7 @@ params { container__raxml = "quay.io/biocontainers/raxml-ng:1.0.3--h32fcf60_0" container__gigmap = "quay.io/hdc-workflows/gig-map:c08eac9" container__mash = "staphb/mashtree:0.52.0" - container__corncob = "quay.io/fhcrc-microbiome/corncob" + container__corncob = "quay.io/fhcrc-microbiome/corncob:84c8354" container__datasets = "biocontainers/ncbi-datasets-cli:15.12.0_cv23.1.0-4" } diff --git a/templates/corncob.Rscript b/templates/corncob.Rscript index 179dbeb..dab1734 100644 --- a/templates/corncob.Rscript +++ b/templates/corncob.Rscript @@ -14,76 +14,60 @@ Sys.setenv("VROOM_CONNECTION_SIZE" = format(connectionSize, scientific=F)) numCores = ${task.cpus} -## READCOUNTS CSV should have columns `specimen` (first col) and `total` (last column). -## METADATA CSV should have columns `specimen` (which matches up with `specimen` from +## READCOUNTS CSV should have sample IDs in the first col +## METADATA CSV should have a column `specimen` (which matches up with the first column from ## the recounts file), and additional columns with covariates matching `formula` ## corncob analysis (coefficients and p-values) are written to OUTPUT CSV on completion print("Reading in ${metadata_csv}") -metadata <- vroom::vroom("${metadata_csv}", delim=",") +metadata <- read.csv("${metadata_csv}", sep=",") +metadata <- tibble::column_to_rownames(metadata, names(metadata)[1]) print(metadata) print("Reading in ${readcounts_csv_gz}") -counts <- vroom::vroom("${readcounts_csv_gz}", delim=",") -total_counts <- counts[,c("specimen", "total")] +counts <- read.csv("${readcounts_csv_gz}", sep=",") +counts <- tibble::column_to_rownames(counts, "specimen") +print(counts) -print("Adding total counts to manifest") -print(head(total_counts)) +#### Run the differentialAbundance analysis +da <- differentialTest( + data = counts, + formula = ~ ${params.formula}, + phi.formula = ~ 1, + formula_null = ~ 1, + phi.formula_null = ~ 1, + sample_data = metadata, + taxa_are_rows = FALSE, + test = "Wald", + full_output = TRUE +) -print("Merging total counts with metadata") -total_and_meta <- metadata %>% - left_join(total_counts, by = c("specimen" = "specimen")) - -#### Run the analysis for every individual gene (in this shard) -print(sprintf("Starting to process %s columns", length(c(2:(dim(counts)[2] - 1))))) -corn_tib <- do.call(rbind, mclapply( - c(2:(dim(counts)[2] - 1)), - function(i){ - try_bbdml <- try( - counts[,c(1, i)] %>% - rename(W = 2) %>% - right_join( - total_and_meta, - by = c("specimen" = "specimen") - ) %>% - corncob::bbdml( - formula = cbind(W, total - W) ~ ${params.formula}, - phi.formula = ~ 1, - data = . - ) - ) - - if (class(try_bbdml) == "bbdml") { - return( - summary( - try_bbdml - )\$coef %>% - as_tibble %>% - mutate("parameter" = summary(try_bbdml)\$coef %>% row.names) %>% - rename( - "estimate" = Estimate, - "std_error" = `Std. Error`, - "p_value" = `Pr(>|t|)` - ) %>% - select(-`t value`) %>% - gather(key = type, ...=estimate:p_value) %>% - mutate("gene_id" = names(counts)[i]) +# Rename the outputs as a table +output <- do.call( + rbind, + lapply( + seq_along(da\$all_models), + function(i){ + coef <- da\$all_models[[i]]\$coef + return( + coef + %>% as_tibble + %>% mutate( + "parameter" = coef %>% row.names, + "feature" = colnames(counts)[i] + ) + %>% rename( + "estimate" = Estimate, + "std_error" = `Std. Error`, + "p_value" = `Pr(>|t|)` + ) + %>% select(-`t value`) + ) + } ) - } else { - return( - tibble( - "parameter" = "all", - "type" = "failed", - "value" = NA, - "gene_id" = names(counts)[i] - ) - ) - } - }, - mc.cores = numCores - )) + ) -print(sprintf("Writing out %s rows to corncob.results.csv", nrow(corn_tib))) -write_csv(corn_tib, "corncob.results.csv") +print(sprintf("Writing out %s rows to corncob.results.csv", nrow(output))) +write_csv(output, "corncob.results.csv") print("Done")