Unverified Commit f5f2feea by Tianqi Chen Committed by GitHub

[ARITH] migrate indexdiv/mod to floordiv/mod (#4008)

parent 2dac17d8
......@@ -92,16 +92,13 @@ class ExprOp(object):
return _generic.divide(other, self)
def __floordiv__(self, other):
# return _generic.floordiv(self, other)
return _generic.divide(self, other)
return _generic.floordiv(self, other)
def __rfloordiv__(self, other):
# return _generic.floordiv(other, self)
return _generic.divide(other, self)
return _generic.floordiv(other, self)
def __mod__(self, other):
raise div_ambiguity_error()
# return _make._OpMod(self, other)
return _make._OpFloorMod(self, other)
def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -87,6 +87,8 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
......@@ -119,6 +121,9 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Div);
ATTR_FUNCTOR_DISPATCH(Mod);
ATTR_FUNCTOR_DISPATCH(FloorDiv);
ATTR_FUNCTOR_DISPATCH(FloorMod);
ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE);
......@@ -160,6 +165,8 @@ class AttrsEqualHandler :
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
......@@ -201,6 +208,8 @@ class AttrsHashHandler :
size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Div* op) final;
size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::FloorDiv* op) final;
size_t VisitAttr_(const ir::FloorMod* op) final;
size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final;
size_t VisitAttr_(const ir::GE* op) final;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
......@@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
......
......@@ -32,8 +32,8 @@
namespace tvm {
// TODO(tqchen): change to floormod/div
using IndexMod = ir::Mod;
using IndexDiv = ir::Div;
using IndexMod = ir::FloorMod;
using IndexDiv = ir::FloorDiv;
Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) {
......
......@@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) {
// TODO(tqchen): switch to floordiv
Expr indexdiv(Expr a, Expr b) {
return truncdiv(a, b);
return floordiv(a, b);
}
Expr indexmod(Expr a, Expr b) {
return truncmod(a, b);
return floormod(a, b);
}
Expr floordiv(Expr a, Expr b) {
......
......@@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
......@@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) {
if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
return op->a >> make_const(dtype, shift);
}
......@@ -93,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if (dtype == Int(32) || dtype == Int(64)) {
if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
......@@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) {
if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
int64_t mask = (
static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
......@@ -140,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
Expr rmod = truncmod(op->a, op->b);
if (dtype == Int(32) || dtype == Int(64)) {
if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
......@@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// patterns
std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
bool support_bitwise_op_{true};
};
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
......
......@@ -48,6 +48,8 @@ def test_add_pipeline():
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
# lower the floordiv(use stackvm rules so it works for all targets)
fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"):
......
......@@ -37,6 +37,7 @@ def test_stack_vm_basic():
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm")
run_jit(fapi, lambda f: f(a))
......
......@@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial):
ib.scope_attr(bx, "thread_extent", nthread_bx)
var = tvm.make.node("FloatImm", dtype="float32", value=2)
new_range = num_anchors // elem_per_thread + 1
iteration = log(cast(new_range, "float32")) // math.log(2)
iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
# Scan: Kogge-Stone adder
with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
with ib.for_range(0, iteration) as k:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment