diff --git a/dynamo/tools/dynamics.py b/dynamo/tools/dynamics.py index 6b0a46913..9a1952e10 100755 --- a/dynamo/tools/dynamics.py +++ b/dynamo/tools/dynamics.py @@ -716,6 +716,7 @@ def dynamics( kin_param_pre, valid_bools_, ind_for_proteins, + cur_cells_bools, ) elif assumption_mRNA.lower() == "kinetic": diff --git a/dynamo/tools/utils.py b/dynamo/tools/utils.py index 6640c3bdd..f39ec0265 100755 --- a/dynamo/tools/utils.py +++ b/dynamo/tools/utils.py @@ -1552,11 +1552,12 @@ def set_param_ss( kin_param_pre, valid_ind, ind_for_proteins, + cur_cells_bools, ): params_df = pd.DataFrame(index=adata.var.index) if experiment_type == "mix_std_stm": if alpha is not None: - if cur_grp == _group[0]: + if kin_param_pre + "alpha" not in adata.varm: adata.varm[kin_param_pre + "alpha"] = np.zeros((adata.shape[1], alpha[1].shape[1])) adata.varm[kin_param_pre + "alpha"][valid_ind, :] = alpha[1] ( @@ -1581,13 +1582,13 @@ def set_param_ss( else: if alpha is not None: if len(alpha.shape) > 1: # for each cell - if cur_grp == _group[0]: + if kin_param_pre + "alpha" not in adata.varm: adata.varm[kin_param_pre + "alpha"] = ( sp.csr_matrix(np.zeros(adata.shape[::-1])) if sp.issparse(alpha) else np.zeros(adata.shape[::-1]) ) # - adata.varm[kin_param_pre + "alpha"][valid_ind, :] = alpha # + adata.varm[kin_param_pre + "alpha"][valid_ind, :][:, cur_cells_bools] = alpha # params_df.loc[valid_ind, kin_param_pre + "alpha"] = alpha.mean(1) elif len(alpha.shape) == 1: if cur_grp == _group[0]: @@ -1753,10 +1754,10 @@ def set_param_kinetic( np.where(cur_cells_bools)[0][:, np.newaxis], np.where(valid_ind)[0], ) - if cur_grp == _group[0]: - adata.layers["cell_wise_alpha"] = sp.csr_matrix((adata.shape), dtype=np.float64) + if kin_param_pre + "cell_wise_alpha" not in adata.layers: + adata.layers[kin_param_pre + "cell_wise_alpha"] = sp.csr_matrix((adata.shape), dtype=np.float64) alpha = alpha.T.tocsr() if sp.issparse(alpha) else sp.csr_matrix(alpha, dtype=np.float64).T - adata.layers["cell_wise_alpha"][cur_cells_ind, valid_ind_] = alpha + adata.layers[kin_param_pre + "cell_wise_alpha"][cur_cells_ind, valid_ind_] = alpha else: params_df.loc[valid_ind, kin_param_pre + "alpha"] = alpha params_df.loc[valid_ind, kin_param_pre + "a"] = a @@ -1818,9 +1819,9 @@ def get_vel_params( for param in params: if param == "alpha": if not skip_cell_wise: - if "cell_wise_alpha" in adata.layers.keys(): - target_params.append(adata[:, genes].layers["cell_wise_alpha"]) - elif "alpha" in adata.varm.keys(): + if kin_param_pre + "cell_wise_alpha" in adata.layers.keys(): + target_params.append(adata[:, genes].layers[kin_param_pre + "cell_wise_alpha"]) + elif kin_param_pre + "alpha" in adata.varm.keys(): target_params.append(adata[:, genes].varm[kin_param_pre + "alpha"]) else: target_params.append(df.loc[genes, kin_param_pre + "alpha"].values)