Commit 69b23d9e by Tianqi Chen Committed by GitHub

[OP] Improve bitwise op type checks (#1415)

parent 84824ae3
......@@ -42,12 +42,14 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis);
*/
TVM_DLL Expr min(Expr source, Array<IterVar> axis);
// Unary intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
......@@ -58,7 +60,14 @@ TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
inline Expr pow(Expr x, Expr y) {
match_types(x, y);
CHECK(x.type().is_float()) << "power only applies to float";
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}
......
......@@ -67,19 +67,19 @@ class ExprOp(object):
return self.__mul__(neg_one)
def __lshift__(self, other):
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)
return _make.left_shift(self, other)
def __rshift__(self, other):
return _make.Call(self.dtype, "shift_right", [self, other], Call.PureIntrinsic, None, 0)
return _make.right_shift(self, other)
def __and__(self, other):
return _make.Call(self.dtype, "bitwise_and", [self, other], Call.PureIntrinsic, None, 0)
return _make.bitwise_and(self, other)
def __or__(self, other):
return _make.Call(self.dtype, "bitwise_or", [self, other], Call.PureIntrinsic, None, 0)
return _make.bitwise_or(self, other)
def __xor__(self, other):
return _make.Call(self.dtype, "bitwise_xor", [self, other], Call.PureIntrinsic, None, 0)
return _make.bitwise_xor(self, other)
def __invert__(self):
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
......
......@@ -111,12 +111,25 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \
match_types(a, b); \
*ret = Node::make(a, b); \
*ret = (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator Expr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator Expr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
} \
})
REGISTER_MAKE5(Reduce);
......@@ -126,21 +139,26 @@ REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE_BINARY_OP(Add);
REGISTER_MAKE_BINARY_OP(Sub);
REGISTER_MAKE_BINARY_OP(Mul);
REGISTER_MAKE_BINARY_OP(Div);
REGISTER_MAKE_BINARY_OP(Mod);
REGISTER_MAKE_BINARY_OP(Min);
REGISTER_MAKE_BINARY_OP(Max);
REGISTER_MAKE_BINARY_OP(EQ);
REGISTER_MAKE_BINARY_OP(NE);
REGISTER_MAKE_BINARY_OP(LT);
REGISTER_MAKE_BINARY_OP(LE);
REGISTER_MAKE_BINARY_OP(GT);
REGISTER_MAKE_BINARY_OP(GE);
REGISTER_MAKE_BINARY_OP(And);
REGISTER_MAKE_BINARY_OP(Or);
REGISTER_MAKE_BINARY_OP(Add, operator+);
REGISTER_MAKE_BINARY_OP(Sub, operator-);
REGISTER_MAKE_BINARY_OP(Mul, operator*);
REGISTER_MAKE_BINARY_OP(Div, operator/);
REGISTER_MAKE_BINARY_OP(Mod, operator%);
REGISTER_MAKE_BINARY_OP(Min, min);
REGISTER_MAKE_BINARY_OP(Max, max);
REGISTER_MAKE_BINARY_OP(EQ, operator==);
REGISTER_MAKE_BINARY_OP(NE, operator!=);
REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GE, operator>=);
REGISTER_MAKE_BINARY_OP(And, operator&&);
REGISTER_MAKE_BINARY_OP(Or, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
......
......@@ -10,6 +10,8 @@ def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)
def test_ir():
x = tvm.const(1)
......@@ -132,6 +134,9 @@ def test_bitwise():
assert str(x | y) == 'bitwise_or(x, y)'
assert str(x ^ y) == 'bitwise_xor(x, y)'
assert str(~x) == 'bitwise_not(x)'
assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2"
assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"
def test_equality():
......
......@@ -64,6 +64,7 @@ def test_single_point_test():
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b):
print(a, b)
assert tvm.ir_pass.Simplify(a - b).value == 0
def test_copy_pad_split():
......@@ -87,6 +88,7 @@ def test_copy_pad_split():
def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1)
rpad_before = tvm.max(1 - xo * 4, 0)
rpad_after = tvm.max(xo * 4 - 7, 0)
assert_expr_equal(pad_before[0], rpad_before)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment