From c408ce3993a3c6e0790c15dc5644777412801314 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 31 Jul 2024 13:02:53 -0400 Subject: [PATCH] pinnacle df reformatting --- tdc/resource/pinnacle.py | 24 ++++++++++++++++++++---- tdc/test/test_resources.py | 10 ++++++++++ tdc/version.py | 2 +- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tdc/resource/pinnacle.py b/tdc/resource/pinnacle.py index f2f99067..fefcab94 100644 --- a/tdc/resource/pinnacle.py +++ b/tdc/resource/pinnacle.py @@ -61,12 +61,28 @@ def get_keys(self): def get_embeds(self): prots = self.get_keys() emb = self.get_embeds_raw() - nemb = {'--'.join(prots.iloc[k]): v for k, v in emb.items()} + # nemb = {'--'.join(prots.iloc[k]): v for k, v in emb.items()} x = {} - for k, v in nemb.items(): + ctr = 0 + for _, v in emb.items(): if isinstance(v, torch.Tensor): - x[k] = pd.DataFrame(v.detach().numpy()) + if v.size()[0] == 1: + k = "--".join(prots.iloc[ctr]) + ctr += 1 + x[k] = v.detach().numpy() + continue + for t in v: + assert len(t.size()) == 1, t.size() + k = "--".join(prots.iloc[ctr]) + ctr += 1 + x[k] = t.detach().numpy() else: raise Exception("encountered non-tensor") - df = pd.concat(x, axis=0) + assert len(x) == len(prots), "dict len {} vs keys length {}".format( + len(x), len(prots)) + df = pd.DataFrame.from_dict(x) + df = df.transpose() + assert len(df) == len( + x), "dims not mantained when translated to pandas. {} vs {}".format( + len(df), len(x)) return df diff --git a/tdc/test/test_resources.py b/tdc/test/test_resources.py index 25e4cad0..c30e821f 100644 --- a/tdc/test/test_resources.py +++ b/tdc/test/test_resources.py @@ -97,6 +97,16 @@ def test_embeddings(self): assert len(keys) > 0, "PINNACLE keys is empty" assert len(keys) == len(embeds), "{} vs {}".format( len(keys), len(embeds)) + num_targets = len(keys["target"].unique()) + num_cells = len(keys["cell type"].unique()) + all_entries = embeds.index + prots = [x.split("--")[0] for x in all_entries] + cells = [x.split("--")[1] for x in all_entries] + assert len( + set(prots)) == num_targets, "{} vs {} for target proteins".format( + len(prots), num_targets) + assert len(set(cells)) == num_cells, "{} vs {} for cell_types".format( + len(cells), num_cells) if __name__ == "__main__": diff --git a/tdc/version.py b/tdc/version.py index f0fbf9d1..353520a4 100644 --- a/tdc/version.py +++ b/tdc/version.py @@ -19,4 +19,4 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = "1.0.0" # pragma: no cover +__version__ = "1.0.6" # pragma: no cover