Skip to content

Commit

Permalink
pinnacle df reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Jul 31, 2024
1 parent 4b946a0 commit c408ce3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
24 changes: 20 additions & 4 deletions tdc/resource/pinnacle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,28 @@ def get_keys(self):
def get_embeds(self):
prots = self.get_keys()
emb = self.get_embeds_raw()
nemb = {'--'.join(prots.iloc[k]): v for k, v in emb.items()}
# nemb = {'--'.join(prots.iloc[k]): v for k, v in emb.items()}
x = {}
for k, v in nemb.items():
ctr = 0
for _, v in emb.items():
if isinstance(v, torch.Tensor):
x[k] = pd.DataFrame(v.detach().numpy())
if v.size()[0] == 1:
k = "--".join(prots.iloc[ctr])
ctr += 1
x[k] = v.detach().numpy()
continue
for t in v:
assert len(t.size()) == 1, t.size()
k = "--".join(prots.iloc[ctr])
ctr += 1
x[k] = t.detach().numpy()
else:
raise Exception("encountered non-tensor")
df = pd.concat(x, axis=0)
assert len(x) == len(prots), "dict len {} vs keys length {}".format(
len(x), len(prots))
df = pd.DataFrame.from_dict(x)
df = df.transpose()
assert len(df) == len(
x), "dims not mantained when translated to pandas. {} vs {}".format(
len(df), len(x))
return df
10 changes: 10 additions & 0 deletions tdc/test/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def test_embeddings(self):
assert len(keys) > 0, "PINNACLE keys is empty"
assert len(keys) == len(embeds), "{} vs {}".format(
len(keys), len(embeds))
num_targets = len(keys["target"].unique())
num_cells = len(keys["cell type"].unique())
all_entries = embeds.index
prots = [x.split("--")[0] for x in all_entries]
cells = [x.split("--")[1] for x in all_entries]
assert len(
set(prots)) == num_targets, "{} vs {} for target proteins".format(
len(prots), num_targets)
assert len(set(cells)) == num_cells, "{} vs {} for cell_types".format(
len(cells), num_cells)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tdc/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
#
__version__ = "1.0.0" # pragma: no cover
__version__ = "1.0.6" # pragma: no cover

0 comments on commit c408ce3

Please sign in to comment.