diff --git a/CHANGELOG.md b/CHANGELOG.md index 60133a85b..51d4cdf8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][]. ## [Unreleased] +### Fixes + +- Fix that `define_clonotype_clusters` could not retreive `within_group` columns from MuData ([#459](https://github.com/scverse/scirpy/pull/459)) + ## v0.13.1 ### Fixes diff --git a/docs/conf.py b/docs/conf.py index fa8235efa..eea484b22 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -138,6 +138,7 @@ "repository_url": repository_url, "use_repository_button": True, "path_to_docs": "docs/", + "navigation_with_keys": False, } pygments_style = "default" diff --git a/src/scirpy/ir_dist/_clonotype_neighbors.py b/src/scirpy/ir_dist/_clonotype_neighbors.py index b044e99bd..5176e35ff 100644 --- a/src/scirpy/ir_dist/_clonotype_neighbors.py +++ b/src/scirpy/ir_dist/_clonotype_neighbors.py @@ -77,7 +77,7 @@ def _make_clonotype_table(self, params: DataHandler) -> tuple[Mapping, pd.DataFr obs = obs.loc[_has_ir(params) & np.any(~pd.isnull(obs), axis=1), :] if self.match_columns is not None: obs = obs.join( - params.adata.obs.loc[:, self.match_columns], + params.get_obs(self.match_columns), validate="one_to_one", how="inner", ) diff --git a/src/scirpy/tl/_clonotypes.py b/src/scirpy/tl/_clonotypes.py index 304c27ba0..398c3a119 100644 --- a/src/scirpy/tl/_clonotypes.py +++ b/src/scirpy/tl/_clonotypes.py @@ -123,11 +123,11 @@ def _validate_parameters( - adata, - reference, + params: DataHandler, + reference: DataHandler, receptor_arms, dual_ir, - match_columns, + within_group, distance_key, sequence, metric, @@ -137,7 +137,7 @@ def _validate_parameters( def _get_db_name(): try: - return reference.uns["DB"]["name"] + return reference.adata.uns["DB"]["name"] except KeyError: raise ValueError( 'If reference does not contain a `.uns["DB"]["name"]` entry, ' @@ -153,22 +153,24 @@ def _get_db_name(): if dual_ir not in ["primary_only", "all", "any"]: raise ValueError("Invalid value for `dual_ir") - if match_columns is not None: - if isinstance(match_columns, str): - match_columns = [match_columns] - for group_col in match_columns: - if group_col not in adata.obs.columns: - msg = f"column `{match_columns}` not found in `adata.obs`. " + if within_group is not None: + if isinstance(within_group, str): + within_group = [within_group] + for group_col in within_group: + try: + params.get_obs(group_col) + except KeyError: + msg = f"column `{group_col}` not found in `obs`. " if group_col in ("receptor_type", "receptor_subtype"): msg += "Did you run `tl.chain_qc`? " - raise ValueError(msg) + raise ValueError(msg) from None if distance_key is None: if reference is not None: distance_key = f"ir_dist_{_get_db_name()}_{sequence}_{_get_metric_key(metric)}" else: distance_key = f"ir_dist_{sequence}_{_get_metric_key(metric)}" - if distance_key not in adata.uns: + if distance_key not in params.adata.uns: raise ValueError("Sequence distances were not found in `adata.uns`. Did you run `pp.ir_dist`?") if key_added is None: @@ -177,7 +179,7 @@ def _get_db_name(): else: key_added = f"cc_{sequence}_{_get_metric_key(metric)}" - return match_columns, distance_key, key_added + return within_group, distance_key, key_added @DataHandler.inject_param_docs( @@ -271,7 +273,7 @@ def define_clonotype_clusters( """ params = DataHandler(adata, airr_mod, airr_key, chain_idx_key) within_group, distance_key, key_added = _validate_parameters( - params.adata, + params, None, receptor_arms, dual_ir, diff --git a/src/scirpy/tl/_ir_query.py b/src/scirpy/tl/_ir_query.py index 11fd167f3..29fd7e1b7 100644 --- a/src/scirpy/tl/_ir_query.py +++ b/src/scirpy/tl/_ir_query.py @@ -178,8 +178,8 @@ def ir_query( DataHandler(reference, airr_mod_ref, airr_key_ref, chain_idx_key_ref) if reference is not None else None ) match_columns, distance_key, key_added = _validate_parameters( - params.adata, - params_ref.adata if params_ref is not None else None, + params, + params_ref if params_ref is not None else None, receptor_arms, dual_ir, match_columns,