Commit bdd4bab0 by Tianqi Chen Committed by GitHub

[NNVM] Introduce const shift ops (#1325)

parent 531bb7c4
Subproject commit 0b7e25275138768bb05edb9b9db2c86d0fb09c9a
Subproject commit 9204453ae8de77e7dfc32c4d80f58dd788ad75ff
......@@ -88,6 +88,8 @@ This level enables typical convnet models.
nnvm.symbol.__rdiv_scalar__
nnvm.symbol.__pow_scalar__
nnvm.symbol.__rpow_scalar__
nnvm.symbol.__lshift_scalar__
nnvm.symbol.__rshift_scalar__
**Level 4: Broadcast and Reductions**
......@@ -164,6 +166,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.__rdiv_scalar__
.. autofunction:: nnvm.symbol.__pow_scalar__
.. autofunction:: nnvm.symbol.__rpow_scalar__
.. autofunction:: nnvm.symbol.__lshift_scalar__
.. autofunction:: nnvm.symbol.__rshift_scalar__
.. autofunction:: nnvm.symbol.transpose
.. autofunction:: nnvm.symbol.broadcast_to
......
......@@ -100,6 +100,20 @@ class Symbol(SymbolBase):
else:
raise TypeError('type %s not supported' % str(type(other)))
def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y"""
if isinstance(other, _Number):
return __lshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y"""
if isinstance(other, _Number):
return __rshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
return self.__div__(other)
......
......@@ -133,6 +133,14 @@ reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# lshift_scalar
reg.register_pattern("__lshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__lshift_scalar__", _fschedule_broadcast)
# rshift_scalar
reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rshift_scalar__", _fschedule_broadcast)
# elemwise_add
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast)
......
......@@ -512,6 +512,39 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__)
};
});
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__lshift_scalar__)
.describe(R"code(Tensor left shift by scalar
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
int scalar_val = static_cast<int>(param.scalar);
return Array<Tensor>{
topi::left_shift(inputs[0],
make_const(inputs[0]->dtype, scalar_val))};
});
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rshift_scalar__)
.describe(R"code(Tensor right shift by scalar
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
int scalar_val = static_cast<int>(param.scalar);
return Array<Tensor>{
topi::right_shift(inputs[0],
make_const(inputs[0]->dtype, scalar_val))};
});
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__)
.describe(R"code(Tensor multiplies scalar
......
......@@ -7,17 +7,21 @@ import nnvm.compiler
from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True):
np_forward, np_backward=None,
need_input=True, need_head_grads=True,
rnd_min=-1, rnd_max=1):
ishapes = {}
itypes = {}
input_syms = []
np_inputs = {}
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
itypes.update({name: dtype})
np_inputs.update({name: np.random.uniform(rnd_min, rnd_max, size=shape).astype(dtype)})
input_syms.append(s)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes, itypes)
m = graph_runtime.create(graph, lib, ctx)
m.run(**np_inputs)
y_np = np_forward(**np_inputs)
......@@ -164,7 +168,7 @@ def test_log():
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
helper(y, inputs, dtype, forward, backward, rnd_min=0.001)
def test_tanh():
......@@ -277,7 +281,7 @@ def test_batchnorm():
('moving_var', (20,), moving_mean)
]
helper(y, inputs, dtype, forward)
helper(y, inputs, dtype, forward, rnd_min=0.001)
def verify_concatenate(ishape, axis):
......
......@@ -7,13 +7,13 @@ import nnvm.compiler
from nnvm.testing.config import ctx_list
from test_top_level1 import helper
def check_map(symfunc, np_func, np_backward=None):
def check_map(symfunc, np_func, np_backward=None, dtype="float32", rnd_min=-1, rnd_max=1):
x = sym.Variable("x")
y = symfunc(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, lambda x: np_func(x), np_backward)
helper(y, inputs, dtype, lambda x: np_func(x), np_backward,
rnd_min=rnd_min, rnd_max=rnd_max)
def test_floor():
......@@ -29,7 +29,14 @@ def test_round():
check_map(sym.round, np.round)
def test_shift():
n = 3
for dtype in ["int32", "int8"]:
check_map(lambda x : x >> n, lambda x: x >> n, dtype=dtype, rnd_min=-100, rnd_max=100)
check_map(lambda x : x << n, lambda x: x << n, dtype=dtype, rnd_min=-100, rnd_max=100)
if __name__ == "__main__":
test_shift()
test_floor()
test_ceil()
test_round()
......
......@@ -210,7 +210,7 @@ def right_shift(lhs, rhs):
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.left_shift(lhs, rhs)
return _cpp.right_shift(lhs, rhs)
def greater(lhs, rhs):
......
......@@ -68,7 +68,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
if rhs_shape is None:
rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max))
if dtype.startswith('int'):
lhs_npy = int(lhs_npy)
rhs_npy = int(rhs_npy)
rhs_nd = rhs_npy
else:
rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max,
......@@ -77,8 +77,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
out_npy = fnumpy(lhs_npy, rhs_npy)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
for _ in range(1):
foo(lhs_nd, rhs_nd, out_nd)
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("opencl")
......@@ -142,8 +141,23 @@ def test_cmp():
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), less, np.less)
def test_shift():
# explicit specify the output type
verify_broadcast_binary_ele(
(2, 1, 2), None, topi.right_shift, np.right_shift,
dtype="int32", rhs_min=0, rhs_max=32)
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int32", rhs_min=0, rhs_max=32)
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int8", rhs_min=0, rhs_max=32)
if __name__ == "__main__":
test_shift()
test_cmp()
test_mod()
test_add()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment