You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks very much for such a wonderful product! I am trying to replicate GAT's paper with the CORA dataset, but I am finding some issues in using jraph . I started from your example notebook, implementing GAT, along with add_self_edges_fn:
def add_self_edges_fn(receivers: jnp.ndarray,
senders: jnp.ndarray,
total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
r"""Adds self edges. Assumes self edges are not in the graph yet."""
receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
return receivers, senders
def GAT(attention_query_fn: Callable,
attention_logit_fn: Callable,
node_update_fn: Optional[Callable] = None,
add_self_edges: bool = True) -> Callable:
r""" Main GAT function"""
# pylint: disable=g-long-lambda
if node_update_fn is None:
# By default, apply the leaky relu and then concatenate the heads on the
# feature axis.
node_update_fn = lambda x: jnp.reshape(jax.nn.leaky_relu(x), (x.shape[0], -1))
def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
"""Applies a Graph Attention layer."""
nodes, edges, receivers, senders, _, _, _ = graph
try:
sum_n_node = nodes.shape[0]
except IndexError:
raise IndexError('GAT requires node features')
nodes = attention_query_fn(nodes)
total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
if add_self_edges:
receivers, senders = add_self_edges_fn(receivers, senders,
total_num_nodes)
sent_attributes = nodes[senders]
received_attributes = nodes[receivers]
att_softmax_logits = attention_logit_fn(sent_attributes,
received_attributes, edges)
att_weights = jraph.segment_softmax(
att_softmax_logits, segment_ids=receivers, num_segments=sum_n_node)
messages = sent_attributes * att_weights
nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)
nodes = node_update_fn(nodes)
return graph._replace(nodes=nodes)
return _ApplyGAT
def gat_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
""" Define GAT algorithm to run
Parameters
----------
graph: jraph.GraphsTupe, input network to be processed
Return
-------
jraph.GraphsTuple updated node graph
"""
def _attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
edges: jnp.ndarray) -> jnp.ndarray:
del edges
x = jnp.concatenate((sender_attr, receiver_attr), axis=-1)
return jax.nn.leaky_relu(hk.Linear(1)(x))
gn = GAT(
attention_query_fn=lambda n: hk.Linear(8)(n),
attention_logit_fn=_attention_logit_fn,
node_update_fn=None,
add_self_edges=True)
graph = gn(graph)
gn = GAT(
attention_query_fn=lambda n: hk.Linear(8)(n),
attention_logit_fn=_attention_logit_fn,
node_update_fn=hk.Linear(2),
add_self_edges=True)
graph = gn(graph)
return graph
Then, after defining the main GAT, I run the training as:
def run_cora(network: hk.Transformed, num_steps: int) -> jnp.ndarray:
r""" Run training on CORA dataset """
cora_graph = cora_ds[0]['input_graph']
labels = cora_ds[0]['target']
params = network.init(jax.random.PRNGKey(42), cora_graph)
@jax.jit
def predict(params: hk.Params) -> jnp.ndarray:
decoded_graph = network.apply(params, cora_graph)
return jnp.argmax(decoded_graph.nodes, axis=1)
@jax.jit
def prediction_loss(params: hk.Params) -> jnp.ndarray:
decoded_graph = network.apply(params, cora_graph)
preds = jnp.argmax(decoded_graph.nodes, axis=1)
# We interpret the decoded nodes as a pair of logits for each node.
loss = compute_bce_with_logits_loss(preds, labels)
return loss#, preds
opt_init, opt_update = optax.adam(5e-4)
opt_state = opt_init(params)
@jax.jit
def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
"""Returns updated params and state."""
g = jax.grad(prediction_loss)(params)
updates, opt_state = opt_update(g, opt_state)
return optax.apply_updates(params, updates), opt_state
@jax.jit
def accuracy(params: hk.Params) -> jnp.ndarray:
decoded_graph = network.apply(params, cora_graph)
return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)
for step in range(num_steps):
if step%100==0:
print(f"step {step} accuracy {accuracy(params).item():.2f}")
params, opt_state = update(params, opt_state)
return predict(params)
The problem is that accuracy stick to the same values throughout all the steps I am running (e.g. 1000 steps, accuracy = 0.13).
Could I ask you some indications to understand where I am wrong?
Thank you
The text was updated successfully, but these errors were encountered:
Hello,
Thanks very much for such a wonderful product! I am trying to replicate GAT's paper with the CORA dataset, but I am finding some issues in using
jraph
. I started from your example notebook, implementing GAT, along withadd_self_edges_fn
:Then, after defining the main GAT, I run the training as:
The problem is that accuracy stick to the same values throughout all the steps I am running (e.g. 1000 steps, accuracy = 0.13).
Could I ask you some indications to understand where I am wrong?
Thank you
The text was updated successfully, but these errors were encountered: