Commit bdd4bab0 by Tianqi Chen Committed by GitHub

[NNVM] Introduce const shift ops (#1325)

parent 531bb7c4
Subproject commit 0b7e25275138768bb05edb9b9db2c86d0fb09c9a Subproject commit 9204453ae8de77e7dfc32c4d80f58dd788ad75ff
...@@ -88,6 +88,8 @@ This level enables typical convnet models. ...@@ -88,6 +88,8 @@ This level enables typical convnet models.
nnvm.symbol.__rdiv_scalar__ nnvm.symbol.__rdiv_scalar__
nnvm.symbol.__pow_scalar__ nnvm.symbol.__pow_scalar__
nnvm.symbol.__rpow_scalar__ nnvm.symbol.__rpow_scalar__
nnvm.symbol.__lshift_scalar__
nnvm.symbol.__rshift_scalar__
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
...@@ -164,6 +166,8 @@ Detailed Definitions ...@@ -164,6 +166,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.__rdiv_scalar__ .. autofunction:: nnvm.symbol.__rdiv_scalar__
.. autofunction:: nnvm.symbol.__pow_scalar__ .. autofunction:: nnvm.symbol.__pow_scalar__
.. autofunction:: nnvm.symbol.__rpow_scalar__ .. autofunction:: nnvm.symbol.__rpow_scalar__
.. autofunction:: nnvm.symbol.__lshift_scalar__
.. autofunction:: nnvm.symbol.__rshift_scalar__
.. autofunction:: nnvm.symbol.transpose .. autofunction:: nnvm.symbol.transpose
.. autofunction:: nnvm.symbol.broadcast_to .. autofunction:: nnvm.symbol.broadcast_to
......
...@@ -100,6 +100,20 @@ class Symbol(SymbolBase): ...@@ -100,6 +100,20 @@ class Symbol(SymbolBase):
else: else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y"""
if isinstance(other, _Number):
return __lshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y"""
if isinstance(other, _Number):
return __rshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other): def __truediv__(self, other):
return self.__div__(other) return self.__div__(other)
......
...@@ -133,6 +133,14 @@ reg.register_schedule("__pow_scalar__", _fschedule_broadcast) ...@@ -133,6 +133,14 @@ reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE) reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# lshift_scalar
reg.register_pattern("__lshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__lshift_scalar__", _fschedule_broadcast)
# rshift_scalar
reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rshift_scalar__", _fschedule_broadcast)
# elemwise_add # elemwise_add
reg.register_pattern("elemwise_add", OpPattern.BROADCAST) reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast) reg.register_schedule("elemwise_add", _fschedule_broadcast)
......
...@@ -512,6 +512,39 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__) ...@@ -512,6 +512,39 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__)
}; };
}); });
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__lshift_scalar__)
.describe(R"code(Tensor left shift by scalar
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
int scalar_val = static_cast<int>(param.scalar);
return Array<Tensor>{
topi::left_shift(inputs[0],
make_const(inputs[0]->dtype, scalar_val))};
});
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rshift_scalar__)
.describe(R"code(Tensor right shift by scalar
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
int scalar_val = static_cast<int>(param.scalar);
return Array<Tensor>{
topi::right_shift(inputs[0],
make_const(inputs[0]->dtype, scalar_val))};
});
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__)
.describe(R"code(Tensor multiplies scalar .describe(R"code(Tensor multiplies scalar
......
...@@ -7,17 +7,21 @@ import nnvm.compiler ...@@ -7,17 +7,21 @@ import nnvm.compiler
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype, def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True): np_forward, np_backward=None,
need_input=True, need_head_grads=True,
rnd_min=-1, rnd_max=1):
ishapes = {} ishapes = {}
itypes = {}
input_syms = [] input_syms = []
np_inputs = {} np_inputs = {}
for (name, shape, s) in inputs: for (name, shape, s) in inputs:
ishapes.update({name: shape}) ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)}) itypes.update({name: dtype})
np_inputs.update({name: np.random.uniform(rnd_min, rnd_max, size=shape).astype(dtype)})
input_syms.append(s) input_syms.append(s)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes) graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes, itypes)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
m.run(**np_inputs) m.run(**np_inputs)
y_np = np_forward(**np_inputs) y_np = np_forward(**np_inputs)
...@@ -164,7 +168,7 @@ def test_log(): ...@@ -164,7 +168,7 @@ def test_log():
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)] inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward, rnd_min=0.001)
def test_tanh(): def test_tanh():
...@@ -277,7 +281,7 @@ def test_batchnorm(): ...@@ -277,7 +281,7 @@ def test_batchnorm():
('moving_var', (20,), moving_mean) ('moving_var', (20,), moving_mean)
] ]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward, rnd_min=0.001)
def verify_concatenate(ishape, axis): def verify_concatenate(ishape, axis):
......
...@@ -7,13 +7,13 @@ import nnvm.compiler ...@@ -7,13 +7,13 @@ import nnvm.compiler
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
from test_top_level1 import helper from test_top_level1 import helper
def check_map(symfunc, np_func, np_backward=None): def check_map(symfunc, np_func, np_backward=None, dtype="float32", rnd_min=-1, rnd_max=1):
x = sym.Variable("x") x = sym.Variable("x")
y = symfunc(x) y = symfunc(x)
dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)] inputs = [('x', dshape, x)]
helper(y, inputs, dtype, lambda x: np_func(x), np_backward) helper(y, inputs, dtype, lambda x: np_func(x), np_backward,
rnd_min=rnd_min, rnd_max=rnd_max)
def test_floor(): def test_floor():
...@@ -29,7 +29,14 @@ def test_round(): ...@@ -29,7 +29,14 @@ def test_round():
check_map(sym.round, np.round) check_map(sym.round, np.round)
def test_shift():
n = 3
for dtype in ["int32", "int8"]:
check_map(lambda x : x >> n, lambda x: x >> n, dtype=dtype, rnd_min=-100, rnd_max=100)
check_map(lambda x : x << n, lambda x: x << n, dtype=dtype, rnd_min=-100, rnd_max=100)
if __name__ == "__main__": if __name__ == "__main__":
test_shift()
test_floor() test_floor()
test_ceil() test_ceil()
test_round() test_round()
......
...@@ -210,7 +210,7 @@ def right_shift(lhs, rhs): ...@@ -210,7 +210,7 @@ def right_shift(lhs, rhs):
Returns Expr if both operands are Expr. Returns Expr if both operands are Expr.
Otherwise returns Tensor. Otherwise returns Tensor.
""" """
return _cpp.left_shift(lhs, rhs) return _cpp.right_shift(lhs, rhs)
def greater(lhs, rhs): def greater(lhs, rhs):
......
...@@ -68,7 +68,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, ...@@ -68,7 +68,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
if rhs_shape is None: if rhs_shape is None:
rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max)) rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max))
if dtype.startswith('int'): if dtype.startswith('int'):
lhs_npy = int(lhs_npy) rhs_npy = int(rhs_npy)
rhs_nd = rhs_npy rhs_nd = rhs_npy
else: else:
rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max, rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max,
...@@ -77,8 +77,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, ...@@ -77,8 +77,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
out_npy = fnumpy(lhs_npy, rhs_npy) out_npy = fnumpy(lhs_npy, rhs_npy)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
for _ in range(1): foo(lhs_nd, rhs_nd, out_nd)
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("opencl") check_device("opencl")
...@@ -142,8 +141,23 @@ def test_cmp(): ...@@ -142,8 +141,23 @@ def test_cmp():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), less, np.less) (2, 1, 2), (2, 3, 1), less, np.less)
def test_shift():
# explicit specify the output type
verify_broadcast_binary_ele(
(2, 1, 2), None, topi.right_shift, np.right_shift,
dtype="int32", rhs_min=0, rhs_max=32)
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int32", rhs_min=0, rhs_max=32)
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int8", rhs_min=0, rhs_max=32)
if __name__ == "__main__": if __name__ == "__main__":
test_shift()
test_cmp() test_cmp()
test_mod() test_mod()
test_add() test_add()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment