Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about gatv2 code #228

Open
XiaokangORCA opened this issue Nov 20, 2023 · 1 comment
Open

Question about gatv2 code #228

XiaokangORCA opened this issue Nov 20, 2023 · 1 comment

Comments

@XiaokangORCA
Copy link

Hello, I am a beginner in GAT , and I've been studying your GATv2 code lately. I have a question while going through the code in

labml_nn/graphs/gatv2/init.py

When calculating g_sum

g_sum = g_l_repeat + g_r_repeat_interleave

You mentioned in the comments: Now we add the two tensors to get

$$ \lbrace\overrightarrow{g_{l1}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l1}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{l1}} + \overrightarrow{g_{rN}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{l2}} + \overrightarrow{g_{rN}}, \dots\rbrace $$

But in the previous code, g_l_repeat gets

$$ \lbrace\overrightarrow{g_{l1}}, \overrightarrow{g_{l2}}, \dots, \overrightarrow{g_{lN}}, \overrightarrow{g_{l1}}, \overrightarrow{g_{l2}}, \dots, \overrightarrow{g_{lN}}, \dots\rbrace $$

and g_r_repeat_interleave gets

$$ \lbrace\overrightarrow{g_{r1}}, \overrightarrow{g_{r1}}, \dots, \overrightarrow{g_{r1}}, \overrightarrow{g_{r2}}, \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{r2}}, \dots\rbrace $$

So I think the result of adding the two tensors should be

$$ \lbrace\overrightarrow{g_{l1}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{lN}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l1}} + \overrightarrow{g_{r2}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{lN}} + \overrightarrow{g_{r2}}, \dots\rbrace $$

I'm not sure whether I may have overlooked some crucial information or if there's a mismatch between your comments and the code. I would greatly appreciate it if you could help clarify my confusion. Thank you.

@rjavierch
Copy link

Hello! I am also new to GAT, I found your issue.

So, to your question, the implementation in the website is correct (partially), I think this is because

g_l_repeat

$${\overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, \overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, ...}$$

is:

>>> n_nodes = 3 # or N
>>> torch.tensor([[1], [2], [3]])
>>> tensor.repeat(n_nodes , 1)
tensor([[1],
        [2],
        [3],
        [1],
        [2],
        [3],
        [1],
        [2],
        [3]])

and g_r_repeat_interleave

$${\overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...}$$

is instead:

>>> tensor.repeat_interleave(n_nodes, dim=0)
tensor([[1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [3],
        [3],
        [3]])

So, the operation g_l_repeat + g_r_repeat_interleave

$${\overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, \overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, ...}$$

$$\ + $$

$${\overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...}$$

is

>>> tensor.repeat(n_nodes , 1) + tensor.repeat_interleave(n_nodes, dim=0)
tensor([[1] + [1],
        [2] + [1],
        [3] + [1],
        [1] + [2],
        [2] + [2],
        [3] + [2],
        [1] + [3],
        [2] + [3],
        [3] + [3]])

So, this is correct:

$${\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_N}, ...}$$

But, if you want to match the notation (to avoid confusion), should (I think) be this. However, the current implementation is correct:

$${\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_l}_N} +\overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_2}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_N} + \overrightarrow{{g_r}_2}, ...}$$

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants