Unverified Commit d7a09150 by Tianqi Chen Committed by GitHub

[ARITH] Add Lowering rule for FloorDiv/Mod (#3976)

* [ARITH] Add Lowering rule for FloorDiv/Mod

* add comment about constant folding
parent 719d6d47
...@@ -333,6 +333,30 @@ TVM_DLL Expr operator||(Expr a, Expr b); ...@@ -333,6 +333,30 @@ TVM_DLL Expr operator||(Expr a, Expr b);
*/ */
TVM_DLL Expr operator!(Expr a); TVM_DLL Expr operator!(Expr a);
/*! /*!
* \brief compute trunc(a / b)
*
* This is the default integer division behavior in C.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr truncdiv(Expr a, Expr b);
/*!
* \brief compute the remainder of truncdiv
*
* This is the default integer division behavior in C.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr truncmod(Expr a, Expr b);
/*!
* \brief compute floor(a / b) * \brief compute floor(a / b)
* *
* \param a left operand * \param a left operand
......
...@@ -891,6 +891,52 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -891,6 +891,52 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer return reducer
def truncdiv(a, b):
"""Compute the truncdiv of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _make._OpTruncDiv(a, b)
def truncmod(a, b):
"""Compute the truncmod of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _make._OpTruncMod(a, b)
def floordiv(a, b): def floordiv(a, b):
"""Compute the floordiv of two expressions. """Compute the floordiv of two expressions.
......
...@@ -196,6 +196,8 @@ REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); ...@@ -196,6 +196,8 @@ REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%); REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpPow, pow); REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpMax, max);
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file const_fold.h * \file const_fold.h
* \brief Centralized location for constant folding. * \brief Centralized location for constant folding.
*/ */
......
...@@ -99,11 +99,12 @@ inline bool WillOverflow<ir::Mod>(int64_t x, ...@@ -99,11 +99,12 @@ inline bool WillOverflow<ir::Mod>(int64_t x,
* \return the result. * \return the result.
*/ */
inline int64_t floordiv(int64_t x, int64_t y) { inline int64_t floordiv(int64_t x, int64_t y) {
bool round_down = int64_t rdiv = x / y;
(x >= 0 && y >= 0) || int64_t rmod = x % y;
(x <= 0 && y <= 0) || bool is_floor_div =
(x % y == 0); (y >= 0 && rmod >= 0) ||
return round_down ? (x / y) : (x / y - 1); (y < 0 && rmod <= 0);
return is_floor_div ? rdiv : (rdiv - 1);
} }
...@@ -114,11 +115,11 @@ inline int64_t floordiv(int64_t x, int64_t y) { ...@@ -114,11 +115,11 @@ inline int64_t floordiv(int64_t x, int64_t y) {
* \return the result. * \return the result.
*/ */
inline int64_t floormod(int64_t x, int64_t y) { inline int64_t floormod(int64_t x, int64_t y) {
bool round_down = int64_t rmod = x % y;
(x >= 0 && y >= 0) || bool is_floor_div =
(x <= 0 && y <= 0) || (y >= 0 && rmod >= 0) ||
(x % y == 0); (y < 0 && rmod <= 0);
return round_down ? (x % y) : (x % y + y); return is_floor_div ? rmod : rmod + y;
} }
} // namespace arith } // namespace arith
......
...@@ -41,8 +41,9 @@ Mutate_(const LetStmt* op, const Stmt& s) { ...@@ -41,8 +41,9 @@ Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
} }
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
...@@ -152,8 +153,9 @@ Mutate_(const Let* op, const Expr& self) { ...@@ -152,8 +153,9 @@ Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
} }
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Expr body = this->Mutate(op->body); Expr body = this->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
......
...@@ -45,6 +45,8 @@ class IRMutatorWithAnalyzer : public ir::IRMutator { ...@@ -45,6 +45,8 @@ class IRMutatorWithAnalyzer : public ir::IRMutator {
explicit IRMutatorWithAnalyzer(Analyzer* analyzer) explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {} : analyzer_(analyzer) {}
using IRMutator::Mutate_;
// override functions that need to populate the context information. // override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override; Stmt Mutate_(const ir::For* op, const Stmt& self) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override; Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override;
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/pattern_match.h * \file tvm/arithmetic/pattern_match.h
* *
* \brief Internal tool for expression-template based pattern matching. * \brief Internal tool for expression-template based pattern matching.
...@@ -326,6 +325,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div); ...@@ -326,6 +325,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod); TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min); TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max); TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv); TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod); TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
......
...@@ -1674,6 +1674,16 @@ Mutate_(const Call* op, const Expr& self) { ...@@ -1674,6 +1674,16 @@ Mutate_(const Call* op, const Expr& self) {
if (op == nullptr) return ret; if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0]; return op->args[0];
} else if (op->is_intrinsic(Call::shift_right)) {
if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
// the operator overload will eagerly constant fold.
return op->args[0] >> op->args[1];
}
} else if (op->is_intrinsic(Call::bitwise_and)) {
if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
// the operator overload will eagerly constant fold.
return op->args[0] & op->args[1];
}
} }
return ret; return ret;
} }
...@@ -1695,6 +1705,24 @@ Mutate_(const Cast* op, const Expr& self) { ...@@ -1695,6 +1705,24 @@ Mutate_(const Cast* op, const Expr& self) {
return cast(op->type, op->value); return cast(op->type, op->value);
} }
Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}
Expr RewriteSimplifier::operator()(const Expr& expr) { Expr RewriteSimplifier::operator()(const Expr& expr) {
// Run simplification in post order // Run simplification in post order
Expr res = expr; Expr res = expr;
......
...@@ -72,6 +72,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { ...@@ -72,6 +72,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr Mutate_(const Call* op, const Expr& self) override; Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override; Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override; Expr Mutate_(const Cast* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
protected: protected:
/*! \brief internal structure for comparison. */ /*! \brief internal structure for comparison. */
......
...@@ -51,6 +51,23 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -51,6 +51,23 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Mutate(stmt); return Mutate(stmt);
} }
Stmt Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value);
return Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
// eliminate useless stores // eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
......
...@@ -178,20 +178,28 @@ Expr operator*(Expr a, Expr b) { ...@@ -178,20 +178,28 @@ Expr operator*(Expr a, Expr b) {
return ir::Mul::make(a, b); return ir::Mul::make(a, b);
} }
Expr operator/(Expr a, Expr b) { Expr truncdiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Div>(a, b); Expr ret = arith::TryConstFold<ir::Div>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
return ir::Div::make(a, b); return ir::Div::make(a, b);
} }
Expr operator%(Expr a, Expr b) { Expr truncmod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Mod>(a, b); Expr ret = arith::TryConstFold<ir::Mod>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
return ir::Mod::make(a, b); return ir::Mod::make(a, b);
} }
Expr operator/(Expr a, Expr b) {
return truncdiv(a, b);
}
Expr operator%(Expr a, Expr b) {
return truncmod(a, b);
}
Expr floordiv(Expr a, Expr b) { Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b); Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,23 +18,28 @@ ...@@ -18,23 +18,28 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors * Lower intrinsic calls and ops to device specific ir when possible.
* Lower intrinsic calls to device specific ir when possible.
* \file lower_intrin.cc * \file lower_intrin.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/expr_operator.h>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h" #include "ir_util.h"
#include "../arithmetic/pattern_match.h"
#include "../arithmetic/ir_mutator_with_analyzer.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class IntrinInjecter : public IRMutator { class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
public: public:
explicit IntrinInjecter(std::string target) { using IRMutatorWithAnalyzer::Mutate_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target)
: IRMutatorWithAnalyzer(analyzer) {
std::istringstream is(target); std::istringstream is(target);
std::string starget; std::string starget;
is >> starget; is >> starget;
...@@ -61,6 +66,118 @@ class IntrinInjecter : public IRMutator { ...@@ -61,6 +66,118 @@ class IntrinInjecter : public IRMutator {
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
} }
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
Expr Mutate_(const FloorDiv* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
op = ret.as<FloorDiv>();
if (op == nullptr) return ret;
int shift;
const DataType& dtype = op->type;
if (dtype.is_float()) {
return floor(Div::make(op->a, op->b));
}
CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
return op->a >> make_const(dtype, shift);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common path, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
Expr rdiv = truncdiv(op->a, op->b);
Expr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if (dtype == Int(32) || dtype == Int(64)) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
return ir::Select::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
Expr rdiv = truncdiv(op->a, op->b);
Expr rmod = truncmod(op->a, op->b);
return ir::Select::make(
(op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rdiv, rdiv - make_const(dtype, 1));
}
}
Expr Mutate_(const FloorMod* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
op = ret.as<FloorMod>();
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
int shift;
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
int64_t mask = (
static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
return op->a & make_const(dtype, mask);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common pass, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
analyzer_->CanProveGreaterEqual(e, 0)) {
return truncmod(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
Expr rmod = truncmod(op->a, op->b);
if (dtype == Int(32) || dtype == Int(64)) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
return ir::Select::make(rmod >= 0, rmod, rmod + op->b);
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
Expr rmod = truncmod(op->a, op->b);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return ir::Select::make(
(op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rmod, rmod + op->b);
}
}
Expr Mutate_(const Max* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
PVar<Integer> c;
if (max(floordiv(x, y), c).Match(e) &&
c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(Mutate(truncdiv(x, y).Eval()), c.Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
private: private:
Expr SwapBroadcastCast(const Expr& e) { Expr SwapBroadcastCast(const Expr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x)) // Try to change broadcast(cast(x)) to cast(broadcast(x))
...@@ -132,17 +249,27 @@ class IntrinInjecter : public IRMutator { ...@@ -132,17 +249,27 @@ class IntrinInjecter : public IRMutator {
} }
return Expr(); return Expr();
} }
// patterns // patterns
std::vector<std::string> patterns_; std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr}; const PackedFunc* fma_{nullptr};
}; };
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target).Mutate(stmt);
}
LoweredFunc LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) { LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = make_node<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = IntrinInjecter(target).Mutate(n->body); n->body = LowerIntrinStmt(n->body, target);
return LoweredFunc(n); return LoweredFunc(n);
} }
// Register the api only for test purposes
TVM_REGISTER_API("ir_pass._LowerIntrinStmt")
.set_body_typed(LowerIntrinStmt);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -87,6 +87,7 @@ def test_llvm_lookup_intrin(): ...@@ -87,6 +87,7 @@ def test_llvm_lookup_intrin():
func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True) func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm") fcode = tvm.build(func, None, "llvm")
def test_llvm_add_pipeline(): def test_llvm_add_pipeline():
nn = 1024 nn = 1024
n = tvm.convert(nn) n = tvm.convert(nn)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
def lower_intrin(stmt):
"""wrapper to call transformation in stmt"""
lower_expr = isinstance(stmt, tvm.expr.Expr)
stmt = tvm.stmt.Evaluate(stmt) if lower_expr else stmt
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass._LowerIntrinStmt(stmt, "llvm")
return stmt.value if lower_expr else stmt.body
def check_value(expr, vx, vy, data, fref):
n = len(data)
A = tvm.placeholder((n,), name="A", dtype=expr.dtype)
B = tvm.placeholder((n,), name="B", dtype=expr.dtype)
def make_binds(i):
x = expr
x = tvm.expr.Let(vx, A[i], x)
x = tvm.expr.Let(vy, B[i], x)
return x
C = tvm.compute((n,), make_binds)
s = tvm.create_schedule([C.op])
if not tvm.module.enabled("llvm"):
return
f = tvm.build(s, [A, B, C], "llvm")
a = tvm.nd.array(np.array([x for x, y in data], dtype=expr.dtype))
b = tvm.nd.array(np.array([y for x, y in data], dtype=expr.dtype))
c = tvm.nd.array(np.zeros(len(data), dtype=expr.dtype))
f(a, b, c)
cref = np.array([fref(x, y) for x, y in data])
np.testing.assert_equal(c.asnumpy(), cref)
def get_ref_data():
"""Get reference data for every pairs"""
import itertools
x = range(-10, 10)
y = list(range(-10, 10))
y.remove(0)
return list(itertools.product(x, y))
def test_lower_floordiv():
data = get_ref_data()
for dtype in ["int32", "int64", "int16"]:
x = tvm.var("x", dtype=dtype)
y = tvm.var("y", dtype=dtype)
zero = tvm.const(0, dtype)
# no constraints
res = lower_intrin(tvm.floordiv(x, y))
check_value(res, x, y, data, lambda a, b: a // b)
# rhs >= 0
res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
# involves max
res = lower_intrin(tvm.expr.Select(y >= 0, tvm.max(tvm.floordiv(x, y), zero), zero))
check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
# lhs >= 0
res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin(tvm.floordiv(x, tvm.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
def test_lower_floormod():
data = get_ref_data()
for dtype in ["int32", "int64", "int16"]:
x = tvm.var("x", dtype=dtype)
y = tvm.var("y", dtype=dtype)
zero = tvm.const(0, dtype)
# no constraints
res = lower_intrin(tvm.floormod(x, y))
check_value(res, x, y, data, lambda a, b: a % b)
# rhs >= 0
res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
# lhs >= 0
res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin(tvm.floormod(x, tvm.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
if __name__ == "__main__":
test_lower_floordiv()
test_lower_floormod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment