Commit 34d74282 by Tianqi Chen

[TUTORIAL] Move mobilenet to tutorial, fix precompute_prune (#35)

* [TUTORIAL] Move mobilenet to tutorial, fix precompute_prune

* Some language improvements
parent 12fa9148
doxygen
_build
gen_modules
tutorials
The documentation of nnvm is generated with recommonmark and sphinx.
- pip install sphinx>=1.5.5 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark
- Build tvm first in the root folder.
......@@ -15,6 +15,7 @@ import sys
import os, subprocess
import shlex
import recommonmark
import sphinx_gallery
from recommonmark.parser import CommonMarkParser
from recommonmark.transform import AutoStructify
......@@ -50,7 +51,8 @@ extensions = [
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinx.ext.mathjax'
'sphinx.ext.mathjax',
'sphinx_gallery.gen_gallery',
]
# Add any paths that contain templates here, relative to this directory.
......@@ -129,7 +131,7 @@ if not on_rtd and html_theme == 'rtd':
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# html_static_path = ['_static']
# Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc'
......@@ -164,9 +166,17 @@ intersphinx_mapping = {
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
'matplotlib': ('http://matplotlib.org/', None),
'tvm': ('http://docs.tvmlang.org/', None),
}
from sphinx_gallery.sorting import ExplicitOrder
examples_dirs = ['../tutorials/']
gallery_dirs = ['tutorials']
subsection_order = ExplicitOrder([])
def generate_doxygen_xml(app):
"""Run the doxygen make commands if we're on the ReadTheDocs server"""
run_doxygen('..')
......@@ -180,3 +190,19 @@ def setup(app):
'auto_doc_ref': True
}, True)
app.add_transform(AutoStructify)
sphinx_gallery_conf = {
'backreferences_dir': 'gen_modules/backreferences',
'doc_module': ('tvm', 'nnvm', 'numpy'),
'reference_url': {
'nnvm': None,
'tvm': 'http://docs.tvmlang.org',
'numpy': 'http://docs.scipy.org/doc/numpy-1.9.1'},
'examples_dirs': examples_dirs,
'gallery_dirs': gallery_dirs,
'subsection_order': subsection_order,
'find_mayavi_figures': False,
'filename_pattern': '.py',
'expected_failing_examples': []
}
NNVM Design Note
================
Design Note
===========
In this part of documentation, we share the rationale for the specific choices made when designing NNVM.
......
......@@ -10,4 +10,5 @@ Contents
self
top
tutorials/index
dev/index
NNVM Core Tensor Operators
==========================
Core Tensor Operators
=====================
This page contains the list of core tensor operator primitives re-defined in NNVM.
The core tensor operator primitives(``nnvm.top``) covers typical workloads in deep learning.
......
"""Forward propagation of MobileNet on GPU."""
import numpy as np
import time
import os
import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from tvm.contrib import nvcc
TASK="mobilenet"
target = 'cuda'
ctx = tvm.gpu(0)
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_60"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
return code
dtype = 'float32'
epsilon = 1e-10 + 1e-5
def conv_block(data, name, channels, kernel_size=(3,3), strides=(1,1), padding=(1,1)):
# convolution + bn + relu
conv = sym.conv2d(data=data, channels=channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout='NCHW', name=name + '_conv')
bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + '_bn')
act = sym.relu(data=bn, name=name + '_relu')
return act
def separable_conv_block(data, name, depthwise_channels, pointwise_channels, kernel_size=(3,3), downsample=False, padding=(1,1)):
if downsample:
strides = (2,2)
else:
strides = (1,1)
# depthwise convolution + bn + relu
conv1 = sym.conv2d(data=data, channels=depthwise_channels, groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout='NCHW', name=name + '_conv1')
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
# pointwise convolution + bn + relu
conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1,1), strides=(1,1),
padding=(0,0), use_bias=False, layout='NCHW', name=name + '_conv2')
bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
return act2
def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
data = sym.Variable("data")
body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2,2))
body = separable_conv_block(body, 'separable_conv_block_1', int(32*alpha), int(64*alpha))
body = separable_conv_block(body, 'separable_conv_block_2', int(64*alpha), int(128*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_3', int(128*alpha), int(128*alpha))
body = separable_conv_block(body, 'separable_conv_block_4', int(128*alpha), int(256*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_5', int(256*alpha), int(256*alpha))
body = separable_conv_block(body, 'separable_conv_block_6', int(256*alpha), int(512*alpha), downsample=True)
if is_shallow:
body = separable_conv_block(body, 'separable_conv_block_7', int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_8', int(1024*alpha), int(1024*alpha))
else:
for i in range(7, 12):
body = separable_conv_block(body, 'separable_conv_block_%d' % i, int(512*alpha), int(512*alpha))
body = separable_conv_block(body, 'separable_conv_block_12', int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_13', int(1024*alpha), int(1024*alpha))
pool = sym.global_avg_pool2d(data=body, name='pool')
flatten = sym.flatten(data=pool, name='flatten')
fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name='fc')
softmax = sym.softmax(data=fc, name='softmax')
return softmax
batch_size = 1
num_classes = 1000
image_shape = (3,224,224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
# build graph
with nnvm.compiler.build_config(opt_level=2):
graph, lib, _ = nnvm.compiler.build(net, target, {'data': data_shape})
# prepare params
params = {}
names = graph.index.input_names
shapes = [graph.json_attr("shape")[graph.index.entry_id(x)] for x in names]
for i in range(len(names)):
params[names[i]] = tvm.nd.array(np.random.uniform(-0.1, 0.1, size=shapes[i]).astype(dtype), ctx=ctx)
# create runtime module
module = nnvm.runtime.create(graph, lib, ctx)
# set input
module.set_input(**params)
# run
print("run")
module.run()
ctx.sync()
start = time.time()
for i in range(1000):
module.run()
ctx.sync()
print("average time cost of 1000 runs = %g ms" % ((time.time() - start)))
# get output
out = module.get_output(0, tvm.nd.empty(out_shape, dtype))
NNVM Examples
=============
This folder contains example snippets of running NNVM Compilation.
- See also [Tutorials](tutorials) for tutorials with detailed explainations.
"""Utilities for testcase"""
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
from .config import ctx_list
from . import mobilenet
"""Configuration about tests"""
from __future__ import absolute_import as _abs
import os
import tvm
......
"""Helper utility to get mobilenet workload for testing."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from .. compiler import graph_util
from .. import graph
from .. import symbol as sym
def conv_block(data, name, channels,
kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
epsilon=1e-5):
"""Helper function to construct conv-bn-relu"""
# convolution + bn + relu
conv = sym.conv2d(data=data, channels=channels,
kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False,
layout="NCHW", name=name + "_conv")
bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + "_bn")
act = sym.relu(data=bn, name=name + "_relu")
return act
def separable_conv_block(data, name, depthwise_channels,
pointwise_channels, kernel_size=(3, 3),
downsample=False, padding=(1, 1),
epsilon=1e-5):
"""Helper function to get a separable conv block"""
if downsample:
strides = (2, 2)
else:
strides = (1, 1)
# depthwise convolution + bn + relu
conv1 = sym.conv2d(data=data, channels=depthwise_channels,
groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout="NCHW", name=name + "_conv1")
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1")
act1 = sym.relu(data=bn1, name=name + "_relu1")
# pointwise convolution + bn + relu
conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1, 1), strides=(1, 1),
padding=(0, 0), use_bias=False, layout="NCHW", name=name + "_conv2")
bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + "_bn2")
act2 = sym.relu(data=bn2, name=name + "_relu2")
return act2
def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
"""Function to construct a MobileNet"""
data = sym.Variable("data")
body = conv_block(data, "conv_block_1", int(32*alpha), strides=(2, 2))
body = separable_conv_block(body, "separable_conv_block_1",
int(32*alpha), int(64*alpha))
body = separable_conv_block(body, "separable_conv_block_2",
int(64*alpha), int(128*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_3",
int(128*alpha), int(128*alpha))
body = separable_conv_block(body, "separable_conv_block_4",
int(128*alpha), int(256*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_5",
int(256*alpha), int(256*alpha))
body = separable_conv_block(body, "separable_conv_block_6",
int(256*alpha), int(512*alpha), downsample=True)
if is_shallow:
body = separable_conv_block(body, "separable_conv_block_7",
int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_8",
int(1024*alpha), int(1024*alpha))
else:
for i in range(7, 12):
body = separable_conv_block(body, "separable_conv_block_%d" % i,
int(512*alpha), int(512*alpha))
body = separable_conv_block(body, "separable_conv_block_12",
int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_13",
int(1024*alpha), int(1024*alpha))
pool = sym.global_avg_pool2d(data=body, name="pool")
flatten = sym.flatten(data=pool, name="flatten")
fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name="fc")
softmax = sym.softmax(data=fc, name="softmax")
return softmax
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for mobilenet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
params = {}
g = graph.create(net)
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes))
for k, v in shape_dict.items():
if k == "data":
continue
# Specially generate non-negative parameters.
if k.endswith("gamma"):
init = np.random.uniform(0.9, 1, size=v)
elif k.endswith("var"):
init = np.random.uniform(0.9, 1, size=v)
else:
init = np.random.uniform(-0.1, 0.1, size=v)
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
return net, params
......@@ -44,17 +44,17 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
} else {
// scan again to find edge nodes, skip variables
for (auto& e : n->inputs) {
if (!e.node->is_variable() && pruned.count(e.node.get())) {
if (pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index);
var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) {
var->attrs.name += "_output" + std::to_string(e.index);
}
entry_var.emplace(e, var);
CHECK(!unique_name.count(var->attrs.name));
unique_name.insert(var->attrs.name);
}
// TODO(ziheng): this pass now mutates the original graph structure
// This might not be a good thing, change to copy the structure instead
//
e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
}
}
......@@ -67,7 +67,6 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
output_names.reserve(entry_var.size());
for (auto kv : entry_var) {
if (kv.first.node->is_variable()) continue;
pre_graph.outputs.emplace_back(kv.first);
output_names.emplace_back(kv.second->attrs.name);
}
......
......@@ -55,26 +55,28 @@ def test_run():
def test_precompute_prune():
x = sym.Variable("x") + 1
a = sym.Variable("a")
y = sym.Variable("y")
z = y + x
z = y + x + a
shape = (10, 10)
dtype = tvm.float32
nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
params = {"x": nx}
params = {"x": nx, "a": na}
graph, lib, params = nnvm.compiler.build(
z, "llvm", shape={"y": ny.shape}, params=params)
assert graph.index.num_nodes == 3
assert graph.index.num_nodes == 4
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
params["y"] = ny
res = tvm.nd.empty(shape)
m.run(**params)
out = m.get_output(0, out=res)
np.testing.assert_allclose(
res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy())
res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy())
if __name__ == "__main__":
test_precompute_prune()
test_compile()
test_run()
test_precompute_prune()
Tutorials
=========
This page contains the tutorials about NNVM.
"""
Compile MobileNet Inference on GPU
==================================
**Author**: `Yuwei Hu <https://huyuwei.github.io/>`_
This is an example of using NNVM to compile MobileNet model and deploy its inference on GPU.
To begin with, we import nnvm(for compilation) and TVM(for deployment).
"""
import tvm
import nnvm.compiler
import nnvm.runtime
import nnvm.testing
from tvm.contrib import nvcc
######################################################################
# Register the NVCC Compiler Option
# ---------------------------------
# NNVM optimizes the graph and relies on TVM to generate fast
# GPU code, to get the maximum performance, we need to enable
# nvcc's compiler hook. This gives better performance than nvrtc mode.
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
return ptx
######################################################################
# Prepare the Benchmark
# ---------------------
# We construct a standard imagenet inference benchmark.
# We use nnvm's testing utility to produce the model description and random parameters that so the example does not
# depend on a specific front-end framework.
#
# .. note::
#
# In a typical workflow, we can get this pair from :any:`nnvm.frontend`
#
target = "cuda"
ctx = tvm.gpu(0)
batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
net, params = nnvm.testing.mobilenet.get_workload(
batch_size=1, image_shape=image_shape)
######################################################################
# Compile The Graph
# -----------------
# NNVM needs two things to compile a deep learning model:
#
# - net which is the graph representation of the computation
# - params a dictionary of str to parameters.
#
# To compile the graph, we call the build function with the graph
# configuration and parameters.
# When parameters are provided, NNVM will pre-compute certain part of the graph if possible,
# the new parameter set returned as the third return value.
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
######################################################################
# Run the Compiled Module
# -----------------------
#
# To deploy the module, we call :any:`nnvm.runtime.create` passing in the graph the lib and context.
# Thanks to TVM, we can deploy the compiled module to many platforms and languages.
# The deployment module is designed to contain minimum dependencies.
# This example runs on the same machine.
module = nnvm.runtime.create(graph, lib, ctx)
# set input
module.set_input(**params)
# run
module.run()
# get output
out = module.get_output(0, tvm.nd.empty(out_shape))
# Convert to numpy
out.asnumpy()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment