Commit ef666539 by Chien-Yu Lin Committed by Thierry Moreau

Tutorial: update Building a Graph Convolutional Network tutorial (#4060)

* update build_gcn.py tutorial

updates
* support bias in GCN layer
* download pretrained gcn model
* verify model accuracy
* use time_evaluator to measure runtime

* fix adding bias in gcn layer

* remove printing output

* fix small bug

* add DGL-PyTorch comparison into the build_gcn tutorial

* add accuracy testing

* adjust import order

* handle different dgl versions

* update number for dgl version checking
parent d69c6fd8
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
""" """
Building a Graph Convolutional Network Building a Graph Convolutional Network
===================== =====================
**Author**: `Yulun Yao <https://yulunyao.io/>`_ **Author**: `Yulun Yao <https://yulunyao.io/>`_, \
`Chien-Yu Lin <https://homes.cs.washington.edu/~cyulin/>`_
This article is an introductory tutorial to build a Graph Convolutional Network (GCN) with Relay. This article is an introductory tutorial to build a Graph Convolutional Network (GCN) with Relay.
...@@ -27,14 +28,151 @@ Cora dataset is a common benchmark for Graph Neural Networks (GNN) and framework ...@@ -27,14 +28,151 @@ Cora dataset is a common benchmark for Graph Neural Networks (GNN) and framework
We directly load the dataset from DGL library to do the apples to apples comparison against DGL. We directly load the dataset from DGL library to do the apples to apples comparison against DGL.
Please refer to DGL tutorial on installation at Please refer to DGL doc for DGL installation at
https://docs.dgl.ai/install/index.html https://docs.dgl.ai/install/index.html
GPU support and more sparse operators will soon follow. and refer to PyTorch guide for PyTorch installation at
https://pytorch.org/get-started/locally/
""" """
######################################################################
# Define GCN in DGL with PyTorch backend
# ------------------
#
# DGL example: https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn
# This part reuses the code from the above example
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GraphConv
class GCN(nn.Module):
def __init__(self,
g,
n_infeat,
n_hidden,
n_classes,
n_layers,
activation):
super(GCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
self.layers.append(GraphConv(n_infeat, n_hidden, activation=activation))
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
self.layers.append(GraphConv(n_hidden, n_classes))
def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
# handle api changes for differnt DGL version
if dgl.__version__ > '0.3':
h = layer(self.g, h)
else:
h = layer(h, self.g)
return h
######################################################################
# Define the functions to load dataset and evaluate accuracy
# ------------------
# You may substitute this part with your own dataset, here we load data from DGL
from dgl.data import load_data
from collections import namedtuple
def load_dataset(dataset="cora"):
args = namedtuple("args", ["dataset"])
data = load_data(args(dataset))
# Remove self-loops to avoid duplicate passing of a node's feature to itself
g = data.graph
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes, g.nodes))
return g, data
def evaluate(data, logits):
test_mask = data.test_mask # the test set which isn't included in the training phase
pred = logits.argmax(axis=1)
acc = ((pred == data.labels) * test_mask).sum() / test_mask.sum()
return acc
######################################################################
# Load the data and set up model parameters
# ------------------
"""
Parameters
----------
dataset: str
Name of dataset. You can choose from ['cora', 'citeseer', 'pubmed'].
num_layer: int
number of hidden layers
num_hidden: int
number of the hidden units in the hidden layer
infeat_dim: int
dimension of the input features
num_classes: int
dimension of model output (Number of classes)
"""
dataset = "cora"
g, data = load_dataset(dataset)
num_layers = 1
num_hidden = 16
infeat_dim = data.features.shape[1]
num_classes = data.num_labels
######################################################################
# Set up the DGL-PyTorch model and get the golden results
# ------------------
#
# The weights are trained with https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/train.py
from tvm.contrib.download import download_testdata
from dgl import DGLGraph
features = torch.FloatTensor(data.features)
dgl_g = DGLGraph(g)
torch_model = GCN(dgl_g,
infeat_dim,
num_hidden,
num_classes,
num_layers,
F.relu)
# Download the pretrained weights
model_url = "https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch"%(dataset)
model_path = download_testdata(model_url, "gcn_%s.pickle"%(dataset), module='gcn_model')
# Load the weights into the model
torch_model.load_state_dict(torch.load(model_path))
######################################################################
# Run the DGL model and test for accuracy
# ------------------
torch_model.eval()
with torch.no_grad():
logits_torch = torch_model(features)
print("Print the first five outputs from DGL-PyTorch execution\n", logits_torch[:5])
acc = evaluate(data, logits_torch.numpy())
print("Test accuracy of DGL results: {:.2%}".format(acc))
###################################################################### ######################################################################
# Define Graph Convolution Layer # Define Graph Convolution Layer in Relay
# ---------------------------- # ----------------------------
# To run GCN on TVM, we first need to implement Graph Convolution Layer. # To run GCN on TVM, we first need to implement Graph Convolution Layer.
# #
...@@ -49,17 +187,18 @@ GPU support and more sparse operators will soon follow. ...@@ -49,17 +187,18 @@ GPU support and more sparse operators will soon follow.
# = ((H * W)^t * A^t)^t # = ((H * W)^t * A^t)^t
# = ((W^t * H^t) * A^t)^t # = ((W^t * H^t) * A^t)^t
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime
def GraphConv( import tvm
layer_name,
input_dim, def GraphConv(layer_name,
output_dim, input_dim,
adj, output_dim,
input, adj,
activation=None, input,
norm=None, norm=None,
): bias=True,
r""" activation=None):
"""
Parameters Parameters
---------- ----------
layer_name: str layer_name: str
...@@ -81,10 +220,12 @@ def GraphConv( ...@@ -81,10 +220,12 @@ def GraphConv(
norm: relay.Expr, norm: relay.Expr,
Norm passed to this layer to normalize features before and after Convolution. Norm passed to this layer to normalize features before and after Convolution.
bias: bool
Set bias to True to add bias when doing GCN layer
activation: <function relay.op.nn>, activation: <function relay.op.nn>,
Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu} Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu}
Returns Returns
---------- ----------
output: tvm.relay.Expr output: tvm.relay.Expr
...@@ -92,42 +233,35 @@ def GraphConv( ...@@ -92,42 +233,35 @@ def GraphConv(
""" """
if norm is not None: if norm is not None:
input = relay.multiply(input, norm) input = relay.multiply(input, norm)
weight = relay.var(layer_name + "_weight", shape=(input_dim, output_dim))
weight_transposed = relay.transpose(weight) weight = relay.var(layer_name + ".weight", shape=(input_dim, output_dim))
dense = relay.nn.dense(weight_transposed, input) weight_t = relay.transpose(weight)
dense = relay.nn.dense(weight_t, input)
output = relay.nn.sparse_dense(dense, adj) output = relay.nn.sparse_dense(dense, adj)
output_transposed = relay.transpose(output) output_t = relay.transpose(output)
if norm is not None: if norm is not None:
output_transposed = relay.multiply(output_transposed, norm) output_t = relay.multiply(output_t, norm)
if bias is True:
_bias = relay.var(layer_name + ".bias", shape=(output_dim, 1))
output_t = relay.nn.bias_add(output_t, _bias, axis=-1)
if activation is not None: if activation is not None:
output_transposed = activation(output_transposed) output_t = activation(output_t)
return output_transposed return output_t
###################################################################### ######################################################################
# Load the dataset # Prepare the parameters needed in the GraphConv layers
# ------------------ # ------------------
# You may substitute this part with your own dataset, here we load data from DGL to benchmark #
import tvm, dgl, scipy
import numpy as np import numpy as np
import networkx as nx import networkx as nx
from collections import namedtuple
from dgl.data import load_data
def load_dataset(dataset="cora"):
args = namedtuple("args", ["dataset"])
dataset = load_data(args(dataset))
def prepare_params(g, data):
params = {} params = {}
params['infeats'] = dataset.features.astype('float32') # Only support float32 as feature for now params['infeats'] = data.features.astype('float32') # Only support float32 as feature for now
# Remove self-loops to avoid duplicate passing of a node's feature to itself
g = dataset.graph
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes, g.nodes))
# Generate adjacency matrix # Generate adjacency matrix
adjacency = nx.to_scipy_sparse_matrix(g) adjacency = nx.to_scipy_sparse_matrix(g)
params['data'] = adjacency.data.astype('float32') params['g_data'] = adjacency.data.astype('float32')
params['indices'] = adjacency.indices.astype('int32') params['indices'] = adjacency.indices.astype('int32')
params['indptr'] = adjacency.indptr.astype('int32') params['indptr'] = adjacency.indptr.astype('int32')
...@@ -138,145 +272,88 @@ def load_dataset(dataset="cora"): ...@@ -138,145 +272,88 @@ def load_dataset(dataset="cora"):
return params return params
###################################################################### params = prepare_params(g, data)
# Set up model Parameters
# ------------------
r"""
Parameters
----------
num_hidden: int
number of hidden layers
hidden_dim: int
input dimension of hidden layers
num_classes: int
dimension of model output (Number of classes)
target: str
currently only support llvm, GPU support will be added in next few weeks
activation: <function relay.op.nn>,
Activation function applied to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu}
dataset: str
Name of dataset. You can pick from ['cora', 'citeseer', 'pubmed'] or you can use your own.
"""
num_hidden = 1 # Check shape of features and the validity of adjacency matrix
hidden_dim = 16
num_classes = 7
target = 'llvm'
activation = relay.nn.relu
dataset = "cora"
params = load_dataset(dataset)
# Check shape of features
assert len(params['infeats'].shape) == 2 assert len(params['infeats'].shape) == 2
nnodes, input_dim = params['infeats'].shape assert params['g_data'] is not None and params['indices'] is not None and params['indptr'] is not None
assert params['infeats'].shape[0] == params['indptr'].shape[0] - 1
# Check validity of adjacency matrix
assert params['data'] is not None and params['indices'] is not None and params['indptr'] is not None
assert nnodes == params['indptr'].shape[0] - 1
###################################################################### ######################################################################
# Put layers together # Put layers together
# ------------------ # ------------------
layers = [] # Define input features, norms, adjacency matrix in Relay
infeats = relay.var("infeats", shape=data.features.shape)
# Define input features, norms, adjacency matrix
infeats = relay.var("infeats", shape=(nnodes, input_dim))
norm = relay.Constant(tvm.nd.array(params['norm'])) norm = relay.Constant(tvm.nd.array(params['norm']))
g_data = relay.Constant(tvm.nd.array(params['g_data']))
data = relay.Constant(tvm.nd.array(params['data']))
indices = relay.Constant(tvm.nd.array(params['indices'])) indices = relay.Constant(tvm.nd.array(params['indices']))
indptr = relay.Constant(tvm.nd.array(params['indptr'])) indptr = relay.Constant(tvm.nd.array(params['indptr']))
Adjacency = namedtuple('Adjacency', ['data', 'indices', 'indptr']) Adjacency = namedtuple('Adjacency', ['data', 'indices', 'indptr'])
adj = Adjacency(data, indices, indptr) adj = Adjacency(g_data, indices, indptr)
# Generate Input Layer # Construct the 2-layer GCN
layers = []
layers.append(GraphConv( layers.append(GraphConv(
layer_name= 'in', layer_name="layers.0",
input_dim= input_dim, input_dim=infeat_dim,
output_dim= hidden_dim, output_dim=num_hidden,
adj = adj, adj=adj,
input= infeats, input=infeats,
activation= activation, norm=norm,
norm= norm, activation=relay.nn.relu
)) ))
# Generate Hidden Layers
for i in range(num_hidden):
layers.append(GraphConv(
layer_name= str(i),
input_dim= hidden_dim,
output_dim= hidden_dim,
adj = adj,
input= layers[-1],
activation= activation,
norm= norm,
))
# Generate Output Layer
layers.append(GraphConv( layers.append(GraphConv(
layer_name= 'out', layer_name="layers.1",
input_dim= hidden_dim, input_dim=num_hidden,
output_dim= num_classes, output_dim=num_classes,
adj = adj, adj=adj,
input= layers[-1], input=layers[-1],
activation= activation, norm=norm,
norm= norm, activation=None
)) ))
output = layers[-1]
# Analyze free variables and generate function # Analyze free variables and generate Relay function
output = layers[-1]
func = relay.Function(relay.analysis.free_vars(output), output) func = relay.Function(relay.analysis.free_vars(output), output)
###################################################################### ######################################################################
# Compile and run # Compile and run with TVM
# ------------------ # ------------------
# We achieved 6.5x speedup for this dataset against dgl given the same model parameters. # Export the weigths from PyTorch model to Python Dict
# Output numerical difference < 10e-4 %. model_params = {}
# for param_tensor in torch_model.state_dict():
# DGL version: https://github.com/dmlc/dgl/blob/master/examples/mxnet/gcn/gcn.py model_params[param_tensor] = torch_model.state_dict()[param_tensor].numpy()
from tvm.contrib import graph_runtime
import time
# Set up weights. You can modify this part and use your own trained weights. for i in range(num_layers+1):
params['in_weight'] = np.ones((input_dim, hidden_dim), dtype='float32') params["layers.%d.weight"%(i)] = model_params["layers.%d.weight"%(i)]
params['out_weight'] = np.ones((hidden_dim, num_classes), dtype='float32') params["layers.%d.bias"%(i)] = model_params["layers.%d.bias"%(i)]
for i in range(num_hidden):
params["%s_weight"%(str(i))] = np.ones((hidden_dim, hidden_dim), dtype='float32')
# Generate graph and library # Set the TVM build target
target = 'llvm' # Currently only support `llvm` as target
# Build with Relay
with relay.build_config(opt_level=0): # Currently only support opt_level=0 with relay.build_config(opt_level=0): # Currently only support opt_level=0
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(func, target, params=params)
lib.save("lib.o")
# Generate module for llvm # Generate graph runtime
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params) m.set_input(**params)
print("finished compiling, testing inference time cost") ######################################################################
totaltime = 0 # Run the TVM model, test for accuracy and verify with DGL
for i in range(30): # ------------------
st = time.time() m.run()
# One forward pass on the entire network logits_tvm = m.get_output(0).asnumpy()
m.run() print("Print the first five outputs from TVM execution\n", logits_tvm[:5])
end = time.time()
# Retrieve output Tensor as numpy array labels = data.labels
outval = m.get_output(0).asnumpy() test_mask = data.test_mask
totaltime += (end-st) acc = evaluate(data, logits_tvm)
print("Test accuracy of TVM results: {:.2%}".format(acc))
if i == 0:
print("features of first five nodes \n %s" % outval[:5]) # Verify the results with the DGL model
if i == 4: tvm.testing.assert_allclose(logits_torch, logits_tvm, atol=1e-3)
print("5 Cycle Average Forward Pass Time ", totaltime/5)
print("30 Cycle Average Forward Pass Time ", totaltime/30)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment