Graph Neural Networks for Outfit prediction - Part 1
Learn how Graph Neural Networks (GNN) can be used to predict a complimentary product.
Introduction to Heterogeneous Graph Neural Networks with PyTorch Geometric
Graph Neural Networks (GNNs) have revolutionized the way we analyze graph-structured data. However, many real-world graphs are not simple, homogeneous structures. They are often heterogeneous, containing different types of nodes and edges. In this blog post, we’ll explore how to work with these complex graphs using PyTorch Geometric (PyG), a powerful library for deep learning on graphs.
What are Heterogeneous Graphs?
A homogeneous graph has only one type of node and one type of edge. A social network where “users” are connected by “friendship” is a classic example.
In contrast, a heterogeneous graph can have multiple types of nodes and edges. Consider an e-commerce network: you might have “customers,” “products,” and “outfits” as different node types. The relationships between them could be “buys” (customer -> product), “belongs_to” (product -> category), and “part_of” (product -> outfit).
Representing Heterogeneous Graphs in PyG
PyTorch Geometric provides the HeteroData class to represent heterogeneous graphs. It’s a dictionary-like object that stores the nodes and edges for each type.
Let’s look at an example from our project and build a graph using the data from Farfetch competition.
TODO: add build graph code
We load a pre-built HeteroData object
import torch
from torch_geometric.data import HeteroData
graph = torch.load("../data/graph/graph.pt")
HeteroData(
category={
num_nodes=129,
node_id=[129],
},
product={
num_nodes=398670,
node_id=[398670],
category_id=[398670],
family_id=[398670],
},
(product, cat_prod_link, category)={ edge_index=[2, 398670] },
(product, product_outfit, product)={ edge_index=[2, 3280676] },
(category, rev_cat_prod_link, product)={ edge_index=[2, 398670] }
)
We can inspect the metadata of our graph to understand its structure:
print(graph.metadata())
This will output the different node types and edge types in our graph, for example:
(['category', 'product'], # Node Types
[('product', 'cat_prod_link', 'category'), # Edge Types
('product', 'product_outfit', 'product'),
('category', 'rev_cat_prod_link', 'product')]))
This tells us we have two node types (category, product) and three edge types.
num_node_dict = {}
for node_type in graph.node_types:
num_node_dict[node_type] = graph[node_type].num_nodes
num_node_dict
> {'category': 129, 'product': 398670}
Building a GNN for a Heterogeneous Graph
So, how do we build a GNN for a heterogeneous graph? Do we need a special type of GNN layer? Not necessarily! PyG provides a powerful to_hetero function that can convert a homogeneous GNN into a heterogeneous one.
Let’s first define a standard GNN layers which are designed for homogeneous graphs.
from torch_geometric.nn import SAGEConv, GATv2Conv
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels=256, out_channels=128, heads=5, layer_type='GAT'):
super().__init__()
if layer_type == "SAGEConv":
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
if layer_type == "GAT":
self.conv1 = GATv2Conv(in_channels, hidden_channels, heads=heads, add_self_loops=False, dropout=0.2)
self.conv2 = GATv2Conv(hidden_channels*heads, out_channels, heads=1, concat=False, add_self_loops=False, dropout=0.2)
self.relu = torch.nn.ReLU()
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
x = self.conv1(x, edge_index)
x = self.relu(x)
x = self.conv2(x, edge_index)
return x
class Predictor(torch.nn.Module):
def __init__(self, out_channels):
super().__init__()
self.layer_class = torch.nn.Linear(out_channels, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
return self.layer_class(x)
Now, we can use to_hetero to wrap this GNN and make it compatible with our HeteroData object. This layer will be applied on all of the different edge types in the graph. For example in our graph we have 3 edge types so it works as follows
message_passing = {
'cat_prod_link': GAT(),
'product_outfit': GAT(),
'rev_cat_prod_link': GAT()
}
aggregated_message_passing = aggregate_fn(message_passing)
from torch_geometric.nn import to_hetero
class Model(torch.nn.Module):
def __init__(self, hidden_channels, num_nodes_dict, graph_metadata):
super().__init__()
# ... (embedding layers) ...
# Instantiate homogeneous GNN:
self.gnn = GNN(hidden_channels)
# Convert GNN model into a heterogeneous variant:
self.gnn = to_hetero(self.gnn, metadata=graph_metadata, aggr="sum")
# ... (classifier) ...
def forward(self, data: HeteroData, edge_types_to_predict: List) -> Tensor:
# ... (forward pass logic) ...
The to_hetero function takes our GNN model and the metadata of our heterogeneous graph. It automatically handles the message passing between different node types. The aggr argument specifies how to aggregate messages from different edge types. In this case, we’re using a simple “sum” aggregation.
The Model Class: Putting it all Together
Our final Model class encapsulates the entire process:
- Node Embeddings: We create embedding layers for each node type (
family,category,product). This is necessary because our nodes don’t have initial features. - Heterogeneous GNN: We use the
to_hetero-wrapped GNN to perform message passing and generate node embeddings. - Classifier: A simple classifier takes the final node embeddings and predicts the existence of an edge.
Here’s a simplified view of the Model class:
class Model(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads, num_nodes_dict, graph_metadata, layer_type='GAT', device='cpu'):
super().__init__()
self.category_embd = torch.nn.Embedding(num_nodes_dict["category"], in_channels)
pembds = pd.read_parquet("product_text_embedding.parquet")
pembds = torch.from_numpy(pembds.drop("product_node_id", axis=1).values)
self.product_emb = torch.nn.Embedding.from_pretrained(pembds, freeze=True)
del pembds
self.layer = torch.nn.Linear(1024, in_channels)
# Instantiate homogeneous GNN:
self.gnn = GNN(in_channels=in_channels, hidden_channels=hidden_channels, out_channels=out_channels, heads=heads, layer_type=layer_type)
# Convert GNN model into a heterogeneous variant:
self.gnn = to_hetero(self.gnn, metadata=graph_metadata, aggr="sum")
self.pred_layer = Predictor(out_channels=out_channels)
#self.l_func = torch.nn.TripletMarginLoss(margin=1.0, p=2.0)
self.l_func = torch.nn.BCEWithLogitsLoss()
self.alpha = 0.5
self.device = device
def inference(self, batch: HeteroData):
x_dict = {
"category": self.category_embd(batch['category'].node_id),
"product": self.layer(self.product_emb(batch['product'].node_id))
}
# `x_dict` holds feature matrices of all node types
# `edge_index_dict` holds all edge indices of all edge types
x_dict = self.gnn(x_dict, batch.edge_index_dict)
return x_dict
def forward(self, batch: HeteroData, edge_types_to_predict: List) -> tuple[dict,Tensor]:
x_dict = self.inference(batch)
total_loss = 0.0
for edge_type in edge_types_to_predict:
#print(edge_type)
# Get the source and destination node types from the edge type tuple
src_node_type, _, dst_node_type = edge_type
u = batch[src_node_type]['src_index']
v = batch[dst_node_type]['dst_pos_index']
vn = batch[dst_node_type]['dst_neg_index']
u_feats = x_dict['product'][u]
v_feats = x_dict['product'][v]
vn_feats = x_dict['product'][vn]
#total_loss += self.l_func(u_feats, v_feats, vn_feats)
pscore = self.pred_layer(u_feats * v_feats)
nscore = self.pred_layer(u_feats * vn_feats)
pl = torch.ones_like(pscore)
nl = torch.zeros_like(nscore)
y_true = torch.cat((pl, nl))
y_pred = torch.cat((pscore, nscore))
total_loss += self.l_func(y_pred, y_true)
return total_loss
(layer): Linear(in_features=1024, out_features=256, bias=True)
(gnn): GraphModule(
(conv1): ModuleDict(
(product__cat_prod_link__category): GATv2Conv(256, 128, heads=3)
(product__product_outfit__product): GATv2Conv(256, 128, heads=3)
(category__rev_cat_prod_link__product): GATv2Conv(256, 128, heads=3)
)
(relu): ModuleDict(
(category): ReLU()
(product): ReLU()
)
(conv2): ModuleDict(
(product__cat_prod_link__category): GATv2Conv(384, 128, heads=1)
(product__product_outfit__product): GATv2Conv(384, 128, heads=1)
(category__rev_cat_prod_link__product): GATv2Conv(384, 128, heads=1)
)
)
(pred_layer): Predictor(
(layer_class): Linear(in_features=128, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
)
model parameters: 1186433
Conclusion
Heterogeneous graphs are a powerful way to model complex, real-world systems. With PyTorch Geometric’s HeteroData and to_hetero, we can easily adapt our existing GNN architectures to handle this complexity. This allows us to build sophisticated models for tasks like link prediction and node classification in heterogeneous domains.