Commit 948f6898 by Yuwei HU Committed by Tianqi Chen

register depthconv, elemwise (#17)

* register depthconv, elemwise

* use global elemwise schedule for relu
parent 1bc5d0ad
......@@ -3,9 +3,44 @@ from __future__ import absolute_import
import tvm
import topi
from topi.util import get_const_int
from .tensor import schedule_elemwise
from ..compiler import registry as reg
from ..compiler import OpPattern
# relu
@reg.register_compute("relu")
def compute_relu(attrs, inputs):
"""Compute definition of relu"""
return topi.nn.relu(inputs[0])
@reg.register_schedule("relu")
def schedule_relu(_, outs, target):
"""Schedule definition of relu"""
return schedule_elemwise(_, outs, target)
reg.register_pattern("relu", OpPattern.ELEM_WISE)
# softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.softmax(inputs[0])
@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("softmax", OpPattern.COMPLEX)
# conv
@reg.register_compute("conv2d")
def compute_conv2d(attrs, inputs):
......@@ -13,10 +48,17 @@ def compute_conv2d(attrs, inputs):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
out = topi.nn.conv2d_nchw(inputs[0], inputs[1], strides, padding)
if groups == 1:
out = topi.nn.conv2d_nchw(inputs[0], inputs[1], strides, padding)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding)
else:
raise ValueError("not support arbitrary group number for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.broadcast_to(bias, (1, bias.shape[0], 1, 1))
......@@ -24,30 +66,15 @@ def compute_conv2d(attrs, inputs):
return out
@reg.register_schedule("conv2d")
def schedule_conv2d(_, outs, target):
def schedule_conv2d(attrs, outs, target):
"""Schedule definition of conv2d"""
groups = attrs.get_int("groups")
if target == "cuda":
return topi.cuda.schedule_conv2d_nchw(outs)
if groups == 1:
return topi.cuda.schedule_conv2d_nchw(outs)
else:
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.COMPLEX)
# softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.softmax(inputs[0])
@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("softmax", OpPattern.COMPLEX)
......@@ -8,6 +8,15 @@ import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern
def schedule_elemwise(_, outs, target):
"""Generic schedule for elemwise operation"""
if target == "cuda":
return topi.cuda.schedule_elemwise(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
return s
def _schedule_broadcast(_, outs, target):
"""Generic schedule for binary bcast"""
if target == "cuda":
......@@ -36,6 +45,24 @@ reg.register_compute("exp",
reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_schedule("exp", _fschedule_broadcast)
# log
reg.register_compute("log",
lambda _, x: topi.log(x[0]))
reg.register_pattern("log", OpPattern.ELEM_WISE)
reg.register_schedule("log", _fschedule_broadcast)
# tanh
reg.register_compute("tanh",
lambda _, x: topi.tanh(x[0]))
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
reg.register_schedule("tanh", _fschedule_broadcast)
# sigmoid
reg.register_compute("sigmoid",
lambda _, x: topi.sigmoid(x[0]))
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
reg.register_schedule("sigmoid", _fschedule_broadcast)
# add scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
......
......@@ -20,6 +20,116 @@ def default_ctx():
else:
return tvm.cpu(0)
def test_relu():
x = sym.Variable("x")
y = sym.relu(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.maximum(data.asnumpy(), 0.0)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_exp():
x = sym.Variable("x")
y = sym.exp(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.exp(data.asnumpy())
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_log():
x = sym.Variable("x")
y = sym.log(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.log(data.asnumpy())
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_tanh():
x = sym.Variable("x")
y = sym.tanh(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy())
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_sigmoid():
x = sym.Variable("x")
y = sym.sigmoid(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = 1.0 / (1.0 + np.exp(-data.asnumpy()))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_softmax():
x = sym.Variable("x")
y = sym.softmax(x)
......@@ -35,12 +145,17 @@ def test_softmax():
set_input("x", data)
# execute
run()
# get outputs
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = topi.testing.softmax_python(data.asnumpy())
np.testing.assert_allclose(out.asnumpy(), y_np, rtol=1e-5)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
if __name__ == "__main__":
test_relu()
test_exp()
test_log()
test_tanh()
test_sigmoid()
test_softmax()
......@@ -6,6 +6,20 @@ import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
USE_GPU=True
def default_target():
if USE_GPU:
return 'cuda'
else:
return 'llvm'
def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)
def test_conv2d():
x = sym.Variable("x")
y = sym.conv2d(x, channels=10, kernel_size=(3, 3),
......@@ -15,25 +29,53 @@ def test_conv2d():
kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape}
graph, lib = nnvm.compiler.build(y, "llvm", shape_dict)
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
graph, lib = nnvm.compiler.build(y, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("x", data)
set_input("y_weight", kernel)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
c_np = topi.testing.conv2d_nchw_python(
data.asnumpy(), kernel.asnumpy(), 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_grouped_conv2d():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32,
name="y", use_bias=False, padding=(1,1))
dtype = "float32"
dshape = (1, 32, 18, 18)
kshape = (32, 1, 3, 3)
oshape = (1, 32, 18, 18)
shape_dict = {"x": dshape}
graph, lib = nnvm.compiler.build(y, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("x", data)
set_input("y_weight", kernel)
# execute
run()
# get outputs
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
c_np = topi.testing.conv2d_nchw_python(
data.asnumpy(), kernel.asnumpy(), 1, 1)
c_np = topi.testing.depthwise_conv2d_python_nchw(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
if __name__ == "__main__":
test_conv2d()
test_grouped_conv2d()
......@@ -25,6 +25,7 @@ def test_unary():
x = sym.log(x)
x = sym.sigmoid(x)
x = sym.tanh(x)
x = sym.relu(x)
assert x.list_input_names() == ['x']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment