What are Graph Neural Networks?

Last Updated : 15 Nov, 2024
Graph Neural Networks (GNNs) are a neural network specifically designed to work with data represented as graphs. Unlike traditional neural networks, which operate on grid-like data structures like images (2D grids) or text (sequential), GNNs can model complex, non-Euclidean relationships in data, such as social networks, molecular structures, and knowledge graphs.

This article explores what GNNs are, how they work, their types, and their wide range of applications.

Understanding Graph Structures

A graph is a data structure that represents relationships between pairs of objects. Each object is known as a node (or vertex), and each connection between nodes is an edge.


Nodes and Vertices in Graph Data Structure

Graphs can vary based on the type of edges:

  • Directed Graphs: In these graphs, edges have a direction, indicated by an arrow, showing a one-way relationship from one node to another.
  • Undirected Graphs: In these graphs, edges are bidirectional, representing two-way relationships without any direction.

Types of Graphs

Graphs can be classified based on their structural properties:

Type of GraphDescriptionExample/Use Case
Homogeneous GraphContains only one type of node and edge, simplifying representation and analysis.Social network with people as nodes and friendships as edges.
Heterogeneous GraphConsists of multiple types of nodes and edges, suitable for complex relationships.Knowledge graph with nodes for entities (people, places) and varied edges.
Static GraphRemains fixed once created, with no changes to nodes or edges.Scenarios where relationships are constant over time.
Dynamic GraphAllows the addition or deletion of nodes and edges, ideal for situations with frequently changing relationships.Real-time communication tracking or evolving social networks.

Graph Data Representation

There are multiple ways to represent graph data, each serving different use cases:

  • Adjacency Matrix: A square matrix where each cell indicates if an edge exists between two nodes; best for dense graphs.
  • Feature Matrices: Stores attributes for nodes and edges, such as age or connection strength, adding context to graph data.
  • Adjacency List: Each node lists its neighboring nodes, making it efficient for sparse graphs.
  • Edge List: Represents edges as pairs of nodes, useful in algorithms that focus on processing edges directly.

Basics of Graph Neural Networks

Graph Neural Networks (GNNs) are a class of neural networks designed specifically to work with graph-structured data. They’re used to learn patterns and relationships between connected entities within a graph, making them ideal for applications like social networks, recommendation systems, and molecular studies.

How GNNs Process Graph Data?

GNNs process graph data by passing information between nodes through their connections (edges). Each node gathers information from its neighbors in a process called message passing, allowing it to learn representations based on its local structure and features. This iterative information-sharing allows nodes to understand the context provided by surrounding nodes in the graph.

Key Concepts of GNNs

  1. Message Passing: In GNNs, each node aggregates information from its neighbors through a process called “message passing.” The node then updates its representation based on this aggregated information.
  2. Node Embeddings: The final result of message passing is an embedding or feature vector for each node. These embeddings can capture complex patterns and dependencies in the graph structure.
  3. Graph-Level Representations: In addition to node embeddings, some GNN architectures can produce representations for the entire graph, making them suitable for graph-level tasks such as graph classification.

How Do Graph Neural Networks Work?

A typical GNN operates in three steps:

  1. Initialization: Each node is initialized with its feature vector, which could represent properties like age, gender, or molecular weight, depending on the application.
  2. Message Passing: Over several iterations or “layers,” nodes exchange information with their neighbors, aggregating data from connected nodes.
  3. Update: Each node updates its feature vector using aggregated information, often by applying a neural network layer (e.g., a linear transformation followed by a non-linear activation function).

Types of Graph Neural Networks (GNNs)

1. Graph Convolutional Networks (GCN)

GCNs extend the concept of convolutional neural networks to graph data. They work by aggregating information from a node’s neighbors to update its representation. This aggregation is done in multiple layers, allowing GCNs to capture both local and global structural information in the graph. GCNs are particularly effective for semi-supervised learning tasks, such as node classification.


Graph Convolutional Networks (GCN)

2. Graph Attention Networks (GAT)

GATs introduce an attention mechanism to GNNs, allowing the model to weigh the importance of neighboring nodes differently during message passing. This adaptive focus helps GATs capture complex relationships and varying influence levels between nodes, improving performance on tasks where certain connections are more significant than others, such as in social networks or citation networks.

3. Graph Recurrent Networks (GRN)

GRNs combine the principles of recurrent neural networks (RNNs) with graph structures. They are designed to handle temporal dynamics in graph data, making them suitable for scenarios where relationships evolve over time. GRNs can effectively model sequences of graph changes, such as social interactions or traffic flow.

4. Spatial and Spectral-based GNNs

  • Spatial-based GNNs operate directly on the graph structure, focusing on the spatial relationships between nodes. They leverage the graph topology to propagate information.
  • Spectral-based GNNs utilize spectral graph theory, applying techniques from Fourier analysis to define convolutions on graphs. These networks typically operate in the spectral domain, allowing them to capture global properties of the graph.

After multiple layers of message passing, each node’s feature vector (embedding) contains not only its own information but also information about its neighbors and potentially, the whole graph structure.

Implementing Graph Neural Network for Node Classification on the Cora Dataset

Step 1: Install Required Libraries

The first lines install the required libraries for working with PyTorch and PyTorch Geometric.

!pip install torch torchvision torchaudio!pip install torch-geometric

Step 2: Import Libraries

Here, essential libraries are imported, including:

  • PyTorch and torch.nn for defining neural networks and optimizing them.
  • torch_geometric for graph neural network operations.
  • matplotlib and networkx for visualizing graphs.
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
import networkx as nx

Step 3: Define the GCN Model

The Graph Convolutional Network (GCN) model class is defined. The model has two layers:

  1. conv1: A GCN layer taking input features and transforming them to hidden channels.
  2. conv2: A GCN layer transforming hidden channels to output classes.
# Define the GCN model
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

Step 4: Load the Cora Dataset

The Planetoid dataset loader loads the Cora dataset, commonly used for graph learning tasks. The dataset includes the graph structure and node labels.

# Load the Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]  # Get the first graph object

Step 5: Initialize Model, Optimizer, and Loss Function

An instance of the GCN model is created, along with an optimizer (Adam) and loss function (Cross-Entropy Loss).

# Initialize model, optimizer, and loss function
model = GCN(in_channels=dataset.num_node_features, hidden_channels=16, out_channels=dataset.num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

Step 6: Define the Training Function

The train function performs a forward pass, calculates the loss, backpropagates the loss, and updates the model’s parameters.

# Training loop
def train():
    optimizer.zero_grad()  # Clear gradients
    out = model(data.x, data.edge_index)  # Forward pass
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights
    return loss.item()

Step 7: Define the Evaluation Function

The test function evaluates the model on the test set without calculating gradients.

# Evaluation function
def test():
    with torch.no_grad():  # No gradients needed for evaluation
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)  # Get the predicted classes
        test_accuracy = (pred[data.test_mask] == data.y[data.test_mask]).sum() / data.test_mask.sum()
        return test_accuracy.item()

Step 8: Train the Model

This loop trains the model for 200 epochs and prints the loss and test accuracy every 10 epochs.

# Run the training process
for epoch in range(200):  # Number of epochs
    loss = train()
    if epoch % 10 == 0:
        test_acc = test()
        print(f'Epoch: {epoch}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

print("Training completed.")


Epoch: 0, Loss: 1.9398, Test Accuracy: 0.3830
Epoch: 10, Loss: 0.5945, Test Accuracy: 0.7890
Epoch: 20, Loss: 0.0978, Test Accuracy: 0.7870
Epoch: 180, Loss: 0.0009, Test Accuracy: 0.7870
Epoch: 190, Loss: 0.0008, Test Accuracy: 0.7870
Training completed.

Step 9: Plot an Interactive Graph of the Cora Dataset

Using NetworkX and Plotly, an interactive graph of the Cora dataset is created with node colors representing different classes.

import numpy as np
import plotly.graph_objs as go
import plotly.offline as pyo
import networkx as nx

# Function to plot the interactive graph
def plot_interactive_graph():
    # Create a NetworkX graph from edge index
    edge_index = data.edge_index.cpu().numpy()  # Convert to NumPy array
    G = nx.Graph()

    # Add edges to the graph
    for i in range(edge_index.shape[1]):
        G.add_edge(edge_index[0, i], edge_index[1, i])

    # Node labels and colors
    labels = {i: dataset[0].y[i].item() for i in range(data.num_nodes)}  # Node labels based on class
    colors = np.array([dataset[0].y[i].item() for i in range(data.num_nodes)])  # Node colors based on class

    # Create position layout for nodes
    pos = nx.spring_layout(G, k=0.5, iterations=50)  # Positioning of nodes

    # Prepare data for Plotly
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(None)  # None to break lines between edges
        edge_y.append(None)  # None to break lines between edges

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),

    node_x = []
    node_y = []
    node_color = []
    for i in range(data.num_nodes):
        x, y = pos[i]
        # Generate monochromatic color based on the class
        node_color.append(f'rgba(0, 0, {255 - (colors[i] * 25)}, 0.7)')  # Blue shades based on class

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        text=[f'Node {i}: {labels[i]}' for i in range(data.num_nodes)]

    # Create the figure
    fig = go.Figure(data=[edge_trace, node_trace],
                        title='Interactive Graph Visualization of the Cora Dataset',
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)

    pyo.iplot(fig)  # Use pyo.plot(fig) if you're using a script outside Jupyter Notebook

# Call the function to plot the interactive graph



Interactive Graph Visualization of the Cora Dataset


Zoomed Image of the above graph visualization

Applications of Graph Neural Networks (GNNs)

  • Social Network Analysis: GNNs are used to model and analyze social interactions, helping to identify communities, predict user behavior, and enhance recommendation systems.
  • Recommendation Systems: By understanding user-item interactions as a graph, GNNs improve personalized recommendations, leveraging the relationships between users and items effectively.
  • Chemistry and Biology: GNNs are applied in drug discovery and protein-protein interaction prediction, modeling molecules as graphs to predict properties and interactions based on their structures.
  • Knowledge Graphs: GNNs enhance the processing of knowledge graphs by improving entity representation and relation prediction, facilitating better information retrieval and reasoning tasks.
  • Traffic and Transportation: GNNs help model traffic flow and predict congestion by analyzing road networks as graphs, leading to better route optimization and traffic management.


Graph Neural Networks have emerged as a powerful tool for processing and analyzing graph-structured data, finding applications across diverse fields such as social networks, biology, and transportation. As research continues to advance, addressing these limitations will enhance the effectiveness of GNNs, expanding their applicability and performance in real-world scenarios.

FAQs on Graph Neural Networks

What is Graph Neural Network in Machine Learning?

Graph Neural Network in machine learning is used to perform operations on data represented by graphs. To establish relationships, analyze the data, and many more operations are performed. It helps in machine learning as well as deep learning.

How can you graph a neural network?

Graph Neural Networks are similar to basic Neural Networks but they have some advanced options in terms of network. In it, there is a class concept that helps to do node-level, edge-level, and graph-level predictions with very less effort. 

What are the types of Neural Graph Networks?

There are three types of Graphical Neural Networks. The first one is Recurrent Graph Neural Networks, the second is Spatial Convolutional Networks, and the last one is Spectral Convolutional Networks.

