Unverified Commit f5f2feea by Tianqi Chen Committed by GitHub

[ARITH] migrate indexdiv/mod to floordiv/mod (#4008)

parent 2dac17d8
...@@ -92,16 +92,13 @@ class ExprOp(object): ...@@ -92,16 +92,13 @@ class ExprOp(object):
return _generic.divide(other, self) return _generic.divide(other, self)
def __floordiv__(self, other): def __floordiv__(self, other):
# return _generic.floordiv(self, other) return _generic.floordiv(self, other)
return _generic.divide(self, other)
def __rfloordiv__(self, other): def __rfloordiv__(self, other):
# return _generic.floordiv(other, self) return _generic.floordiv(other, self)
return _generic.divide(other, self)
def __mod__(self, other): def __mod__(self, other):
raise div_ambiguity_error() return _make._OpFloorMod(self, other)
# return _make._OpMod(self, other)
def __neg__(self): def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype) neg_one = _api_internal._const(-1, self.dtype)
......
...@@ -87,6 +87,8 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -87,6 +87,8 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
...@@ -119,6 +121,9 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -119,6 +121,9 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(Sub); ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul); ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Div); ATTR_FUNCTOR_DISPATCH(Div);
ATTR_FUNCTOR_DISPATCH(Mod);
ATTR_FUNCTOR_DISPATCH(FloorDiv);
ATTR_FUNCTOR_DISPATCH(FloorMod);
ATTR_FUNCTOR_DISPATCH(Min); ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max); ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE); ATTR_FUNCTOR_DISPATCH(GE);
...@@ -160,6 +165,8 @@ class AttrsEqualHandler : ...@@ -160,6 +165,8 @@ class AttrsEqualHandler :
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
...@@ -201,6 +208,8 @@ class AttrsHashHandler : ...@@ -201,6 +208,8 @@ class AttrsHashHandler :
size_t VisitAttr_(const ir::Mul* op) final; size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Div* op) final; size_t VisitAttr_(const ir::Div* op) final;
size_t VisitAttr_(const ir::Mod* op) final; size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::FloorDiv* op) final;
size_t VisitAttr_(const ir::FloorMod* op) final;
size_t VisitAttr_(const ir::Min* op) final; size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final; size_t VisitAttr_(const ir::Max* op) final;
size_t VisitAttr_(const ir::GE* op) final; size_t VisitAttr_(const ir::GE* op) final;
......
...@@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); ...@@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div); TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE); TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
...@@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub); ...@@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul); TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div); TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod); TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max); TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min); TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE); TVM_DEFINE_ATTRS_BINOP_HASH(GE);
......
...@@ -32,8 +32,8 @@ ...@@ -32,8 +32,8 @@
namespace tvm { namespace tvm {
// TODO(tqchen): change to floormod/div // TODO(tqchen): change to floormod/div
using IndexMod = ir::Mod; using IndexMod = ir::FloorMod;
using IndexDiv = ir::Div; using IndexDiv = ir::FloorDiv;
Array<Expr> SimplifyArray(Array<Expr> array) { Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) { for (size_t i = 0; i < array.size(); ++i) {
......
...@@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) { ...@@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) {
// TODO(tqchen): switch to floordiv // TODO(tqchen): switch to floordiv
Expr indexdiv(Expr a, Expr b) { Expr indexdiv(Expr a, Expr b) {
return truncdiv(a, b); return floordiv(a, b);
} }
Expr indexmod(Expr a, Expr b) { Expr indexmod(Expr a, Expr b) {
return truncmod(a, b); return floormod(a, b);
} }
Expr floordiv(Expr a, Expr b) { Expr floordiv(Expr a, Expr b) {
......
...@@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
patterns_.push_back("tvm.intrin.rule." + starget + "."); patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule.default."); patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma"); fma_ = runtime::Registry::Get(patterns_[0] + "fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
} }
Expr Mutate_(const Call* op, const Expr& e) final { Expr Mutate_(const Call* op, const Expr& e) final {
...@@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const DataType& dtype = op->type; const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint()); CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) { if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible. // lower to right shift if possible.
return op->a >> make_const(dtype, shift); return op->a >> make_const(dtype, shift);
} }
...@@ -93,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -93,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// condition on b >= 0. // condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv, // truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases. // So we need to correct these cases.
if (dtype == Int(32) || dtype == Int(64)) { if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1); // equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else { } else {
...@@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const DataType& dtype = op->type; const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint()); CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) { if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible. // lower to masking if possible.
int64_t mask = ( int64_t mask = (
static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1; static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
...@@ -140,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -140,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// mod(a, b) < 0 will imply we are doing ceildiv, // mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases. // So we need to correct these cases.
Expr rmod = truncmod(op->a, op->b); Expr rmod = truncmod(op->a, op->b);
if (dtype == Int(32) || dtype == Int(64)) { if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b // (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b // -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b // -> rmod >= 0 ? 0 : b
...@@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// patterns // patterns
std::vector<std::string> patterns_; std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr}; const PackedFunc* fma_{nullptr};
bool support_bitwise_op_{true};
}; };
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
......
...@@ -48,6 +48,8 @@ def test_add_pipeline(): ...@@ -48,6 +48,8 @@ def test_add_pipeline():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
# lower the floordiv(use stackvm rules so it works for all targets)
fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
......
...@@ -37,6 +37,7 @@ def test_stack_vm_basic(): ...@@ -37,6 +37,7 @@ def test_stack_vm_basic():
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm")
run_jit(fapi, lambda f: f(a)) run_jit(fapi, lambda f: f(a))
......
...@@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial): ...@@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial):
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "thread_extent", nthread_bx)
var = tvm.make.node("FloatImm", dtype="float32", value=2) var = tvm.make.node("FloatImm", dtype="float32", value=2)
new_range = num_anchors // elem_per_thread + 1 new_range = num_anchors // elem_per_thread + 1
iteration = log(cast(new_range, "float32")) // math.log(2) iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
# Scan: Kogge-Stone adder # Scan: Kogge-Stone adder
with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
with ib.for_range(0, iteration) as k: with ib.for_range(0, iteration) as k:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment