Commit 215693df by Tianqi Chen

[TOP] Add dense, batchnorm (#22)

* [TOP] Add dense, batchnorm

* update tvm
parent a2ab3d83
...@@ -44,11 +44,14 @@ using TOpPattern = int; ...@@ -44,11 +44,14 @@ using TOpPattern = int;
* \brief Computation description interface * \brief Computation description interface
* \param attrs The attribute of the node. * \param attrs The attribute of the node.
* \param inputs The input tensors(placeholders) * \param inputs The input tensors(placeholders)
* \param out_info Tensors holding shape/type information about output,
& these are always placeholders.
* \return The output description of the tensor. * \return The output description of the tensor.
*/ */
using FTVMCompute = std::function< using FTVMCompute = std::function<
Array<Tensor> Array<Tensor>(const NodeAttrs& attrs,
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>; const Array<Tensor>& inputs,
const Array<Tensor>& out_info)>;
/*! /*!
* \brief Build the computation schedule for * \brief Build the computation schedule for
......
...@@ -115,9 +115,12 @@ def optimize(graph, shape, dtype="float32"): ...@@ -115,9 +115,12 @@ def optimize(graph, shape, dtype="float32"):
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
cfg = BuildConfig.current cfg = BuildConfig.current
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
if graph.json_attr("shape_num_unknown_nodes"):
raise ValueError("InferShape fails..")
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]: if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph.apply("SimplifyBatchNormInference")
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
return graph return graph
...@@ -164,6 +167,12 @@ def build(graph, target, shape, dtype="float32", params=None): ...@@ -164,6 +167,12 @@ def build(graph, target, shape, dtype="float32", params=None):
cfg = BuildConfig.current cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
# Initial pass do shape type inference
ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape))
if not isinstance(dtype, str):
idtype, _ = graph_util.infer_dtype(graph, **dtype)
dtype.update(zip(graph.index.input_names, idtype))
# Apply optimization # Apply optimization
graph = optimize(graph, shape, dtype) graph = optimize(graph, shape, dtype)
# Precompute prune # Precompute prune
......
...@@ -5,8 +5,10 @@ import tvm ...@@ -5,8 +5,10 @@ import tvm
class OpPattern(object): class OpPattern(object):
ELEM_WISE = 0 ELEM_WISE = 0
BROADCAST = 1 BROADCAST = 1
# Complex means we can fuse elemwise to it
COMPLEX = 2 COMPLEX = 2
EXTERN = 2 # Extern means the op is not fusable
EXTERN = 3
_register_compute = tvm.get_global_func("nnvm._register_compute") _register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule") _register_schedule = tvm.get_global_func("nnvm._register_schedule")
......
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
from .attr_dict import AttrDict from .attr_dict import AttrDict
from . import tensor from . import tensor
from . import nn from . import nn
from . import transform
# pylint: disable=invalid-name, unused-argument
"""Definition of nn ops""" """Definition of nn ops"""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
import topi import topi
from topi.util import get_const_int from topi.util import get_const_int
from .tensor import schedule_elemwise from .tensor import _fschedule_broadcast
from ..compiler import registry as reg from ..compiler import registry as reg
from ..compiler import OpPattern from ..compiler import OpPattern
# relu # relu
@reg.register_compute("relu") @reg.register_compute("relu")
def compute_relu(_, inputs): def compute_relu(attrs, inputs, _):
"""Compute definition of relu""" """Compute definition of relu"""
return topi.nn.relu(inputs[0]) return topi.nn.relu(inputs[0])
@reg.register_schedule("relu") reg.register_schedule("relu", _fschedule_broadcast)
def schedule_relu(_, outs, target):
"""Schedule definition of relu"""
return schedule_elemwise(_, outs, target)
reg.register_pattern("relu", OpPattern.ELEM_WISE) reg.register_pattern("relu", OpPattern.ELEM_WISE)
# flatten
@reg.register_compute("flatten")
def compute_flatten(attrs, inputs, _):
"""Compute definition of flatten"""
return topi.nn.flatten(inputs[0])
reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.COMPLEX)
# softmax # softmax
@reg.register_compute("softmax") @reg.register_compute("softmax")
def compute_softmax(attrs, inputs): def compute_softmax(attrs, inputs, _):
"""Compute definition of softmax""" """Compute definition of softmax"""
axis = attrs.get_int("axis") axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now" assert axis == -1, "only support axis == -1 for now"
...@@ -38,12 +45,34 @@ def schedule_softmax(_, outs, target): ...@@ -38,12 +45,34 @@ def schedule_softmax(_, outs, target):
# naive schedule # naive schedule
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("softmax", OpPattern.COMPLEX) # Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.EXTERN)
# dense
@reg.register_compute("dense")
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.fully_connected_with_bias(
inputs[0], inputs[1], inputs[2])
return topi.nn.fully_connected(inputs[0], inputs[1])
@reg.register_schedule("dense")
def schedule_dense(_, outs, target):
"""Schedule definition of dense"""
if target == "cuda":
raise ValueError("fully_connected not yet implemented")
# naive schedule
return tvm.create_schedule([x.op for x in outs])
# register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.EXTERN)
# conv # conv
@reg.register_compute("conv2d") @reg.register_compute("conv2d")
def compute_conv2d(attrs, inputs): def compute_conv2d(attrs, inputs, _):
"""Compute definition of conv2d""" """Compute definition of conv2d"""
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
......
# pylint: disable=invalid-name # pylint: disable=invalid-name, unused-argument
"""Tensor ops""" """Tensor ops"""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -8,15 +8,6 @@ import topi.cuda ...@@ -8,15 +8,6 @@ import topi.cuda
from ..compiler import registry as reg from ..compiler import registry as reg
from ..compiler import OpPattern from ..compiler import OpPattern
def schedule_elemwise(_, outs, target):
"""Generic schedule for elemwise operation"""
if target == "cuda":
return topi.cuda.schedule_elemwise(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
return s
def _schedule_broadcast(_, outs, target): def _schedule_broadcast(_, outs, target):
"""Generic schedule for binary bcast""" """Generic schedule for binary bcast"""
if target == "cuda": if target == "cuda":
...@@ -29,7 +20,7 @@ def _schedule_broadcast(_, outs, target): ...@@ -29,7 +20,7 @@ def _schedule_broadcast(_, outs, target):
def _compute_binary_scalar(f): def _compute_binary_scalar(f):
"""auxiliary function""" """auxiliary function"""
@tvm.tag_scope("ewise") @tvm.tag_scope("ewise")
def _compute(attrs, x): def _compute(attrs, x, _):
x = x[0] x = x[0]
scalar = attrs.get_float("scalar") scalar = attrs.get_float("scalar")
scalar = tvm.const(scalar, x.dtype) scalar = tvm.const(scalar, x.dtype)
...@@ -37,58 +28,132 @@ def _compute_binary_scalar(f): ...@@ -37,58 +28,132 @@ def _compute_binary_scalar(f):
return _compute return _compute
def _compute_unary(f):
"""auxiliary function"""
def _compute(attrs, x, _):
return f(x[0])
return _compute
def _compute_binary(f):
"""auxiliary function"""
def _compute(attrs, x, _):
return f(x[0], x[1])
return _compute
_fschedule_broadcast = tvm.convert(_schedule_broadcast) _fschedule_broadcast = tvm.convert(_schedule_broadcast)
# exp # exp
reg.register_compute("exp", reg.register_compute("exp", _compute_unary(topi.exp))
lambda _, x: topi.exp(x[0]))
reg.register_pattern("exp", OpPattern.ELEM_WISE) reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_schedule("exp", _fschedule_broadcast) reg.register_schedule("exp", _fschedule_broadcast)
# sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
reg.register_schedule("sqrt", _fschedule_broadcast)
# log # log
reg.register_compute("log", reg.register_compute("log", _compute_unary(topi.log))
lambda _, x: topi.log(x[0]))
reg.register_pattern("log", OpPattern.ELEM_WISE) reg.register_pattern("log", OpPattern.ELEM_WISE)
reg.register_schedule("log", _fschedule_broadcast) reg.register_schedule("log", _fschedule_broadcast)
# tanh # tanh
reg.register_compute("tanh", reg.register_compute("tanh", _compute_unary(topi.tanh))
lambda _, x: topi.tanh(x[0]))
reg.register_pattern("tanh", OpPattern.ELEM_WISE) reg.register_pattern("tanh", OpPattern.ELEM_WISE)
reg.register_schedule("tanh", _fschedule_broadcast) reg.register_schedule("tanh", _fschedule_broadcast)
# negative
reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEM_WISE)
reg.register_schedule("negative", _fschedule_broadcast)
# sigmoid # sigmoid
reg.register_compute("sigmoid", reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
lambda _, x: topi.sigmoid(x[0]))
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE) reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
reg.register_schedule("sigmoid", _fschedule_broadcast) reg.register_schedule("sigmoid", _fschedule_broadcast)
# add scalar # add_scalar
reg.register_compute("__add_scalar__", reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y)) _compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast) reg.register_schedule("__add_scalar__", _fschedule_broadcast)
# sub_calar
reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
# rsub_scalar
reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
# mul_scalar
reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
# div_scalar
reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast)
# rdiv_scalar
reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
# elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast)
# elemwise_sub
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
reg.register_schedule("elemwise_sub", _fschedule_broadcast)
# elemwise_mul
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
reg.register_schedule("elemwise_mul", _fschedule_broadcast)
# elemwise_div
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
reg.register_schedule("elemwise_div", _fschedule_broadcast)
# broadcast_add # broadcast_add
reg.register_compute("broadcast_add", reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
lambda _, x: topi.broadcast_add(x[0], x[1]))
reg.register_pattern("broadcast_add", OpPattern.BROADCAST) reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
reg.register_schedule("broadcast_add", _fschedule_broadcast) reg.register_schedule("broadcast_add", _fschedule_broadcast)
# broadcast_sub # broadcast_sub
reg.register_compute("broadcast_sub", reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
lambda _, x: topi.broadcast_sub(x[0], x[1]))
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST) reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
reg.register_schedule("broadcast_sub", _fschedule_broadcast) reg.register_schedule("broadcast_sub", _fschedule_broadcast)
# broadcast_mul # broadcast_mul
reg.register_compute("broadcast_mul", reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
lambda _, x: topi.broadcast_mul(x[0], x[1]))
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST) reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
reg.register_schedule("broadcast_mul", _fschedule_broadcast) reg.register_schedule("broadcast_mul", _fschedule_broadcast)
# broadcast_div # broadcast_div
reg.register_compute("broadcast_div", reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
lambda _, x: topi.broadcast_div(x[0], x[1]))
reg.register_pattern("broadcast_div", OpPattern.BROADCAST) reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
reg.register_schedule("broadcast_div", _fschedule_broadcast) reg.register_schedule("broadcast_div", _fschedule_broadcast)
# broadcast_to
@reg.register_compute("broadcast_to")
def compute_softmax(attrs, inputs, out_info):
"""Compute definition of softmax"""
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast)
# pylint: disable=invalid-name, unused-argument
"""Tensor transformation ops"""
from __future__ import absolute_import
import tvm
from .tensor import _fschedule_broadcast
from ..compiler import registry as reg
from ..compiler import OpPattern
# Need add reshape, transpose
def _flatten_index(indices, shape):
"""flatten the index to 1D"""
idx = 0
for i, value in enumerate(shape):
if i != 0:
idx *= value
idx = idx + indices[i]
return idx
# reshape
@reg.register_compute("reshape")
def compute_reshape(attrs, inputs, out_info):
"""Compute definition of softmax"""
# TODO(sxj) add support for general reshape
assert len(inputs[0].shape) == 1, "Only support 1d input for now"
oshape = out_info[0].shape
x = inputs[0]
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
reg.register_pattern("reshape", OpPattern.COMPLEX)
reg.register_schedule("reshape", _fschedule_broadcast)
...@@ -261,7 +261,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -261,7 +261,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
if (inode.source->is_variable()) continue; if (inode.source->is_variable()) continue;
int root_id = group_vec[nid]; int root_id = group_vec[nid];
FuseEntry& fe = fuse_vec[root_id]; FuseEntry& fe = fuse_vec[root_id];
Array<Tensor> inputs; Array<Tensor> inputs, out_info;
// input loading // input loading
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
if (group_vec[e.node_id] != root_id) { if (group_vec[e.node_id] != root_id) {
...@@ -274,11 +274,21 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -274,11 +274,21 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
inputs.push_back(t); inputs.push_back(t);
} }
} }
// output hint
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
Array<Expr> shape;
for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
shape.push_back(make_const(Int(32), x));
}
out_info.push_back(
placeholder(shape,
TVMType2Type(dltype_vec[idx.entry_id(nid, i)])));
}
// get default // get default
Array<Tensor> out = fcompute[inode.source->op()]( Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, inputs); inode.source->attrs, inputs, out_info);
CHECK_EQ(out.size(), inode.source->num_outputs()); CHECK_EQ(out.size(), inode.source->num_outputs());
// schedule on root node, and use master's schedule // schedule on root node, and use master's schedule
if (nid != root_id) { if (nid != root_id) {
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
...@@ -312,6 +322,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -312,6 +322,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
} }
} }
} }
tvm::runtime::Module module = fbuild(funcs, target); tvm::runtime::Module module = fbuild(funcs, target);
// Final step: Remap the node, with given attribute // Final step: Remap the node, with given attribute
const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op"); const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
......
...@@ -67,9 +67,11 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") ...@@ -67,9 +67,11 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute")
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fcompute = [f](const NodeAttrs& attrs, const Array<Tensor>& inputs) auto fcompute = [f](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info)
-> Array<Tensor> { -> Array<Tensor> {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs); TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info);
if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) { if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
return {ret.operator Tensor()}; return {ret.operator Tensor()};
} else { } else {
......
...@@ -21,7 +21,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -21,7 +21,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm::NodeEntry beta, nnvm::NodeEntry beta,
nnvm::NodeEntry moving_mean, nnvm::NodeEntry moving_mean,
nnvm::NodeEntry moving_var, nnvm::NodeEntry moving_var,
int data_dim) { TShape dshape) {
CHECK(attrs.op); CHECK(attrs.op);
static const Op* bn_op = Op::Get("batch_norm"); static const Op* bn_op = Op::Get("batch_norm");
CHECK(attrs.op == bn_op); CHECK(attrs.op == bn_op);
...@@ -57,19 +57,12 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -57,19 +57,12 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
shift = MakeNode( shift = MakeNode(
"elemwise_add", bn_name + "_add_beta", {shift, beta}); "elemwise_add", bn_name + "_add_beta", {shift, beta});
} }
// reshape to nhwc // use broaodcast to reshape
std::ostringstream oshape; std::ostringstream oshape;
oshape << "("; for (dim_t i = 0; i < dshape.ndim(); ++i) {
for (int i = 0; i < data_dim; ++i) { dshape[i] = (i != param.axis) ? 1 : -1;
if (i != 0) oshape << ", ";
if (i == param.axis) {
oshape << "-1";
} else {
oshape << "1";
}
} }
oshape << ")"; oshape << dshape;
scale = MakeNode("reshape", bn_name + "_sc_reshape", scale = MakeNode("reshape", bn_name + "_sc_reshape",
{scale}, {{"shape", oshape.str()}}); {scale}, {{"shape", oshape.str()}});
shift = MakeNode("reshape", bn_name + "_sh_reshape", shift = MakeNode("reshape", bn_name + "_sh_reshape",
...@@ -98,7 +91,7 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) { ...@@ -98,7 +91,7 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
n->inputs[2], n->inputs[2],
n->inputs[3], n->inputs[3],
n->inputs[4], n->inputs[4],
shape_vec[idx.entry_id(nid, 0)].ndim()); shape_vec[idx.entry_id(nid, 0)]);
return true; return true;
} else { } else {
return false; return false;
......
...@@ -73,7 +73,7 @@ void PrintGraphIR_(Graph src, ...@@ -73,7 +73,7 @@ void PrintGraphIR_(Graph src,
AttrPrinter fp = GetVectorPrinter(src, key); AttrPrinter fp = GetVectorPrinter(src, key);
auto fprint = [&idx, key, fp]( auto fprint = [&idx, key, fp](
uint32_t nid, std::ostream& os) { // NOLINT(*) uint32_t nid, std::ostream& os) { // NOLINT(*)
os << key << "="; os << ", " << key << "=";
fp(idx.entry_id(nid, 0), os); fp(idx.entry_id(nid, 0), os);
}; };
trigger.push_back(fprint); trigger.push_back(fprint);
......
...@@ -5,13 +5,13 @@ from nnvm.compiler import graph_util, graph_attr ...@@ -5,13 +5,13 @@ from nnvm.compiler import graph_util, graph_attr
def test_simplify_batchnorm(): def test_simplify_batchnorm():
def simple_bn(x, gamma, beta, moving_mean, moving_var, def simple_bn(x, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, dim=2): axis=1, epsilon=1e-5, shape=None):
# expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta # expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta
scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma) scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
shift = sym.elemwise_add( shift = sym.elemwise_add(
sym.elemwise_mul(sym.negative(moving_mean), scale), beta) sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
shape = [-1 if i == axis else 1 for i in range(len(shape))]
# for 2D # for 2D
shape = tuple(1 if i != axis else -1 for i in range(dim))
scale = sym.reshape(scale, shape=shape) scale = sym.reshape(scale, shape=shape)
shift = sym.reshape(shift, shape=shape) shift = sym.reshape(shift, shape=shape)
return x * scale + shift return x * scale + shift
...@@ -26,15 +26,14 @@ def test_simplify_batchnorm(): ...@@ -26,15 +26,14 @@ def test_simplify_batchnorm():
moving_var = sym.Variable("moving_var") moving_var = sym.Variable("moving_var")
moving_mean = sym.Variable("moving_mean") moving_mean = sym.Variable("moving_mean")
y1, y2 = x, x y1, y2 = x, x
ishape = {"x": tuple(10 for i in range(dim))}
for i in range(nstep): for i in range(nstep):
y1 = sym.batch_norm( y1 = sym.batch_norm(
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var, y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, dim=dim) epsilon=eps, axis=axis, shape=ishape["x"])
g = nnvm.graph.create(y1) g = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2) g2 = nnvm.graph.create(y2)
ishape = {"x": tuple(10 for i in range(dim))}
graph_attr.set_shape_inputs(g, ishape) graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference") g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
# Some prints for debug # Some prints for debug
......
...@@ -6,19 +6,9 @@ import nnvm.symbol as sym ...@@ -6,19 +6,9 @@ import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
import nnvm.runtime import nnvm.runtime
USE_GPU=True def ctx_list():
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
def default_target(): return [x for x in res if x[1].exist]
if USE_GPU:
return 'cuda'
else:
return 'llvm'
def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)
def test_conv2d(): def test_conv2d():
x = sym.Variable("x") x = sym.Variable("x")
...@@ -29,23 +19,24 @@ def test_conv2d(): ...@@ -29,23 +19,24 @@ def test_conv2d():
kshape = (10, 3, 3, 3) kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18) oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape} shape_dict = {"x": dshape}
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict) for target, ctx in ctx_list():
m = nnvm.runtime.create(graph, lib, default_ctx()) graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
# get member functions m = nnvm.runtime.create(graph, lib, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"] # get member functions
# set input set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) # set input
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data) kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("y_weight", kernel) set_input("x", data)
# execute set_input("y_weight", kernel)
run() # execute
# get output run()
out = tvm.nd.empty(oshape, dtype) # get output
get_output(0, out) out = tvm.nd.empty(oshape, dtype)
c_np = topi.testing.conv2d_nchw_python( get_output(0, out)
data.asnumpy(), kernel.asnumpy(), 1, 1) c_np = topi.testing.conv2d_nchw_python(
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) data.asnumpy(), kernel.asnumpy(), 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_grouped_conv2d(): def test_grouped_conv2d():
...@@ -57,23 +48,24 @@ def test_grouped_conv2d(): ...@@ -57,23 +48,24 @@ def test_grouped_conv2d():
kshape = (32, 1, 3, 3) kshape = (32, 1, 3, 3)
oshape = (1, 32, 18, 18) oshape = (1, 32, 18, 18)
shape_dict = {"x": dshape} shape_dict = {"x": dshape}
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict) for target, ctx in ctx_list():
m = nnvm.runtime.create(graph, lib, default_ctx()) graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
# get member functions m = nnvm.runtime.create(graph, lib, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"] # get member functions
# set input set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) # set input
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data) kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("y_weight", kernel) set_input("x", data)
# execute set_input("y_weight", kernel)
run() # execute
# get output run()
out = tvm.nd.empty(oshape, dtype) # get output
get_output(0, out) out = tvm.nd.empty(oshape, dtype)
c_np = topi.testing.depthwise_conv2d_python_nchw( get_output(0, out)
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') c_np = topi.testing.depthwise_conv2d_python_nchw(
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment