Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replicating GAT with CORA dataset #26

Open
Steboss89 opened this issue Mar 18, 2022 · 0 comments
Open

Replicating GAT with CORA dataset #26

Steboss89 opened this issue Mar 18, 2022 · 0 comments

Comments

@Steboss89
Copy link

Hello,

Thanks very much for such a wonderful product! I am trying to replicate GAT's paper with the CORA dataset, but I am finding some issues in using jraph . I started from your example notebook, implementing GAT, along with add_self_edges_fn:

def add_self_edges_fn(receivers: jnp.ndarray,
                      senders: jnp.ndarray,
                      total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    r"""Adds self edges. Assumes self edges are not in the graph yet."""
    receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
    senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
    return receivers, senders
  
def GAT(attention_query_fn: Callable,
        attention_logit_fn: Callable,
        node_update_fn: Optional[Callable] = None,
        add_self_edges: bool = True) -> Callable:
    r""" Main GAT function"""
    # pylint: disable=g-long-lambda
    if node_update_fn is None:
        # By default, apply the leaky relu and then concatenate the heads on the
        # feature axis.
        node_update_fn = lambda x: jnp.reshape(jax.nn.leaky_relu(x), (x.shape[0], -1))

    def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Applies a Graph Attention layer."""
        nodes, edges, receivers, senders, _, _, _ = graph
        
        try:
            sum_n_node = nodes.shape[0]
        except IndexError:
            raise IndexError('GAT requires node features')

        nodes = attention_query_fn(nodes)
        total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]

        if add_self_edges:
            receivers, senders = add_self_edges_fn(receivers, senders,
                                                    total_num_nodes)
        sent_attributes = nodes[senders]
        received_attributes = nodes[receivers]
        att_softmax_logits = attention_logit_fn(sent_attributes,
                                                received_attributes, edges)

        att_weights = jraph.segment_softmax(
            att_softmax_logits, segment_ids=receivers, num_segments=sum_n_node)

        messages = sent_attributes * att_weights

        nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)

        nodes = node_update_fn(nodes)

        return graph._replace(nodes=nodes)

    return _ApplyGAT


def gat_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """ Define GAT algorithm to run 
    Parameters
    ----------
    graph: jraph.GraphsTupe, input network to be processed 
    
    Return 
    -------
    jraph.GraphsTuple updated node graph
    """

    def _attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
                            edges: jnp.ndarray) -> jnp.ndarray:
        del edges
        x = jnp.concatenate((sender_attr, receiver_attr), axis=-1)
        return jax.nn.leaky_relu(hk.Linear(1)(x))

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=None,
        add_self_edges=True)
    graph = gn(graph)

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=hk.Linear(2),
        add_self_edges=True)
    graph = gn(graph)
    return graph

Then, after defining the main GAT, I run the training as:


def run_cora(network: hk.Transformed, num_steps: int) -> jnp.ndarray:
  r""" Run training on CORA dataset """
  cora_graph = cora_ds[0]['input_graph']
  labels = cora_ds[0]['target']
  params = network.init(jax.random.PRNGKey(42), cora_graph)

  @jax.jit
  def predict(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.argmax(decoded_graph.nodes, axis=1)

  @jax.jit
  def prediction_loss(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    preds = jnp.argmax(decoded_graph.nodes, axis=1)
    # We interpret the decoded nodes as a pair of logits for each node.
    loss = compute_bce_with_logits_loss(preds, labels)
    return loss#, preds

  opt_init, opt_update = optax.adam(5e-4)
  opt_state = opt_init(params)

  @jax.jit
  def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
    """Returns updated params and state."""
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)

  for step in range(num_steps):
    if step%100==0:
        print(f"step {step} accuracy {accuracy(params).item():.2f}")
    params, opt_state = update(params, opt_state)

  return predict(params)

The problem is that accuracy stick to the same values throughout all the steps I am running (e.g. 1000 steps, accuracy = 0.13).
Could I ask you some indications to understand where I am wrong?
Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant