Commit cba957e0 by Yuwei HU Committed by Tianqi Chen

[TOP][Example] register pool, global_pool; add mobilenet example (#32)

* register pool, global_pool; add mobilenet example

* tests of pool and global_pool

* use new API of runtime module

* small fix
parent 942d8b0e
"""Forward propagation of MobileNet on GPU."""
import numpy as np
import time
import os
import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from tvm.contrib import nvcc
TASK="mobilenet"
target = 'cuda'
ctx = tvm.gpu(0)
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_60"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
return code
dtype = 'float32'
epsilon = 1e-10 + 1e-5
def conv_block(data, name, channels, kernel_size=(3,3), strides=(1,1), padding=(1,1)):
# convolution + bn + relu
conv = sym.conv2d(data=data, channels=channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout='NCHW', name=name + '_conv')
bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + '_bn')
act = sym.relu(data=bn, name=name + '_relu')
return act
def separable_conv_block(data, name, depthwise_channels, pointwise_channels, kernel_size=(3,3), downsample=False, padding=(1,1)):
if downsample:
strides = (2,2)
else:
strides = (1,1)
# depthwise convolution + bn + relu
conv1 = sym.conv2d(data=data, channels=depthwise_channels, groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout='NCHW', name=name + '_conv1')
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
# pointwise convolution + bn + relu
conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1,1), strides=(1,1),
padding=(0,0), use_bias=False, layout='NCHW', name=name + '_conv2')
bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
return act2
def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
data = sym.Variable("data")
body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2,2))
body = separable_conv_block(body, 'separable_conv_block_1', int(32*alpha), int(64*alpha))
body = separable_conv_block(body, 'separable_conv_block_2', int(64*alpha), int(128*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_3', int(128*alpha), int(128*alpha))
body = separable_conv_block(body, 'separable_conv_block_4', int(128*alpha), int(256*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_5', int(256*alpha), int(256*alpha))
body = separable_conv_block(body, 'separable_conv_block_6', int(256*alpha), int(512*alpha), downsample=True)
if is_shallow:
body = separable_conv_block(body, 'separable_conv_block_7', int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_8', int(1024*alpha), int(1024*alpha))
else:
for i in range(7, 12):
body = separable_conv_block(body, 'separable_conv_block_%d' % i, int(512*alpha), int(512*alpha))
body = separable_conv_block(body, 'separable_conv_block_12', int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_13', int(1024*alpha), int(1024*alpha))
pool = sym.global_avg_pool2d(data=body, name='pool')
flatten = sym.flatten(data=pool, name='flatten')
fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name='fc')
softmax = sym.softmax(data=fc, name='softmax')
return softmax
batch_size = 1
num_classes = 1000
image_shape = (3,224,224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
# build graph
with nnvm.compiler.build_config(opt_level=2):
graph, lib, _ = nnvm.compiler.build(net, target, {'data': data_shape})
# prepare params
params = {}
names = graph.index.input_names
shapes = [graph.json_attr("shape")[graph.index.entry_id(x)] for x in names]
for i in range(len(names)):
params[names[i]] = tvm.nd.array(np.random.uniform(-0.1, 0.1, size=shapes[i]).astype(dtype), ctx=ctx)
# create runtime module
module = nnvm.runtime.create(graph, lib, ctx)
# set input
module.set_input(**params)
# run
print("run")
module.run()
ctx.sync()
start = time.time()
for i in range(1000):
module.run()
ctx.sync()
print("average time cost of 1000 runs = %g ms" % ((time.time() - start)))
# get output
out = module.get_output(0, tvm.nd.empty(out_shape, dtype))
......@@ -202,7 +202,6 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
TShape pool_size;
TShape strides;
TShape padding;
int groups;
int layout;
bool ceil_mode;
......@@ -214,12 +213,6 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
DMLC_DECLARE_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
......
......@@ -18,6 +18,7 @@ def compute_relu(attrs, inputs, _):
reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEMWISE)
# leaky_relu
@reg.register_compute("leaky_relu")
def compute_leaky_relu(attrs, inputs, _):
......@@ -27,6 +28,7 @@ def compute_leaky_relu(attrs, inputs, _):
reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
# flatten
@reg.register_compute("flatten")
def compute_flatten(attrs, inputs, _):
......@@ -73,11 +75,10 @@ def schedule_dense(_, outs, target):
# naive schedule
return tvm.create_schedule([x.op for x in outs])
# register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)
# conv
# conv2d
@reg.register_compute("conv2d")
def compute_conv2d(attrs, inputs, _):
"""Compute definition of conv2d"""
......@@ -113,3 +114,89 @@ def schedule_conv2d(attrs, outs, target):
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d
@reg.register_compute("max_pool2d")
def compute_max_pool2d(attrs, inputs, _):
"""Compute definition of max_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='max')
@reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target):
"""Schedule definition of max_pool2d"""
if target == "cuda":
return topi.cuda.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool2d
@reg.register_compute("avg_pool2d")
def compute_avg_pool2d(attrs, inputs, _):
"""Compute definition of avg_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='avg')
@reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target):
"""Schedule definition of avg_pool2d"""
if target == "cuda":
return topi.cuda.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# global_max_pool2d
@reg.register_compute("global_max_pool2d")
def compute_global_max_pool2d(attrs, inputs, _):
"""Compute definition of global_max_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='max')
@reg.register_schedule("global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d"""
if target == "cuda":
return topi.cuda.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# global_avg_pool2d
@reg.register_compute("global_avg_pool2d")
def compute_global_avg_pool2d(attrs, inputs, _):
"""Compute definition of global_avg_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='avg')
@reg.register_schedule("global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d"""
if target == "cuda":
return topi.cuda.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......@@ -16,7 +16,6 @@ def test_relu():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
data = (data < 0) * data * 0.3 + (data>0) * data - 0.2
......@@ -34,17 +33,10 @@ def test_exp():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.exp(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = np.exp(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
......@@ -58,17 +50,10 @@ def test_log():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.log(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = np.log(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
......@@ -82,17 +67,10 @@ def test_tanh():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = np.sinh(data) / np.cosh(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
......@@ -105,17 +83,10 @@ def test_sigmoid():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = 1.0 / (1.0 + np.exp(-data.asnumpy()))
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = 1.0 / (1.0 + np.exp(-data))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
......@@ -129,17 +100,10 @@ def test_softmax():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = topi.testing.softmax_python(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = topi.testing.softmax_python(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
......
......@@ -10,8 +10,8 @@ from nnvm.testing.config import ctx_list
def test_conv2d():
x = sym.Variable("x")
y = sym.conv2d(x, channels=10, kernel_size=(3, 3),
name="y", use_bias=False, padding=(1,1))
y = sym.conv2d(x, channels=10, kernel_size=(3,3),
name="y", padding=(1,1))
dtype = "float32"
dshape = (1, 3, 18, 18)
kshape = (10, 3, 3, 3)
......@@ -20,26 +20,20 @@ def test_conv2d():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("x", data)
set_input("y_weight", kernel)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.conv2d_nchw_python(
data.asnumpy(), kernel.asnumpy(), 1, 1)
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_grouped_conv2d():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32,
y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
name="y", padding=(1,1))
dtype = "float32"
dshape = (1, 32, 18, 18)
......@@ -49,12 +43,10 @@ def test_grouped_conv2d():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
# get output
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.depthwise_conv2d_python_nchw(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
......@@ -62,6 +54,78 @@ def test_grouped_conv2d():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_max_pool2d():
x = sym.Variable("x")
y = sym.max_pool2d(x, pool_size=(2,2), strides=(2,2), padding=(0,0), name="y")
dtype = "float32"
dshape = (1, 3, 28, 28)
oshape = (1, 3, 14, 14)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = np.max(data.asnumpy().reshape(1,3,14,2,14,2), axis=(3,5))
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
def test_avg_pool2d():
x = sym.Variable("x")
y = sym.avg_pool2d(x, pool_size=(2,2), strides=(2,2), padding=(0,0), name="y")
dtype = "float32"
dshape = (1, 3, 28, 28)
oshape = (1, 3, 14, 14)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = np.mean(data.asnumpy().reshape(1,3,14,2,14,2), axis=(3,5))
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
def test_global_max_pool2d():
x = sym.Variable("x")
y = sym.global_max_pool2d(x, name="y")
dtype = "float32"
dshape = (1, 1024, 7, 7)
oshape = (1, 1024, 1, 1)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = np.max(data.asnumpy(), axis=(2,3), keepdims=True)
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
def test_global_avg_pool2d():
x = sym.Variable("x")
y = sym.global_avg_pool2d(x, name="y")
dtype = "float32"
dshape = (1, 1024, 7, 7)
oshape = (1, 1024, 1, 1)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = np.mean(data.asnumpy(), axis=(2,3), keepdims=True)
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
if __name__ == "__main__":
test_conv2d()
test_grouped_conv2d()
test_max_pool2d()
test_avg_pool2d()
test_global_max_pool2d()
test_global_avg_pool2d()
......@@ -23,7 +23,7 @@ def default_ctx():
else:
return tvm.cpu(0)
def test_mxnet_frontend_impl(mx_symbol, data_shape=(2, 3, 224, 224), out_shape=(2, 1000)):
def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)):
def get_mxnet_output(symbol, x, dtype='float32'):
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
......@@ -83,6 +83,5 @@ def test_forward_resnet():
if __name__ == '__main__':
test_forward_mlp()
# waiting for max_pool2d
# test_forward_vgg()
# test_forward_resnet()
test_forward_vgg()
test_forward_resnet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment