Commit 69b23d9e by Tianqi Chen Committed by GitHub

[OP] Improve bitwise op type checks (#1415)

parent 84824ae3
...@@ -42,12 +42,14 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis); ...@@ -42,12 +42,14 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis);
*/ */
TVM_DLL Expr min(Expr source, Array<IterVar> axis); TVM_DLL Expr min(Expr source, Array<IterVar> axis);
// Unary intrinsic operators // Unary intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \ #define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \ inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \ return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \ } \
TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
...@@ -58,7 +60,14 @@ TVM_DECLARE_INTRIN_UNARY(ceil); ...@@ -58,7 +60,14 @@ TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round); TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc); TVM_DECLARE_INTRIN_UNARY(trunc);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
inline Expr pow(Expr x, Expr y) { inline Expr pow(Expr x, Expr y) {
match_types(x, y);
CHECK(x.type().is_float()) << "power only applies to float";
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
} }
......
...@@ -67,19 +67,19 @@ class ExprOp(object): ...@@ -67,19 +67,19 @@ class ExprOp(object):
return self.__mul__(neg_one) return self.__mul__(neg_one)
def __lshift__(self, other): def __lshift__(self, other):
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0) return _make.left_shift(self, other)
def __rshift__(self, other): def __rshift__(self, other):
return _make.Call(self.dtype, "shift_right", [self, other], Call.PureIntrinsic, None, 0) return _make.right_shift(self, other)
def __and__(self, other): def __and__(self, other):
return _make.Call(self.dtype, "bitwise_and", [self, other], Call.PureIntrinsic, None, 0) return _make.bitwise_and(self, other)
def __or__(self, other): def __or__(self, other):
return _make.Call(self.dtype, "bitwise_or", [self, other], Call.PureIntrinsic, None, 0) return _make.bitwise_or(self, other)
def __xor__(self, other): def __xor__(self, other):
return _make.Call(self.dtype, "bitwise_xor", [self, other], Call.PureIntrinsic, None, 0) return _make.bitwise_xor(self, other)
def __invert__(self): def __invert__(self):
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
......
...@@ -111,12 +111,25 @@ TVM_REGISTER_API("make.CommReducer") ...@@ -111,12 +111,25 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \ *ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \ }) \
#define REGISTER_MAKE_BINARY_OP(Node) \ #define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \ Expr a = args[0], b = args[1]; \
match_types(a, b); \ *ret = (Func(a, b)); \
*ret = Node::make(a, b); \ })
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator Expr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator Expr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
} \
}) })
REGISTER_MAKE5(Reduce); REGISTER_MAKE5(Reduce);
...@@ -126,21 +139,26 @@ REGISTER_MAKE2(IntImm); ...@@ -126,21 +139,26 @@ REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm); REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm); REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm); REGISTER_MAKE1(StringImm);
REGISTER_MAKE_BINARY_OP(Add); REGISTER_MAKE_BINARY_OP(Add, operator+);
REGISTER_MAKE_BINARY_OP(Sub); REGISTER_MAKE_BINARY_OP(Sub, operator-);
REGISTER_MAKE_BINARY_OP(Mul); REGISTER_MAKE_BINARY_OP(Mul, operator*);
REGISTER_MAKE_BINARY_OP(Div); REGISTER_MAKE_BINARY_OP(Div, operator/);
REGISTER_MAKE_BINARY_OP(Mod); REGISTER_MAKE_BINARY_OP(Mod, operator%);
REGISTER_MAKE_BINARY_OP(Min); REGISTER_MAKE_BINARY_OP(Min, min);
REGISTER_MAKE_BINARY_OP(Max); REGISTER_MAKE_BINARY_OP(Max, max);
REGISTER_MAKE_BINARY_OP(EQ); REGISTER_MAKE_BINARY_OP(EQ, operator==);
REGISTER_MAKE_BINARY_OP(NE); REGISTER_MAKE_BINARY_OP(NE, operator!=);
REGISTER_MAKE_BINARY_OP(LT); REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(LE); REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GT); REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GE); REGISTER_MAKE_BINARY_OP(GE, operator>=);
REGISTER_MAKE_BINARY_OP(And); REGISTER_MAKE_BINARY_OP(And, operator&&);
REGISTER_MAKE_BINARY_OP(Or); REGISTER_MAKE_BINARY_OP(Or, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
REGISTER_MAKE1(Not); REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select); REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp); REGISTER_MAKE3(Ramp);
......
...@@ -10,6 +10,8 @@ def test_make(): ...@@ -10,6 +10,8 @@ def test_make():
x = tvm.const(1) x = tvm.const(1)
y = tvm.make.IntImm('int32', 1) y = tvm.make.IntImm('int32', 1)
z = x + y z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)
def test_ir(): def test_ir():
x = tvm.const(1) x = tvm.const(1)
...@@ -132,6 +134,9 @@ def test_bitwise(): ...@@ -132,6 +134,9 @@ def test_bitwise():
assert str(x | y) == 'bitwise_or(x, y)' assert str(x | y) == 'bitwise_or(x, y)'
assert str(x ^ y) == 'bitwise_xor(x, y)' assert str(x ^ y) == 'bitwise_xor(x, y)'
assert str(~x) == 'bitwise_not(x)' assert str(~x) == 'bitwise_not(x)'
assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2"
assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"
def test_equality(): def test_equality():
......
...@@ -64,6 +64,7 @@ def test_single_point_test(): ...@@ -64,6 +64,7 @@ def test_single_point_test():
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b): def assert_expr_equal(a, b):
print(a, b)
assert tvm.ir_pass.Simplify(a - b).value == 0 assert tvm.ir_pass.Simplify(a - b).value == 0
def test_copy_pad_split(): def test_copy_pad_split():
...@@ -87,6 +88,7 @@ def test_copy_pad_split(): ...@@ -87,6 +88,7 @@ def test_copy_pad_split():
def cb(src, dst, pad_before, pad_after, pad_value): def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0) assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1)
rpad_before = tvm.max(1 - xo * 4, 0) rpad_before = tvm.max(1 - xo * 4, 0)
rpad_after = tvm.max(xo * 4 - 7, 0) rpad_after = tvm.max(xo * 4 - 7, 0)
assert_expr_equal(pad_before[0], rpad_before) assert_expr_equal(pad_before[0], rpad_before)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment