Getting Started with PyTorch Geometric: Graph Neural Networks in PyTorch
Most machine learning works on grids (images) or sequences (text), but a lot of real-world data is naturally a graph: people connected to friends, atoms bonded into molecules, transactions linking accounts. PyTorch Geometric (PyG) is a library built on top of PyTorch that makes it practical to train neural networks directly on this kind of connected data. This tutorial assumes you are comfortable with basic PyTorch (tensors, nn.Module, a training loop) but have not worked with graph neural networks before.
What Is Graph Machine Learning?
A graph is a set of nodes (also called vertices) connected by edges. Both nodes and edges can carry features. The structure itself — who is connected to whom — carries information that a plain feed-forward network would ignore if you flattened everything into a table.
Graph machine learning shows up in many production settings:
- Social and recommendation graphs. Users, items, and their interactions form a graph. Predicting which item a user will click is a link prediction problem; classifying a user as a bot is node classification.
- Molecules and materials. Atoms are nodes, chemical bonds are edges. Predicting a molecule's solubility or toxicity is a graph-level regression or classification task.
- Fraud and risk. Accounts, devices, and transactions form a graph. Fraud rings are dense subgraphs that are nearly invisible when each transaction is scored independently.
- Knowledge graphs. Entities and typed relations (for example "Paris — capitalof — France") support reasoning and retrieval. These are typically heterogeneous graphs with many node and edge types.
The common theme is that a prediction about one element depends on its neighborhood, not just its own features.
The Intuition: Message Passing
Almost every modern GNN is a form of message passing (sometimes called neighborhood aggregation or graph convolution). The idea is simple and repeats over several layers:
After one layer, every node's representation reflects its immediate neighbors. After two layers, it reflects neighbors-of-neighbors, and so on. Stacking k layers lets information flow across a k-hop neighborhood.
Formally, a generic message passing layer updates node i as:
hi' = UPDATE( hi , AGGREGATE( { MESSAGE(hi, hj, eij) : j in N(i) } ) )
where N(i) are the neighbors of i, and MESSAGE, AGGREGATE, and UPDATE are the parts each GNN variant defines differently. Graph Convolutional Networks (GCN) use a normalized sum, GraphSAGE samples and concatenates, and Graph Attention Networks (GAT) learn weights for each neighbor. We will see these shortly.
Installation
PyG runs on top of PyTorch, so install a matching PyTorch build first. For a CPU-only setup the basics are:
pip install torch
pip install torchgeometric
For most node- and graph-level tutorials this is enough — recent PyG versions implement the core operations in pure PyTorch.
PyG also offers optional compiled extension wheels (pyglib, torchscatter, torchsparse, torchcluster, torchsplineconv). These accelerate scatter/gather and sparse operations and are required for some advanced features (certain samplers and cluster routines). They must match both your PyTorch version and your CUDA build, so install them from the PyG wheel index rather than from PyPI:
# Replace 2.4.0 and cu121 with your own torch version and CUDA tag.
Use cpu instead of cu121 for a CPU-only machine.
pip install pyglib torchscatter torchsparse torchcluster torchsplineconv \
-f https://data.pyg.org/whl/torch-2.4.0+cu121.html
To find the right tags, check your installed versions:
import torch
print(torch.version) # e.g. 2.4.0
print(torch.version.cuda) # e.g. 12.1, or None for CPU