Commit 2e9b6b99 by Tianqi Chen

[TOP][COMPILER] Add expand_dims, change graph_compare to not compare input optionally (#25)

parent 40bc10f3
...@@ -21,6 +21,17 @@ struct ConcatenateParam : public dmlc::Parameter<ConcatenateParam> { ...@@ -21,6 +21,17 @@ struct ConcatenateParam : public dmlc::Parameter<ConcatenateParam> {
} }
}; };
struct ExpandDimsParam : public dmlc::Parameter<ExpandDimsParam> {
int axis;
int num_newaxis;
DMLC_DECLARE_PARAMETER(ExpandDimsParam) {
DMLC_DECLARE_FIELD(axis)
.describe("the axis to be expanded.");
DMLC_DECLARE_FIELD(num_newaxis).set_lower_bound(1).set_default(1)
.describe("Number of new axis to be inserted.");
}
};
struct SplitParam : public dmlc::Parameter<SplitParam> { struct SplitParam : public dmlc::Parameter<SplitParam> {
// numpy convention, only support indices, not support list. // numpy convention, only support indices, not support list.
Tuple<int> indices_or_sections; Tuple<int> indices_or_sections;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Namespace for building operators.""" """Namespace for building operators."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging
import tvm import tvm
from . import graph_attr, graph_util from . import graph_attr, graph_util
from .. import graph as _graph from .. import graph as _graph
...@@ -74,6 +75,7 @@ def build_config(**kwargs): ...@@ -74,6 +75,7 @@ def build_config(**kwargs):
@tvm.register_func("nnvm.compiler.lower") @tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name): def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name) f = tvm.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name)
return f if isinstance( return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f] f, (tvm.container.Array, tuple, list)) else [f]
......
...@@ -59,7 +59,7 @@ def infer_dtype(graph, **dtype): ...@@ -59,7 +59,7 @@ def infer_dtype(graph, **dtype):
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare") _deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
def check_graph_equal(grapha, graphb): def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
"""Check if two graphs have equal structure. """Check if two graphs have equal structure.
Parameters Parameters
...@@ -70,11 +70,16 @@ def check_graph_equal(grapha, graphb): ...@@ -70,11 +70,16 @@ def check_graph_equal(grapha, graphb):
graphb : Graph graphb : Graph
The second graph The second graph
compare_variable_attrs : bool, optional
Whether we want to compare attributes(names) on variables.
Usually it is safe to skip it unless we want input name
to exactly match
Raises Raises
------ ------
ValueError ValueError
ValueError is raised with error message when graph not equal ValueError is raised with error message when graph not equal
""" """
err = _deep_compare(grapha, graphb) err = _deep_compare(grapha, graphb, compare_variable_attrs)
if err: if err:
raise ValueError("Graph compare error: " + err) raise ValueError("Graph compare error: " + err)
"""Utilities for testcase"""
"""Configuration about tests"""
import os
import tvm
def test_ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("NNVM_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist and x[0] in device_list]
...@@ -90,7 +90,7 @@ def compute_conv2d(attrs, inputs, _): ...@@ -90,7 +90,7 @@ def compute_conv2d(attrs, inputs, _):
raise ValueError("not support arbitrary group number for now") raise ValueError("not support arbitrary group number for now")
if attrs.get_bool("use_bias"): if attrs.get_bool("use_bias"):
bias = inputs[2] bias = inputs[2]
bias = topi.broadcast_to(bias, (1, bias.shape[0], 1, 1)) bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias) out = topi.broadcast_add(out, bias)
return out return out
......
...@@ -115,6 +115,18 @@ reg.register_compute("__rdiv_scalar__", ...@@ -115,6 +115,18 @@ reg.register_compute("__rdiv_scalar__",
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE) reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
# pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
# rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# elemwise_add # elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add)) reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST) reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
......
...@@ -3,11 +3,21 @@ ...@@ -3,11 +3,21 @@
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
import topi
from .tensor import _fschedule_broadcast from .tensor import _fschedule_broadcast
from ..compiler import registry as reg from ..compiler import registry as reg
from ..compiler import OpPattern from ..compiler import OpPattern
# Need add reshape, transpose # Need add reshape, transpose
@reg.register_compute("expand_dims")
def compute_expand_dims(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
return topi.expand_dims(
inputs[0], attrs.get_int("axis"),
num_newaxis=attrs.get_int("num_newaxis"))
reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast)
def _flatten_index(indices, shape): def _flatten_index(indices, shape):
"""flatten the index to 1D""" """flatten the index to 1D"""
......
...@@ -16,7 +16,9 @@ namespace compiler { ...@@ -16,7 +16,9 @@ namespace compiler {
// not considering the graph attributes // not considering the graph attributes
// return non-empty error message if the graph mismatch. // return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node. // the comparator won't match name of intermediate node.
std::string DeepCompare(Graph a, Graph b) { // compare_var_attr
std::string DeepCompare(Graph a, Graph b,
bool compare_variable_attr) {
const IndexedGraph& idxa = a.indexed_graph(); const IndexedGraph& idxa = a.indexed_graph();
const IndexedGraph& idxb = b.indexed_graph(); const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err; std::ostringstream err;
...@@ -51,6 +53,10 @@ std::string DeepCompare(Graph a, Graph b) { ...@@ -51,6 +53,10 @@ std::string DeepCompare(Graph a, Graph b) {
err << "Node mismatch "; err << "Node mismatch ";
return err.str(); return err.str();
} }
if (anode.source->is_variable()) {
CHECK(bnode.source->is_variable());
if (!compare_variable_attr) continue;
}
AttrDict adict = GetAttrDict(anode.source->attrs); AttrDict adict = GetAttrDict(anode.source->attrs);
AttrDict bdict = GetAttrDict(bnode.source->attrs); AttrDict bdict = GetAttrDict(bnode.source->attrs);
...@@ -107,7 +113,7 @@ std::string DeepCompare(Graph a, Graph b) { ...@@ -107,7 +113,7 @@ std::string DeepCompare(Graph a, Graph b) {
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare") TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { .set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = DeepCompare(args[0], args[1]); *rv = DeepCompare(args[0], args[1], args[2]);
}); });
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -58,16 +58,15 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -58,16 +58,15 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
shift = MakeNode( shift = MakeNode(
"elemwise_add", bn_name + "_add_beta", {shift, beta}); "elemwise_add", bn_name + "_add_beta", {shift, beta});
} }
// use broaodcast to reshape // use expand dims to pad lower dims to 1
std::ostringstream oshape; int num_pad_axis = static_cast<int>(dshape.ndim() - param.axis) - 1;
for (dim_t i = 0; i < dshape.ndim(); ++i) { if (num_pad_axis != 0) {
dshape[i] = (i != param.axis) ? 1 : -1; std::unordered_map<std::string, std::string> kwargs{
{"axis", std::to_string(param.axis)},
{"num_newaxis", std::to_string(num_pad_axis)}};
scale = MakeNode("expand_dims", bn_name + "_sc_expand", {scale}, kwargs);
shift = MakeNode("expand_dims", bn_name + "_sh_expand", {shift}, kwargs);
} }
oshape << dshape;
scale = MakeNode("reshape", bn_name + "_sc_reshape",
{scale}, {{"shape", oshape.str()}});
shift = MakeNode("reshape", bn_name + "_sh_reshape",
{shift}, {{"shape", oshape.str()}});
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
{data, scale}); {data, scale});
out = MakeNode("broadcast_add", bn_name + "_out", out = MakeNode("broadcast_add", bn_name + "_out",
......
...@@ -142,8 +142,51 @@ Example:: ...@@ -142,8 +142,51 @@ Example::
.set_num_inputs(kVarg) .set_num_inputs(kVarg)
.set_support_level(1); .set_support_level(1);
// expand_dims
DMLC_REGISTER_PARAMETER(ExpandDimsParam);
inline bool ExpandDimsInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
const TShape& dshape = in_shape->at(0);
int ndim = static_cast<int>(dshape.ndim());
CHECK(param.axis >= -ndim - 1 && param.axis <= ndim);
int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis;
std::vector<dim_t> oshape;
for (int i = 0; i < axis; ++i) {
oshape.push_back(dshape[i]);
}
for (int i = 0; i < param.num_newaxis; ++i) {
oshape.push_back(1);
}
for (int i = axis; i < ndim; ++i) {
oshape.push_back(dshape[i]);
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0,
TShape(oshape.begin(), oshape.end()));
return true;
}
// concatenate NNVM_REGISTER_OP(expand_dims)
.describe(R"code(Inserts a new axis of size 1 into the array shape
For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1)``
will return a new array with shape ``(2,1,3,4)``.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input tensor")
.add_arguments(ExpandDimsParam::__FIELDS__())
.set_attr_parser(ParamParser<ExpandDimsParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>)
.set_attr<FInferShape>("FInferShape", ExpandDimsInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(1);
// split
DMLC_REGISTER_PARAMETER(SplitParam); DMLC_REGISTER_PARAMETER(SplitParam);
inline void SplitParamParser(nnvm::NodeAttrs* attrs) { inline void SplitParamParser(nnvm::NodeAttrs* attrs) {
......
...@@ -40,9 +40,6 @@ def test_compile(): ...@@ -40,9 +40,6 @@ def test_compile():
assert graph.index.num_nodes == 4 assert graph.index.num_nodes == 4
verify(graph, lib) verify(graph, lib)
def test_run(): def test_run():
x = sym.Variable("x") x = sym.Variable("x")
y = sym.Variable("y") y = sym.Variable("y")
......
...@@ -12,8 +12,10 @@ def test_simplify_batchnorm(): ...@@ -12,8 +12,10 @@ def test_simplify_batchnorm():
sym.elemwise_mul(sym.negative(moving_mean), scale), beta) sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
shape = [-1 if i == axis else 1 for i in range(len(shape))] shape = [-1 if i == axis else 1 for i in range(len(shape))]
# for 2D # for 2D
scale = sym.reshape(scale, shape=shape) num_newaxis=len(shape) - axis - 1
shift = sym.reshape(shift, shape=shape) if num_newaxis:
scale = sym.expand_dims(scale, axis=axis, num_newaxis=num_newaxis)
shift = sym.expand_dims(shift, axis=axis, num_newaxis=num_newaxis)
return x * scale + shift return x * scale + shift
...@@ -25,7 +27,7 @@ def test_simplify_batchnorm(): ...@@ -25,7 +27,7 @@ def test_simplify_batchnorm():
gamma = sym.Variable("gamma") gamma = sym.Variable("gamma")
moving_var = sym.Variable("moving_var") moving_var = sym.Variable("moving_var")
moving_mean = sym.Variable("moving_mean") moving_mean = sym.Variable("moving_mean")
y1, y2 = x, x y1, y2 = x, sym.Variable("xx") + 1
ishape = {"x": tuple(10 for i in range(dim))} ishape = {"x": tuple(10 for i in range(dim))}
for i in range(nstep): for i in range(nstep):
y1 = sym.batch_norm( y1 = sym.batch_norm(
...@@ -44,6 +46,7 @@ def test_simplify_batchnorm(): ...@@ -44,6 +46,7 @@ def test_simplify_batchnorm():
check(2, 1, 1) check(2, 1, 1)
check(4, 0, 3) check(4, 0, 3)
check(4, 1, 2)
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_batchnorm() test_simplify_batchnorm()
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
import nnvm.symbol as sym import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
import nnvm.runtime import nnvm.runtime
from nnvm.testing.config import test_ctx_list
def ctx_list():
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist]
def test_relu(): def test_relu():
x = sym.Variable("x") x = sym.Variable("x")
...@@ -17,7 +12,7 @@ def test_relu(): ...@@ -17,7 +12,7 @@ def test_relu():
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
oshape = dshape oshape = dshape
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
# get member functions # get member functions
...@@ -40,7 +35,7 @@ def test_exp(): ...@@ -40,7 +35,7 @@ def test_exp():
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
oshape = dshape oshape = dshape
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
# get member functions # get member functions
...@@ -63,7 +58,7 @@ def test_log(): ...@@ -63,7 +58,7 @@ def test_log():
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
oshape = dshape oshape = dshape
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1): with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
...@@ -87,7 +82,7 @@ def test_tanh(): ...@@ -87,7 +82,7 @@ def test_tanh():
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
oshape = dshape oshape = dshape
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1): with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
...@@ -111,7 +106,7 @@ def test_sigmoid(): ...@@ -111,7 +106,7 @@ def test_sigmoid():
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
oshape = dshape oshape = dshape
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
# get member functions # get member functions
...@@ -134,7 +129,7 @@ def test_softmax(): ...@@ -134,7 +129,7 @@ def test_softmax():
dtype = "float32" dtype = "float32"
dshape = (10, 1000) dshape = (10, 1000)
oshape = dshape oshape = dshape
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1): with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
...@@ -187,7 +182,7 @@ def test_batchnorm(): ...@@ -187,7 +182,7 @@ def test_batchnorm():
y = sym.batch_norm( y = sym.batch_norm(
x, gamma, beta, moving_mean, moving_var, epsilon=eps) x, gamma, beta, moving_mean, moving_var, epsilon=eps)
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape}) graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape})
m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
x_np = np.random.uniform(size=shape).astype(dtype) x_np = np.random.uniform(size=shape).astype(dtype)
......
...@@ -5,10 +5,8 @@ import topi ...@@ -5,10 +5,8 @@ import topi
import nnvm.symbol as sym import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
import nnvm.runtime import nnvm.runtime
from nnvm.testing.config import test_ctx_list
def ctx_list():
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist]
def test_conv2d(): def test_conv2d():
x = sym.Variable("x") x = sym.Variable("x")
...@@ -19,7 +17,7 @@ def test_conv2d(): ...@@ -19,7 +17,7 @@ def test_conv2d():
kshape = (10, 3, 3, 3) kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18) oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape} shape_dict = {"x": dshape}
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
# get member functions # get member functions
...@@ -42,29 +40,25 @@ def test_conv2d(): ...@@ -42,29 +40,25 @@ def test_conv2d():
def test_grouped_conv2d(): def test_grouped_conv2d():
x = sym.Variable("x") x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32, y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32,
name="y", use_bias=False, padding=(1,1)) name="y", padding=(1,1))
dtype = "float32" dtype = "float32"
dshape = (1, 32, 18, 18) dshape = (1, 32, 18, 18)
kshape = (32, 1, 3, 3) kshape = (32, 1, 3, 3)
oshape = (1, 32, 18, 18) oshape = (1, 32, 18, 18)
shape_dict = {"x": dshape} shape_dict = {"x": dshape}
for target, ctx in ctx_list(): for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx) m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input # set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("x", data) bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
set_input("y_weight", kernel) m.run(x=data, y_weight=kernel, y_bias=bias)
# execute
run()
# get output # get output
out = tvm.nd.empty(oshape, dtype) out = m.get_output(0, tvm.nd.empty(oshape, dtype))
get_output(0, out)
c_np = topi.testing.depthwise_conv2d_python_nchw( c_np = topi.testing.depthwise_conv2d_python_nchw(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
......
...@@ -34,6 +34,16 @@ def test_concatenate(): ...@@ -34,6 +34,16 @@ def test_concatenate():
assert(sdict["concat"][0] == [20, 20]) assert(sdict["concat"][0] == [20, 20])
def test_expand_dims():
x = sym.Variable("x", shape=(10, 20))
y = sym.expand_dims(x, axis=1, name="y")
sdict = infer_shape(y)
assert(sdict["y"][0] == [10, 1, 20])
y = sym.expand_dims(x, axis=-1, name="y", num_newaxis=2)
sdict = infer_shape(y)
assert(sdict["y"][0] == [10, 20, 1, 1])
def test_split(): def test_split():
x1 = sym.Variable("x", shape=(10, 20)) x1 = sym.Variable("x", shape=(10, 20))
z = sym.split(x1, indices_or_sections=[11], name="y") z = sym.split(x1, indices_or_sections=[11], name="y")
...@@ -247,6 +257,7 @@ def test_reduce(): ...@@ -247,6 +257,7 @@ def test_reduce():
if __name__ == "__main__": if __name__ == "__main__":
test_expand_dims()
test_dense() test_dense()
test_concatenate() test_concatenate()
test_split() test_split()
......
...@@ -19,6 +19,11 @@ def test_concatenate_split(): ...@@ -19,6 +19,11 @@ def test_concatenate_split():
z = sym.split(y, indices_or_sections=[10, 20]) z = sym.split(y, indices_or_sections=[10, 20])
assert len(z.list_output_names()) == 3 assert len(z.list_output_names()) == 3
def test_expand_dims():
x = sym.Variable('x')
y = sym.expand_dims(x, axis=1, num_newaxis=2)
assert y.list_input_names() == ['x']
def test_unary(): def test_unary():
x = sym.Variable('x') x = sym.Variable('x')
...@@ -39,6 +44,7 @@ def test_batchnorm(): ...@@ -39,6 +44,7 @@ def test_batchnorm():
if __name__ == "__main__": if __name__ == "__main__":
test_concatenate_split() test_concatenate_split()
test_expand_dims()
test_dense() test_dense()
test_unary() test_unary()
test_batchnorm() test_batchnorm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment