Commit 2e9b6b99 by Tianqi Chen

[TOP][COMPILER] Add expand_dims, change graph_compare to not compare input optionally (#25)

parent 40bc10f3
......@@ -21,6 +21,17 @@ struct ConcatenateParam : public dmlc::Parameter<ConcatenateParam> {
}
};
struct ExpandDimsParam : public dmlc::Parameter<ExpandDimsParam> {
int axis;
int num_newaxis;
DMLC_DECLARE_PARAMETER(ExpandDimsParam) {
DMLC_DECLARE_FIELD(axis)
.describe("the axis to be expanded.");
DMLC_DECLARE_FIELD(num_newaxis).set_lower_bound(1).set_default(1)
.describe("Number of new axis to be inserted.");
}
};
struct SplitParam : public dmlc::Parameter<SplitParam> {
// numpy convention, only support indices, not support list.
Tuple<int> indices_or_sections;
......
......@@ -2,6 +2,7 @@
"""Namespace for building operators."""
from __future__ import absolute_import as _abs
import logging
import tvm
from . import graph_attr, graph_util
from .. import graph as _graph
......@@ -74,6 +75,7 @@ def build_config(**kwargs):
@tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
......
......@@ -59,7 +59,7 @@ def infer_dtype(graph, **dtype):
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
def check_graph_equal(grapha, graphb):
def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
"""Check if two graphs have equal structure.
Parameters
......@@ -70,11 +70,16 @@ def check_graph_equal(grapha, graphb):
graphb : Graph
The second graph
compare_variable_attrs : bool, optional
Whether we want to compare attributes(names) on variables.
Usually it is safe to skip it unless we want input name
to exactly match
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
err = _deep_compare(grapha, graphb, compare_variable_attrs)
if err:
raise ValueError("Graph compare error: " + err)
"""Utilities for testcase"""
"""Configuration about tests"""
import os
import tvm
def test_ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("NNVM_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist and x[0] in device_list]
......@@ -90,7 +90,7 @@ def compute_conv2d(attrs, inputs, _):
raise ValueError("not support arbitrary group number for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.broadcast_to(bias, (1, bias.shape[0], 1, 1))
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias)
return out
......
......@@ -115,6 +115,18 @@ reg.register_compute("__rdiv_scalar__",
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
# pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
# rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
......
......@@ -3,11 +3,21 @@
from __future__ import absolute_import
import tvm
import topi
from .tensor import _fschedule_broadcast
from ..compiler import registry as reg
from ..compiler import OpPattern
# Need add reshape, transpose
@reg.register_compute("expand_dims")
def compute_expand_dims(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
return topi.expand_dims(
inputs[0], attrs.get_int("axis"),
num_newaxis=attrs.get_int("num_newaxis"))
reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast)
def _flatten_index(indices, shape):
"""flatten the index to 1D"""
......
......@@ -16,7 +16,9 @@ namespace compiler {
// not considering the graph attributes
// return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node.
std::string DeepCompare(Graph a, Graph b) {
// compare_var_attr
std::string DeepCompare(Graph a, Graph b,
bool compare_variable_attr) {
const IndexedGraph& idxa = a.indexed_graph();
const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err;
......@@ -51,6 +53,10 @@ std::string DeepCompare(Graph a, Graph b) {
err << "Node mismatch ";
return err.str();
}
if (anode.source->is_variable()) {
CHECK(bnode.source->is_variable());
if (!compare_variable_attr) continue;
}
AttrDict adict = GetAttrDict(anode.source->attrs);
AttrDict bdict = GetAttrDict(bnode.source->attrs);
......@@ -107,7 +113,7 @@ std::string DeepCompare(Graph a, Graph b) {
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = DeepCompare(args[0], args[1]);
*rv = DeepCompare(args[0], args[1], args[2]);
});
} // namespace compiler
} // namespace nnvm
......@@ -58,16 +58,15 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
shift = MakeNode(
"elemwise_add", bn_name + "_add_beta", {shift, beta});
}
// use broaodcast to reshape
std::ostringstream oshape;
for (dim_t i = 0; i < dshape.ndim(); ++i) {
dshape[i] = (i != param.axis) ? 1 : -1;
// use expand dims to pad lower dims to 1
int num_pad_axis = static_cast<int>(dshape.ndim() - param.axis) - 1;
if (num_pad_axis != 0) {
std::unordered_map<std::string, std::string> kwargs{
{"axis", std::to_string(param.axis)},
{"num_newaxis", std::to_string(num_pad_axis)}};
scale = MakeNode("expand_dims", bn_name + "_sc_expand", {scale}, kwargs);
shift = MakeNode("expand_dims", bn_name + "_sh_expand", {shift}, kwargs);
}
oshape << dshape;
scale = MakeNode("reshape", bn_name + "_sc_reshape",
{scale}, {{"shape", oshape.str()}});
shift = MakeNode("reshape", bn_name + "_sh_reshape",
{shift}, {{"shape", oshape.str()}});
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
{data, scale});
out = MakeNode("broadcast_add", bn_name + "_out",
......
......@@ -142,8 +142,51 @@ Example::
.set_num_inputs(kVarg)
.set_support_level(1);
// expand_dims
DMLC_REGISTER_PARAMETER(ExpandDimsParam);
inline bool ExpandDimsInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
const TShape& dshape = in_shape->at(0);
int ndim = static_cast<int>(dshape.ndim());
CHECK(param.axis >= -ndim - 1 && param.axis <= ndim);
int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis;
std::vector<dim_t> oshape;
for (int i = 0; i < axis; ++i) {
oshape.push_back(dshape[i]);
}
for (int i = 0; i < param.num_newaxis; ++i) {
oshape.push_back(1);
}
for (int i = axis; i < ndim; ++i) {
oshape.push_back(dshape[i]);
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0,
TShape(oshape.begin(), oshape.end()));
return true;
}
// concatenate
NNVM_REGISTER_OP(expand_dims)
.describe(R"code(Inserts a new axis of size 1 into the array shape
For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1)``
will return a new array with shape ``(2,1,3,4)``.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input tensor")
.add_arguments(ExpandDimsParam::__FIELDS__())
.set_attr_parser(ParamParser<ExpandDimsParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>)
.set_attr<FInferShape>("FInferShape", ExpandDimsInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(1);
// split
DMLC_REGISTER_PARAMETER(SplitParam);
inline void SplitParamParser(nnvm::NodeAttrs* attrs) {
......
......@@ -40,9 +40,6 @@ def test_compile():
assert graph.index.num_nodes == 4
verify(graph, lib)
def test_run():
x = sym.Variable("x")
y = sym.Variable("y")
......
......@@ -12,8 +12,10 @@ def test_simplify_batchnorm():
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
shape = [-1 if i == axis else 1 for i in range(len(shape))]
# for 2D
scale = sym.reshape(scale, shape=shape)
shift = sym.reshape(shift, shape=shape)
num_newaxis=len(shape) - axis - 1
if num_newaxis:
scale = sym.expand_dims(scale, axis=axis, num_newaxis=num_newaxis)
shift = sym.expand_dims(shift, axis=axis, num_newaxis=num_newaxis)
return x * scale + shift
......@@ -25,7 +27,7 @@ def test_simplify_batchnorm():
gamma = sym.Variable("gamma")
moving_var = sym.Variable("moving_var")
moving_mean = sym.Variable("moving_mean")
y1, y2 = x, x
y1, y2 = x, sym.Variable("xx") + 1
ishape = {"x": tuple(10 for i in range(dim))}
for i in range(nstep):
y1 = sym.batch_norm(
......@@ -44,6 +46,7 @@ def test_simplify_batchnorm():
check(2, 1, 1)
check(4, 0, 3)
check(4, 1, 2)
if __name__ == "__main__":
test_simplify_batchnorm()
import numpy as np
import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
def ctx_list():
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist]
from nnvm.testing.config import test_ctx_list
def test_relu():
x = sym.Variable("x")
......@@ -17,7 +12,7 @@ def test_relu():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -40,7 +35,7 @@ def test_exp():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -63,7 +58,7 @@ def test_log():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
......@@ -87,7 +82,7 @@ def test_tanh():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
......@@ -111,7 +106,7 @@ def test_sigmoid():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -134,7 +129,7 @@ def test_softmax():
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
......@@ -187,7 +182,7 @@ def test_batchnorm():
y = sym.batch_norm(
x, gamma, beta, moving_mean, moving_var, epsilon=eps)
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape})
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
x_np = np.random.uniform(size=shape).astype(dtype)
......
......@@ -5,10 +5,8 @@ import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm.testing.config import test_ctx_list
def ctx_list():
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist]
def test_conv2d():
x = sym.Variable("x")
......@@ -19,7 +17,7 @@ def test_conv2d():
kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -42,29 +40,25 @@ def test_conv2d():
def test_grouped_conv2d():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32,
name="y", use_bias=False, padding=(1,1))
name="y", padding=(1,1))
dtype = "float32"
dshape = (1, 32, 18, 18)
kshape = (32, 1, 3, 3)
oshape = (1, 32, 18, 18)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("x", data)
set_input("y_weight", kernel)
# execute
run()
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.depthwise_conv2d_python_nchw(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
......
......@@ -34,6 +34,16 @@ def test_concatenate():
assert(sdict["concat"][0] == [20, 20])
def test_expand_dims():
x = sym.Variable("x", shape=(10, 20))
y = sym.expand_dims(x, axis=1, name="y")
sdict = infer_shape(y)
assert(sdict["y"][0] == [10, 1, 20])
y = sym.expand_dims(x, axis=-1, name="y", num_newaxis=2)
sdict = infer_shape(y)
assert(sdict["y"][0] == [10, 20, 1, 1])
def test_split():
x1 = sym.Variable("x", shape=(10, 20))
z = sym.split(x1, indices_or_sections=[11], name="y")
......@@ -247,6 +257,7 @@ def test_reduce():
if __name__ == "__main__":
test_expand_dims()
test_dense()
test_concatenate()
test_split()
......
......@@ -19,6 +19,11 @@ def test_concatenate_split():
z = sym.split(y, indices_or_sections=[10, 20])
assert len(z.list_output_names()) == 3
def test_expand_dims():
x = sym.Variable('x')
y = sym.expand_dims(x, axis=1, num_newaxis=2)
assert y.list_input_names() == ['x']
def test_unary():
x = sym.Variable('x')
......@@ -39,6 +44,7 @@ def test_batchnorm():
if __name__ == "__main__":
test_concatenate_split()
test_expand_dims()
test_dense()
test_unary()
test_batchnorm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment