Commit 3a0b757c by Siva Committed by Tianqi Chen

[NNVM][TOP] broadcast versions corresponding to topi: mod, max, min, pow,…

[NNVM][TOP] broadcast versions corresponding to topi: mod, max, min, pow, left_shift, right_shift greater, less, equal, not_equal, greater_equal and less_equal. (#1383)
parent 0d673a9d
......@@ -544,7 +544,7 @@ def _get_convert_map(opset):
'Exp': Renamer('exp'),
'Log': Renamer('log'),
'Tanh': Renamer('tanh'),
# 'Pow'
'Pow': Renamer('broadcast_pow'),
'PRelu': Prelu.get_converter(opset),
'Sigmoid': Renamer('sigmoid'),
# 'HardSigmoid'
......
......@@ -168,6 +168,54 @@ reg.register_schedule("broadcast_mul", _fschedule_broadcast)
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
reg.register_schedule("broadcast_div", _fschedule_broadcast)
# broadcast mod
reg.register_pattern("broadcast_mod", OpPattern.BROADCAST)
reg.register_schedule("broadcast_mod", _fschedule_broadcast)
# broadcast max
reg.register_pattern("broadcast_max", OpPattern.BROADCAST)
reg.register_schedule("broadcast_max", _fschedule_broadcast)
# broadcast min
reg.register_pattern("broadcast_min", OpPattern.BROADCAST)
reg.register_schedule("broadcast_min", _fschedule_broadcast)
# broadcast pow
reg.register_pattern("broadcast_pow", OpPattern.BROADCAST)
reg.register_schedule("broadcast_pow", _fschedule_broadcast)
# broadcast left_shift
reg.register_pattern("broadcast_left_shift", OpPattern.BROADCAST)
reg.register_schedule("broadcast_left_shift", _fschedule_broadcast)
# broadcast right_shift
reg.register_pattern("broadcast_right_shift", OpPattern.BROADCAST)
reg.register_schedule("broadcast_right_shift", _fschedule_broadcast)
# broadcast greater
reg.register_pattern("broadcast_greater", OpPattern.BROADCAST)
reg.register_schedule("broadcast_greater", _fschedule_broadcast)
# broadcast less
reg.register_pattern("broadcast_less", OpPattern.BROADCAST)
reg.register_schedule("broadcast_less", _fschedule_broadcast)
# broadcast equal
reg.register_pattern("broadcast_equal", OpPattern.BROADCAST)
reg.register_schedule("broadcast_equal", _fschedule_broadcast)
# broadcast not_equal
reg.register_pattern("broadcast_not_equal", OpPattern.BROADCAST)
reg.register_schedule("broadcast_not_equal", _fschedule_broadcast)
# broadcast greater_equal
reg.register_pattern("broadcast_greater_equal", OpPattern.BROADCAST)
reg.register_schedule("broadcast_greater_equal", _fschedule_broadcast)
# broadcast less_equal
reg.register_pattern("broadcast_less_equal", OpPattern.BROADCAST)
reg.register_schedule("broadcast_less_equal", _fschedule_broadcast)
# broadcast_to
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast)
......
......@@ -15,6 +15,7 @@
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/broadcast.h"
#include "topi/elemwise.h"
namespace nnvm {
namespace top {
......@@ -346,5 +347,251 @@ Example::
return std::vector<NodeEntry>{ dlhs, drhs };
});
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mod, mod)
.add_alias("__mod_symbol__")
.describe(R"code(Returns element-wise mod of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 3.]]
broadcast_mod(x, y) = [[ 1., 0., 1.],
[ 1., 2., 0.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_max, maximum)
.add_alias("__max_symbol__")
.describe(R"code(Returns element-wise max of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 3.]]
broadcast_max(x, y) = [[ 2., 2., 3.],
[ 4., 5., 6.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_min, minimum)
.add_alias("__min_symbol__")
.describe(R"code(Returns element-wise minimum of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 3.]]
broadcast_min(x, y) = [[ 1., 2., 2.],
[ 3., 3., 3.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_pow, power)
.add_alias("__pow_symbol__")
.describe(R"code(Returns element-wise x^y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 1.],
[ 2.]]
broadcast_pow(x, y) = [[ 1., 2., 3. ],
[ 16., 25., 36.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_left_shift, left_shift)
.add_alias("__left_shift_symbol__")
.describe(R"code(Returns element-wise x << y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 1.]]
broadcast_left_shift(x, y) = [[ 4., 8., 12.],
[ 8., 10., 12.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_right_shift, right_shift)
.add_alias("__right_shift_symbol__")
.describe(R"code(Returns element-wise x >> y of the input arrays with broadcasting.
Example::
x = [[ 4., 8., 12.],
[ 8., 10., 12.]]
y = [[ 2.],
[ 1.]]
broadcast_right_shift(x, y) = [[ 1., 2., 3.],
[ 4., 5., 6.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_greater, greater)
.add_alias("__greater_symbol__")
.describe(R"code(Returns element-wise x > y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 3.]]
broadcast_greater(x, y) = [[ 0., 0., 1.],
[ 1., 1., 1.]]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::greater(inputs[0], inputs[1]), out_info[0]->dtype) };
}, 11);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_less, less)
.add_alias("__less_symbol__")
.describe(R"code(Returns element-wise x < y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 3.]]
broadcast_less(x, y) = [[ 1., 0., 0.],
[ 0., 0., 0.]]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::less(inputs[0], inputs[1]), out_info[0]->dtype) };
}, 11);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_equal, equal)
.add_alias("__equal_symbol__")
.describe(R"code(Returns element-wise x == y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 5.]]
broadcast_equal(x, y) = [[ 0., 1., 0.],
[ 0., 1., 0.]]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::equal(inputs[0], inputs[1]), out_info[0]->dtype) };
}, 11);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_not_equal, not_equal)
.add_alias("__not_equal_symbol__")
.describe(R"code(Returns element-wise x != y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 4.]]
broadcast_not_equal(x, y) = [[ 1., 0., 1.],
[ 0., 1., 1.]]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::not_equal(inputs[0],
inputs[1]),
out_info[0]->dtype) };
}, 11);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_greater_equal, greater_equal)
.add_alias("__greater_equal_symbol__")
.describe(R"code(Returns element-wise x >= y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 2.],
[ 6.]]
broadcast_greater_equal(x, y) = [[ 0., 1., 1.],
[ 0., 0., 1.]]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::greater_equal(inputs[0],
inputs[1]),
out_info[0]->dtype) };
}, 11);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_less_equal, less_equal)
.add_alias("__less_equal_symbol__")
.describe(R"code(Returns element-wise x <= y of the input arrays with broadcasting.
Example::
x = [[ 1., 2., 3.],
[ 4., 5., 6.]]
y = [[ 1.],
[ 5.]]
broadcast_less_equal(x, y) = [[ 1., 0., 0.],
[ 1., 1., 0.]]
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::less_equal(inputs[0],
inputs[1]),
out_info[0]->dtype) };
}, 11);
} // namespace top
} // namespace nnvm
......@@ -9,17 +9,23 @@ from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True):
np_forward, np_backward=None,
need_input=True, need_head_grads=True, in_range={}):
ishapes = {}
input_syms = []
np_inputs = {}
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
if name in in_range:
np_inputs.update({name: np.random.uniform(size=shape,
low=in_range[name][0],
high=in_range[name][1]).astype(dtype)})
else:
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
input_syms.append(s)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes, dtype=dtype)
m = graph_runtime.create(graph, lib, ctx)
m.run(**np_inputs)
y_np = np_forward(**np_inputs)
......@@ -228,6 +234,49 @@ def test_broadcast():
return da, db
helper(y, inputs, dtype, lambda a, b: a / b, _backward_div)
y = sym.broadcast_mod(a, b)
helper(y, inputs, 'int32',
lambda a, b: np.mod(a, b),
in_range={'a': (0.001, 100), 'b': (1, 100)})
y = sym.broadcast_max(a, b)
helper(y, inputs, dtype, lambda a, b: np.maximum(a, b))
y = sym.broadcast_min(a, b)
helper(y, inputs, dtype, lambda a, b: np.minimum(a, b))
y = sym.broadcast_pow(a, b)
helper(y, inputs, dtype,
lambda a, b: np.power(a, b),
in_range={'a': (0.001, 100), 'b': (0.001, 2)})
y = sym.broadcast_left_shift(a, b)
helper(y, inputs, 'int32', lambda a, b: a << b)
y = sym.broadcast_right_shift(a, b)
helper(y, inputs, 'int32', lambda a, b: a >> b)
y = sym.broadcast_greater(a, b)
helper(y, inputs, dtype, lambda a, b: np.greater(a, b))
y = sym.broadcast_less(a, b)
helper(y, inputs, dtype, lambda a, b: np.less(a, b))
y = sym.broadcast_equal(a, b)
helper(y, inputs, 'int32', lambda a, b: np.equal(a, b),
in_range={'a': (-2, 2), 'b': (-2, 2)})
y = sym.broadcast_not_equal(a, b)
helper(y, inputs, 'int32', lambda a, b: np.not_equal(a, b),
in_range={'a': (-2, 2), 'b': (-2, 2)})
y = sym.broadcast_greater_equal(a, b)
helper(y, inputs, 'int32', lambda a, b: np.greater_equal(a, b),
in_range={'a': (-3, 3), 'b': (-3, 3)})
y = sym.broadcast_less_equal(a, b)
helper(y, inputs, 'int32', lambda a, b: np.less_equal(a, b),
in_range={'a': (-3, 3), 'b': (-3, 3)})
def test_greater():
l = sym.Variable("l")
......
......@@ -108,6 +108,50 @@ def test_reshape_like():
np.testing.assert_allclose(ref_shape, tvm_out.shape)
def _test_power_iteration(x_shape, y_shape):
if isinstance(y_shape, int):
y_shape = [y_shape]
x = np.random.uniform(size=x_shape).astype(np.float32)
y = np.random.uniform(size=y_shape).astype(np.float32)
np_res = np.power(x, y).astype(np.float32)
res = helper.make_node("Pow", ['x', 'y'], ['out'])
graph = helper.make_graph([res],
'power_test',
inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))])
model = helper.make_model(graph, producer_name='power_test')
for target, ctx in ctx_list():
new_sym, params = nnvm.frontend.from_onnx(model)
input_name = model.graph.input[0].name
input_name1 = model.graph.input[1].name
shape_dict = {input_name: x.shape, input_name1: y.shape}
dtype_dict = {input_name: x.dtype, input_name1: y.dtype}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_name, tvm.nd.array(x))
m.set_input(input_name1, tvm.nd.array(y))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(np_res.shape, np_res.dtype))
np.testing.assert_allclose(np_res, tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
def test_power():
_test_power_iteration((1, 3), (1))
_test_power_iteration((2, 3), (2, 3))
_test_power_iteration((2, 3), (1, 3))
def test_squeeze():
in_shape = (1, 3, 1, 3, 1, 1)
out_shape = (3, 3)
......@@ -247,6 +291,7 @@ if __name__ == '__main__':
verify_resnet18()
test_reshape()
test_reshape_like()
test_power()
test_squeeze()
test_unsqueeze()
test_slice()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment