Unverified Commit d7a09150 by Tianqi Chen Committed by GitHub

[ARITH] Add Lowering rule for FloorDiv/Mod (#3976)

* [ARITH] Add Lowering rule for FloorDiv/Mod

* add comment about constant folding
parent 719d6d47
......@@ -333,6 +333,30 @@ TVM_DLL Expr operator||(Expr a, Expr b);
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief compute trunc(a / b)
*
* This is the default integer division behavior in C.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr truncdiv(Expr a, Expr b);
/*!
* \brief compute the remainder of truncdiv
*
* This is the default integer division behavior in C.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr truncmod(Expr a, Expr b);
/*!
* \brief compute floor(a / b)
*
* \param a left operand
......
......@@ -891,6 +891,52 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer
def truncdiv(a, b):
"""Compute the truncdiv of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _make._OpTruncDiv(a, b)
def truncmod(a, b):
"""Compute the truncmod of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _make._OpTruncMod(a, b)
def floordiv(a, b):
"""Compute the floordiv of two expressions.
......
......@@ -196,6 +196,8 @@ REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file const_fold.h
* \brief Centralized location for constant folding.
*/
......
......@@ -99,11 +99,12 @@ inline bool WillOverflow<ir::Mod>(int64_t x,
* \return the result.
*/
inline int64_t floordiv(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x / y) : (x / y - 1);
int64_t rdiv = x / y;
int64_t rmod = x % y;
bool is_floor_div =
(y >= 0 && rmod >= 0) ||
(y < 0 && rmod <= 0);
return is_floor_div ? rdiv : (rdiv - 1);
}
......@@ -114,11 +115,11 @@ inline int64_t floordiv(int64_t x, int64_t y) {
* \return the result.
*/
inline int64_t floormod(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x % y) : (x % y + y);
int64_t rmod = x % y;
bool is_floor_div =
(y >= 0 && rmod >= 0) ||
(y < 0 && rmod <= 0);
return is_floor_div ? rmod : rmod + y;
}
} // namespace arith
......
......@@ -41,8 +41,9 @@ Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
......@@ -152,8 +153,9 @@ Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
......
......@@ -45,6 +45,8 @@ class IRMutatorWithAnalyzer : public ir::IRMutator {
explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {}
using IRMutator::Mutate_;
// override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override;
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/pattern_match.h
*
* \brief Internal tool for expression-template based pattern matching.
......@@ -326,6 +325,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
......
......@@ -1674,6 +1674,16 @@ Mutate_(const Call* op, const Expr& self) {
if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
} else if (op->is_intrinsic(Call::shift_right)) {
if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
// the operator overload will eagerly constant fold.
return op->args[0] >> op->args[1];
}
} else if (op->is_intrinsic(Call::bitwise_and)) {
if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
// the operator overload will eagerly constant fold.
return op->args[0] & op->args[1];
}
}
return ret;
}
......@@ -1695,6 +1705,24 @@ Mutate_(const Cast* op, const Expr& self) {
return cast(op->type, op->value);
}
Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}
Expr RewriteSimplifier::operator()(const Expr& expr) {
// Run simplification in post order
Expr res = expr;
......
......@@ -72,6 +72,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
protected:
/*! \brief internal structure for comparison. */
......
......@@ -51,6 +51,23 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Mutate(stmt);
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value);
return Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
// eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
......
......@@ -178,20 +178,28 @@ Expr operator*(Expr a, Expr b) {
return ir::Mul::make(a, b);
}
Expr operator/(Expr a, Expr b) {
Expr truncdiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Div>(a, b);
if (ret.defined()) return ret;
return ir::Div::make(a, b);
}
Expr operator%(Expr a, Expr b) {
Expr truncmod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Mod>(a, b);
if (ret.defined()) return ret;
return ir::Mod::make(a, b);
}
Expr operator/(Expr a, Expr b) {
return truncdiv(a, b);
}
Expr operator%(Expr a, Expr b) {
return truncmod(a, b);
}
Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
......
......@@ -18,23 +18,28 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* Lower intrinsic calls to device specific ir when possible.
* Lower intrinsic calls and ops to device specific ir when possible.
* \file lower_intrin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include <tvm/expr_operator.h>
#include <unordered_set>
#include "ir_util.h"
#include "../arithmetic/pattern_match.h"
#include "../arithmetic/ir_mutator_with_analyzer.h"
namespace tvm {
namespace ir {
class IntrinInjecter : public IRMutator {
class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
public:
explicit IntrinInjecter(std::string target) {
using IRMutatorWithAnalyzer::Mutate_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target)
: IRMutatorWithAnalyzer(analyzer) {
std::istringstream is(target);
std::string starget;
is >> starget;
......@@ -61,6 +66,118 @@ class IntrinInjecter : public IRMutator {
return IRMutator::Mutate_(op, e);
}
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
Expr Mutate_(const FloorDiv* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
op = ret.as<FloorDiv>();
if (op == nullptr) return ret;
int shift;
const DataType& dtype = op->type;
if (dtype.is_float()) {
return floor(Div::make(op->a, op->b));
}
CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
return op->a >> make_const(dtype, shift);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common path, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
Expr rdiv = truncdiv(op->a, op->b);
Expr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if (dtype == Int(32) || dtype == Int(64)) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
return ir::Select::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
Expr rdiv = truncdiv(op->a, op->b);
Expr rmod = truncmod(op->a, op->b);
return ir::Select::make(
(op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rdiv, rdiv - make_const(dtype, 1));
}
}
Expr Mutate_(const FloorMod* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
op = ret.as<FloorMod>();
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
int shift;
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
int64_t mask = (
static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
return op->a & make_const(dtype, mask);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common pass, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
analyzer_->CanProveGreaterEqual(e, 0)) {
return truncmod(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
Expr rmod = truncmod(op->a, op->b);
if (dtype == Int(32) || dtype == Int(64)) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
return ir::Select::make(rmod >= 0, rmod, rmod + op->b);
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
Expr rmod = truncmod(op->a, op->b);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return ir::Select::make(
(op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rmod, rmod + op->b);
}
}
Expr Mutate_(const Max* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
PVar<Integer> c;
if (max(floordiv(x, y), c).Match(e) &&
c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(Mutate(truncdiv(x, y).Eval()), c.Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
private:
Expr SwapBroadcastCast(const Expr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x))
......@@ -132,17 +249,27 @@ class IntrinInjecter : public IRMutator {
}
return Expr();
}
// patterns
std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
};
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target).Mutate(stmt);
}
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = IntrinInjecter(target).Mutate(n->body);
n->body = LowerIntrinStmt(n->body, target);
return LoweredFunc(n);
}
// Register the api only for test purposes
TVM_REGISTER_API("ir_pass._LowerIntrinStmt")
.set_body_typed(LowerIntrinStmt);
} // namespace ir
} // namespace tvm
......@@ -87,6 +87,7 @@ def test_llvm_lookup_intrin():
func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm")
def test_llvm_add_pipeline():
nn = 1024
n = tvm.convert(nn)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
def lower_intrin(stmt):
"""wrapper to call transformation in stmt"""
lower_expr = isinstance(stmt, tvm.expr.Expr)
stmt = tvm.stmt.Evaluate(stmt) if lower_expr else stmt
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass._LowerIntrinStmt(stmt, "llvm")
return stmt.value if lower_expr else stmt.body
def check_value(expr, vx, vy, data, fref):
n = len(data)
A = tvm.placeholder((n,), name="A", dtype=expr.dtype)
B = tvm.placeholder((n,), name="B", dtype=expr.dtype)
def make_binds(i):
x = expr
x = tvm.expr.Let(vx, A[i], x)
x = tvm.expr.Let(vy, B[i], x)
return x
C = tvm.compute((n,), make_binds)
s = tvm.create_schedule([C.op])
if not tvm.module.enabled("llvm"):
return
f = tvm.build(s, [A, B, C], "llvm")
a = tvm.nd.array(np.array([x for x, y in data], dtype=expr.dtype))
b = tvm.nd.array(np.array([y for x, y in data], dtype=expr.dtype))
c = tvm.nd.array(np.zeros(len(data), dtype=expr.dtype))
f(a, b, c)
cref = np.array([fref(x, y) for x, y in data])
np.testing.assert_equal(c.asnumpy(), cref)
def get_ref_data():
"""Get reference data for every pairs"""
import itertools
x = range(-10, 10)
y = list(range(-10, 10))
y.remove(0)
return list(itertools.product(x, y))
def test_lower_floordiv():
data = get_ref_data()
for dtype in ["int32", "int64", "int16"]:
x = tvm.var("x", dtype=dtype)
y = tvm.var("y", dtype=dtype)
zero = tvm.const(0, dtype)
# no constraints
res = lower_intrin(tvm.floordiv(x, y))
check_value(res, x, y, data, lambda a, b: a // b)
# rhs >= 0
res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
# involves max
res = lower_intrin(tvm.expr.Select(y >= 0, tvm.max(tvm.floordiv(x, y), zero), zero))
check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
# lhs >= 0
res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin(tvm.floordiv(x, tvm.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
def test_lower_floormod():
data = get_ref_data()
for dtype in ["int32", "int64", "int16"]:
x = tvm.var("x", dtype=dtype)
y = tvm.var("y", dtype=dtype)
zero = tvm.const(0, dtype)
# no constraints
res = lower_intrin(tvm.floormod(x, y))
check_value(res, x, y, data, lambda a, b: a % b)
# rhs >= 0
res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
# lhs >= 0
res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin(tvm.floormod(x, tvm.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
if __name__ == "__main__":
test_lower_floordiv()
test_lower_floormod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment