Unverified Commit d9cecdf5 by Tianqi Chen Committed by GitHub

[ARITH] Remove the legacy Simplify, migrate to Analyzer. (#5385)

The legacy Simplify/CanonicalSimplify are now a thin wrapper around the Analyzer.
This PR removes these functions and migrated every place that requires
simplification to enforce Analyzer creation.
The new API would encourage more Analyzer sharing and potentially enable
context-aware analyzer-based simplification.
parent b8efe27f
......@@ -112,7 +112,7 @@ class ConstIntBoundAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr);
TVM_DLL ConstIntBound operator()(const PrimExpr& expr);
/*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
......@@ -120,8 +120,8 @@ class ConstIntBoundAnalyzer {
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
TVM_DLL ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
/*!
* \brief Update constant int bound information of var.
......@@ -130,22 +130,22 @@ class ConstIntBoundAnalyzer {
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
TVM_DLL void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const Var& var, const Range& range);
TVM_DLL void Bind(const Var& var, const Range& range);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(Analyzer* parent);
~ConstIntBoundAnalyzer();
TVM_DLL ~ConstIntBoundAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
......@@ -212,7 +212,7 @@ class ModularSetAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ModularSet operator()(const PrimExpr& expr);
TVM_DLL ModularSet operator()(const PrimExpr& expr);
/*!
* \brief Update constant int bound information of var.
*
......@@ -220,15 +220,15 @@ class ModularSetAnalyzer {
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ModularSet& info,
bool override = false);
TVM_DLL void Update(const Var& var,
const ModularSet& info,
bool override = false);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit ModularSetAnalyzer(Analyzer* parent);
~ModularSetAnalyzer();
TVM_DLL ~ModularSetAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
......@@ -252,7 +252,7 @@ class RewriteSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
PrimExpr operator()(const PrimExpr& expr);
TVM_DLL PrimExpr operator()(const PrimExpr& expr);
/*!
* \brief Update binding of var to a new expression.
......@@ -261,9 +261,9 @@ class RewriteSimplifier {
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);
TVM_DLL void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
......@@ -272,7 +272,7 @@ class RewriteSimplifier {
friend class ConstraintContext;
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
TVM_DLL ~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
......@@ -288,7 +288,7 @@ class CanonicalSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
PrimExpr operator()(const PrimExpr& expr);
TVM_DLL PrimExpr operator()(const PrimExpr& expr);
/*!
* \brief Update binding of var to a new expression.
......@@ -297,15 +297,15 @@ class CanonicalSimplifier {
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);
TVM_DLL void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
~CanonicalSimplifier();
TVM_DLL ~CanonicalSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
......@@ -363,12 +363,12 @@ class IntSetAnalyzer {
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
~IntSetAnalyzer();
TVM_DLL ~IntSetAnalyzer();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
......@@ -384,7 +384,7 @@ class IntSetAnalyzer {
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
class Analyzer {
class TVM_DLL Analyzer {
public:
/*
* Disable copy constructor.
......
......@@ -41,39 +41,6 @@
namespace tvm {
namespace tir {
/*!
* \brief Simplify the expression.
* \param expr The expression to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
TVM_DLL PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify the statement.
* \param stmt The statement to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt,
Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify by applying canonical form.
* \param expr The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized expression.
*/
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
......
......@@ -23,8 +23,8 @@ import time
from random import randrange
import numpy as np
from tvm.tir import expr, ir_pass
import tvm.arith
from tvm.tir import expr
logger = logging.getLogger('autotvm')
......@@ -156,7 +156,8 @@ def get_const_int(exp):
if isinstance(exp, int):
return exp
if not isinstance(exp, (expr.IntImm,)):
exp = ir_pass.Simplify(exp)
ana = tvm.arith.Analyzer()
exp = ana.simplify(exp)
if not isinstance(exp, (expr.IntImm,)):
raise ValueError("Expect value to be constant int")
return exp.value
......@@ -180,7 +181,8 @@ def get_const_tuple(in_tuple):
if isinstance(elem, expr.Var):
ret.append(elem)
elif not isinstance(elem, (expr.IntImm, int)):
elem = ir_pass.Simplify(elem)
ana = tvm.arith.Analyzer()
elem = ana.simplify(elem)
if not isinstance(elem, (expr.IntImm)):
ret.append(elem)
else:
......
......@@ -287,6 +287,7 @@ def _build_for_device(input_mod, target, target_host):
lambda f: "calling_conv" in f.attrs and
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_mixed)
......
......@@ -29,10 +29,10 @@ import tvm.runtime
import tvm.tir
import tvm.te
import tvm.te._ffi_api
import tvm.arith
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.tir import ir_pass as _ir_pass
from tvm.te.tensor import Tensor, Operation
from tvm.tir import all as _all
from tvm.tir import any as _any
......@@ -160,6 +160,7 @@ class HybridParser(ast.NodeVisitor):
self.outputs = [] # Output tensors' name
self.side_effect = set() # Tensors with side effects
self.parsed_body = None # The parsed HalideIR body
self.analyzer = tvm.arith.Analyzer()
self.returned = False # If this function has a valid return
......@@ -326,7 +327,7 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0]
if isinstance(rhs, _expr.PrimExpr):
rhs = _ir_pass.Simplify(rhs)
rhs = self.analyzer.simplify(rhs)
if isinstance(lhs, ast.Name):
#TODO: support defined intermediate buffer later
lhs_ = lhs
......@@ -410,7 +411,7 @@ class HybridParser(ast.NodeVisitor):
def visit_If(self, node):
cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
cond = self.analyzer.simplify(self.visit(node.test))
# Return no IfThenElse if proven
if isinstance(cond, _expr.IntImm):
......@@ -501,8 +502,8 @@ class HybridParser(ast.NodeVisitor):
_name = node.target.id
if isinstance(for_type, tuple):
low = _ir_pass.CanonicalSimplify(low)
ext = _ir_pass.CanonicalSimplify(ext)
low = self.analyzer.simplify(low)
ext = self.analyzer.simplify(ext)
_internal_assert(isinstance(low, _expr.ConstExpr) and
isinstance(ext, _expr.ConstExpr), \
"Const range should start from a const " + \
......
......@@ -20,6 +20,8 @@
import logging
import numpy as np
import tvm
import tvm.arith
import tvm.tir
import tvm._ffi
......@@ -168,4 +170,23 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
x_name, grad.shape, dist, max_diff, avg_diff)
def assert_prim_expr_equal(lhs, rhs):
"""Assert lhs and rhs equals to each iother.
Parameters
----------
lhs : tvm.tir.PrimExpr
The left operand.
rhs : tvm.tir.PrimExpr
The left operand.
"""
ana = tvm.arith.Analyzer()
res = ana.simplify(lhs - rhs)
equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
if not equal:
raise ValueError("{} and {} are not equal".format(lhs, rhs))
tvm._ffi._init_api("testing", __name__)
......@@ -21,7 +21,6 @@ from tvm.ir import container as _container
from . import stmt as _stmt
from . import expr as _expr
from . import ir_pass as _pass
class WithScope(object):
......@@ -212,7 +211,7 @@ class IRBuilder(object):
self.nidx += 1
self._seq_stack.append([])
loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else _pass.Simplify(end - begin)
extent = end if begin == 0 else (end - begin)
def _exit_cb():
if for_type == "serial":
for_type_id = 0
......
......@@ -207,8 +207,9 @@ bool DetectClipBound(
return false;
}
LinearEqEntry ret;
Analyzer analyzer;
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = Simplify(ret.coeff);
ret.coeff = analyzer.Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()];
if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift
......@@ -254,14 +255,15 @@ Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
for (PrimExpr cond : splits) {
if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>();
}
Analyzer analyzer;
Array<PrimExpr> ret;
for (Var v : vars) {
IntervalEntry e = rmap[v.get()];
if (e.min_value.defined()) {
e.min_value = Simplify(e.min_value);
e.min_value = analyzer.Simplify(e.min_value);
}
if (e.max_value.defined()) {
e.max_value = Simplify(e.max_value);
e.max_value = analyzer.Simplify(e.max_value);
}
ret.push_back(e.min_value);
ret.push_back(e.max_value);
......
......@@ -570,11 +570,12 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr,
// TODO(tqchen): revisit IntSet interface as well.
Range IntSet::cover_range(Range max_range) const {
IntSet temp;
Analyzer analyzer;
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int != nullptr);
if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
return Range::make_by_min_extent(
s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value));
s_int->min_value, analyzer.Simplify(s_int->max_value + 1 - s_int->min_value));
}
return max_range;
}
......@@ -607,26 +608,30 @@ bool IntSet::is_single_point() const {
}
bool IntSet::can_prove_positive() const {
Analyzer analyzer;
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && is_positive_const(tir::Simplify(s_int->min_value)));
return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value)));
}
bool IntSet::can_prove_negative() const {
Analyzer analyzer;
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && is_negative_const(tir::Simplify(s_int->max_value)));
return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value)));
}
bool IntSet::can_prove_non_positive() const {
Analyzer analyzer;
if (const auto* s_int = (*this).as<IntervalSetNode>()) {
auto max = tir::Simplify(s_int->max_value);
auto max = analyzer.Simplify(s_int->max_value);
return is_zero(max) || is_negative_const(max);
}
return false;
}
bool IntSet::can_prove_non_negative() const {
Analyzer analyzer;
if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
auto min = tir::Simplify(s_int->min_value);
auto min = analyzer.Simplify(s_int->min_value);
return is_zero(min) || is_positive_const(min);
}
return false;
......@@ -669,8 +674,8 @@ IntSet IntSet::interval(PrimExpr min, PrimExpr max) {
}
// Range related code
inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) {
return is_zero(tir::Simplify(lhs - rhs));
inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) {
return is_zero(analyzer->Simplify(lhs - rhs));
}
IntSet IntSet::range(Range r) {
......@@ -685,8 +690,9 @@ bool IntSet::match_range(const Range& b) const {
const IntSet& a = *this;
const IntervalSetNode* a_int = a.as<IntervalSetNode>();
if (!a_int) return false;
return ProveEqual(a_int->min_value, b->min) &&
ProveEqual(a_int->max_value, b->extent + b->min - 1);
Analyzer ana;
return ProveEqual(&ana, a_int->min_value, b->min) &&
ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1);
}
IntSet Union(const Array<IntSet>& sets) {
......@@ -697,8 +703,8 @@ IntSet Union(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) {
x = Union(&ana, x, ToIntervalSet(sets[i]));
}
return IntervalSet(tir::Simplify(x->min_value),
tir::Simplify(x->max_value));
return IntervalSet(ana.Simplify(x->min_value),
ana.Simplify(x->max_value));
}
IntSet Intersect(const Array<IntSet>& sets) {
......@@ -709,8 +715,8 @@ IntSet Intersect(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) {
x = Intersect(&ana, x, ToIntervalSet(sets[i]));
}
return IntervalSet(tir::Simplify(x->min_value),
tir::Simplify(x->max_value));
return IntervalSet(ana.Simplify(x->min_value),
ana.Simplify(x->max_value));
}
Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
......@@ -758,7 +764,7 @@ IntSet EvalSet(Range r,
IntervalSetEvaluator m(&ana, dom_map);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
PrimExpr sum = r->min + r->extent - 1;
auto res = m.Eval(IntervalSet(r->min, Simplify(sum)));
auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum)));
return std::move(res);
}
......
......@@ -236,6 +236,7 @@ split_dev_host_funcs(IRModule mod_mixed,
}),
BindTarget(target),
tir::transform::LowerWarpMemory(),
tir::transform::Simplify(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
};
......
......@@ -22,9 +22,10 @@
* \brief A set of utilities and common functionality
* for type relations.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/tir/ir_pass.h>
#include <numeric>
#include "./type_relations.h"
......@@ -48,7 +49,8 @@ bool EqualCheck(const IndexExpr& lhs,
return pdiff[0] == 0;
}
// symbolic
diff = tvm::tir::CanonicalSimplify(diff);
tvm::arith::Analyzer ana;
diff = ana.Simplify(diff);
if (const int64_t* pdiff = tir::as_const_int(diff)) {
return pdiff[0] == 0;
}
......
......@@ -414,7 +414,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index = tir::Simplify(
PrimExpr vec_index = analyzer_->Simplify(
ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
......@@ -492,7 +492,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index = tir::Simplify(
PrimExpr vec_index = analyzer_->Simplify(
ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
......
......@@ -24,9 +24,11 @@
* The result Jacobian shape will be (Y.shape, X.shape)
*/
#include <tvm/te/autodiff.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <topi/transform.h>
#include <tvm/tir/ir_pass.h>
#include <memory>
#include "ad_util.h"
......@@ -264,7 +266,7 @@ class JacobianMutator : public ExprMutator {
CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
// Also simplify the resulting combiner
// (mostly to get rid of unused components, e.g., the original expressions)
return Simplify(
return analyzer_.Simplify(
ReduceNode::make(new_combiner, new_source, new_op->axis,
new_op->condition, new_op->value_index));
}
......@@ -302,6 +304,7 @@ class JacobianMutator : public ExprMutator {
Tensor input_;
Array<PrimExpr> indices_;
Var input_var_;
arith::Analyzer analyzer_;
};
PrimExpr Derivative(const PrimExpr& expr, const Var& var) {
......@@ -341,11 +344,11 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) {
// Differentiate wrt input[input_indices]
input_indices.push_back(new_v);
}
arith::Analyzer analzyer;
// Compute Jacobian
PrimExpr new_body = Jacobian(
Substitute(op->body[output->value_index], vmap), input, input_indices);
new_body = Simplify(new_body);
new_body = analzyer.Simplify(new_body);
int value_index = 0;
Array<PrimExpr> new_bodies;
......
......@@ -39,10 +39,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_REGISTER_NODE_TYPE(ScanOpNode);
inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) {
return is_zero(tir::Simplify(lhs - rhs));
}
int ScanOpNode::num_outputs() const {
return static_cast<int>(update.size());
}
......@@ -77,6 +73,10 @@ Operation ScanOpNode::make(std::string name,
auto n = make_object<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
arith::Analyzer analyzer;
auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) {
return is_zero(analyzer.Simplify(lhs - rhs));
};
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
......@@ -232,10 +232,11 @@ void ScanOpNode::GatherBound(
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
CHECK(!out_dom_map->count(this->scan_axis));
arith::Analyzer analyzer;
Range sdom = this->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
sdom->min, tir::Simplify(r->extent + r->min - sdom->min));
sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min));
Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
// Update for spatial axis.
size_t sp_idx = 0;
......@@ -260,10 +261,11 @@ Stmt ScanOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const Stmt& body) const {
arith::Analyzer analyzer;
CHECK_EQ(stage->op.get(), this);
Range sdom = dom_map.at(this->scan_axis);
Range tdom = Range::make_by_min_extent(
0, tir::Simplify(sdom->extent + sdom->min));
0, analyzer.Simplify(sdom->extent + sdom->min));
Stmt ret = body;
size_t sp_idx = 0;
for (size_t i = 0; i < update.size(); ++i) {
......
......@@ -222,6 +222,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
compute_intrin_iter_space->Set(iv->var, vrange);
}
}
analyzer_.Bind(*compute_intrin_iter_space);
// input remap.
Array<Tensor> inputs = self->InputTensors();
......@@ -234,7 +235,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
// Enable fuzzy matching, to match [1, n, m] to [n, m]
e.start = e.region.size() - e.tensor.ndim();
for (size_t j = 0; j < e.start; ++j) {
auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
auto canonical_extent = analyzer_.Simplify(e.region[j]->extent);
CHECK(is_one(canonical_extent))
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
......@@ -304,6 +305,8 @@ class TensorIntrinMatcher final : public StmtExprMutator {
std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// IterVar remap.
std::unordered_map<IterVar, IterVar> axis_remap_;
// arith analyzer
arith::Analyzer analyzer_;
};
// Try to match tensor dataflow of the stage with the intrinsic
......@@ -339,11 +342,12 @@ void VerifyTensorizeBody(
CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_EQ(body.size(), intrin_compute->body.size())
<< "Tensorize failed: body size mismatch";
arith::Analyzer ana;
ana.Bind(compute_intrin_iter_space);
for (size_t i = 0; i < body.size(); ++i) {
PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
PrimExpr lhs = ana.Simplify(body[i]);
PrimExpr rhs = ana.Simplify(intrin_compute->body[i]);
if (lhs.dtype() != rhs.dtype()) {
LOG(FATAL)
<< "Failed to match the data type with TensorIntrin "
......
......@@ -324,6 +324,7 @@ void PassUpDomain(const FuseNode* s,
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
arith::Analyzer ana;
if (fused.match_range(dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
......@@ -348,15 +349,15 @@ void PassUpDomain(const FuseNode* s,
*outer = IntSet::interval(
outer_min + indexdiv(fused.min(), inner_extent),
outer_min + indexdiv(fused.max(), inner_extent));
if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) &&
is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) {
// fused never spans multiple rows, make a tight bounding box
// there may be other cases when bounding box could be tightened
*inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
inner_min + indexmod(fused.max(), inner_extent));
} else { // fused may span multiple rows, use full row widths
if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
!is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) ||
!is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) {
LOG(WARNING) <<
"fused and original axes are not aligned, this may cause redundant computations";
}
......
......@@ -181,7 +181,7 @@ class SchedulePostProc : public StmtExprMutator {
// delete duplicated thread extent attr
auto it = thread_extent_scope_.find(op->node.get());
if (it != thread_extent_scope_.end()) {
CHECK(is_zero(tir::Simplify(it->second - op->value)));
CHECK(is_zero(analyzer_.Simplify(it->second - op->value)));
return this->VisitStmt(op->body);
} else {
thread_extent_scope_[op->node.get()] = op->value;
......@@ -335,6 +335,8 @@ class SchedulePostProc : public StmtExprMutator {
std::unordered_map<TensorKey, Tensor> replace_realize_;
// replace producer consumer.
std::unordered_map<const Object*, Operation> replace_op_;
// integer analyzer
arith::Analyzer analyzer_;
};
Stmt ScheduleOps(
......
......@@ -508,7 +508,7 @@ class BufferAnalyser : public StmtExprVisitor {
return;
}
auto index = rel_index[i];
auto simplified_index = tir::Simplify(index);
auto simplified_index = analyzer_.Simplify(index);
index_visitor(simplified_index);
}
......@@ -611,7 +611,7 @@ class BufferAnalyser : public StmtExprVisitor {
index_visitor.scaling_factor_ = shape->value;
}
auto index = rel_index[i];
auto simplified_index = tir::Simplify(index);
auto simplified_index = analyzer_.Simplify(index);
index_visitor(simplified_index);
}
}
......@@ -645,7 +645,7 @@ class BufferAnalyser : public StmtExprVisitor {
PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
stride = stride + \
indexmod(factor + offset - indexmod(stride, factor), factor);
stride = tir::Simplify(stride);
stride = analyzer_.Simplify(stride);
}
rstrides.push_back(stride);
stride = stride * shape[dim];
......@@ -773,6 +773,7 @@ class BufferAnalyser : public StmtExprVisitor {
IndexVisitor index_visitor;
Tile warp_tile_;
Tile thread_tile_;
arith::Analyzer analyzer_;
int warp_threads_y_{-1};
bool invalid_{false};
};
......@@ -1148,7 +1149,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
buffer_node->strides = strides;
buffer_node->shape = shape;
buffer_node->data_alignment = 1;
buffer_node->elem_offset = Simplify(elem_offset);
buffer_node->elem_offset = analyzer_.Simplify(elem_offset);
buffer_node->offset_factor = 1;
Buffer buffer(buffer_node);
......@@ -1184,6 +1185,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
std::unordered_map<const ProvideNode*, PrimExpr> frag_load_;
std::unordered_map<const ProvideNode*, PrimExpr> frag_store_;
std::unordered_map<TensorKey, Region> bounds_;
arith::Analyzer analyzer_;
Tile warp_tile_;
int warp_threads_y_{-1};
};
......
......@@ -26,6 +26,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <iterator>
#include <stack>
......@@ -37,9 +38,9 @@ namespace tir {
using IndexMod = tir::FloorModNode;
using IndexDiv = tir::FloorDivNode;
Array<PrimExpr> SimplifyArray(Array<PrimExpr> array) {
Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
for (size_t i = 0; i < array.size(); ++i) {
array.Set(i, tir::Simplify(array[i]));
array.Set(i, ana->Simplify(array[i]));
}
return array;
}
......@@ -185,14 +186,14 @@ inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles,
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
// a pair with (true, optimized_expr) if can be optimized
inline PrimExpr MergeMulMod(const PrimExpr &base) {
inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) {
using namespace tir;
// 1. Prepare the lists.
// We store two lists, a list that contain all the elements that match Mul and
// a list that contain all the elements that match Mod.
// The elements in the Mod will be used to match against the elements in Mul.
// The result will then be split and pushed back to these two lists.
PrimExpr simplified_base = Simplify(base);
PrimExpr simplified_base = analyzer->Simplify(base);
std::vector<const PrimExpr*> eles = ExprSplitAddition(simplified_base);
std::list<PrimExpr> mult_exprs;
std::list<std::pair<PrimExpr, PrimExpr> > mod_exprs;
......@@ -254,6 +255,7 @@ inline PrimExpr MergeMulMod(const PrimExpr &base) {
// We also perform optimization to simplify the indexing expression.
inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
PrimExpr base = n->elem_offset;
arith::Analyzer ana;
if (n->strides.size() == 0) {
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
......@@ -265,7 +267,7 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
if (index.size() > 0) {
PrimExpr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]);
}
base = base + offset;
}
......@@ -273,12 +275,12 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
} else {
CHECK_EQ(n->strides.size(), index.size());
if (is_zero(base)) {
base = MergeMulMod(index[0] * n->strides[0]);
base = MergeMulMod(&ana, index[0] * n->strides[0]);
} else {
base = MergeMulMod(base + index[0] * n->strides[0]);
base = MergeMulMod(&ana, base + index[0] * n->strides[0]);
}
for (size_t i = 1; i < index.size(); ++i) {
base = MergeMulMod(base + index[i] * n->strides[i]);
base = MergeMulMod(&ana, base + index[i] * n->strides[i]);
}
}
return base;
......@@ -353,8 +355,9 @@ Buffer Buffer::MakeStrideView() const {
Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const {
const BufferNode* n = operator->();
begins = SimplifyArray(begins);
PrimExpr elem_offset = tir::Simplify(ElemOffset(n, begins));
arith::Analyzer ana;
begins = SimplifyArray(&ana, begins);
PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins));
Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
......@@ -363,7 +366,7 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
for (size_t i = 0; i < extents.size(); ++i) {
if (!can_relax) {
if (!is_zero(begins[i]) ||
!is_zero(tir::Simplify(extents[i] - n->shape[i]))) {
!is_zero(ana.Simplify(extents[i] - n->shape[i]))) {
need_stride = true;
}
}
......
......@@ -24,6 +24,8 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <cctype>
namespace tvm {
......@@ -253,15 +255,16 @@ inline bool GetStoreRule(Array<PrimExpr>* rule,
}
inline Array<PrimExpr> TransformIndex(const Array<PrimExpr>& src_index,
const Array<IterVar>& src_axis,
const Array<PrimExpr>& transform_rule) {
const Array<IterVar>& src_axis,
const Array<PrimExpr>& transform_rule) {
arith::Analyzer ana;
Array<PrimExpr> result;
std::unordered_map<const tir::VarNode*, PrimExpr> bind_map;
for (size_t i = 0; i < src_index.size(); ++i) {
bind_map[src_axis[i]->var.get()] = src_index[i];
}
for (PrimExpr rule : transform_rule) {
result.push_back(tir::Simplify(tir::Substitute(rule, bind_map)));
result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
}
return result;
}
......@@ -284,9 +287,10 @@ Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index)
}
inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
const Array<IterVar>& src_axis,
const Array<IterVar>& target_axis,
const Array<PrimExpr>& transform_rule) {
const Array<IterVar>& src_axis,
const Array<IterVar>& target_axis,
const Array<PrimExpr>& transform_rule) {
arith::Analyzer ana;
CHECK_EQ(src_shape.size(), src_axis.size());
// bind variables for original axes
// for major-axis, bind the corresponding size
......@@ -329,7 +333,7 @@ inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
if (symbolic_var_set.count(i)) {
result.push_back(tir::AnyNode::make());
} else {
result.push_back(tir::Simplify(tir::Substitute(rule, bind_map)));
result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
}
}
}
......
......@@ -31,10 +31,11 @@
namespace tvm {
namespace tir {
void BinderAddAssert(PrimExpr cond,
void BinderAddAssert(arith::Analyzer* ana,
PrimExpr cond,
const std::string& arg_name,
std::vector<Stmt>* asserts) {
PrimExpr scond = Simplify(cond);
PrimExpr scond = ana->Simplify(cond);
if (is_zero(scond)) {
LOG(FATAL) << "Bind have an unmet assertion: "
<< cond << ", " << " on argument " << arg_name;
......@@ -65,10 +66,10 @@ bool ArgBinder::Bind_(const PrimExpr& arg,
}
return true;
} else {
BinderAddAssert(it->second == value, arg_name, &asserts_);
BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_);
}
} else {
BinderAddAssert(arg == value, arg_name, &asserts_);
BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_);
}
return false;
}
......@@ -121,7 +122,8 @@ void ArgBinder::BindBuffer(const Buffer& arg,
PrimExpr offset = value->elem_offset;
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(truncmod(offset, factor) == zero,
BinderAddAssert(&analyzer_,
truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
}
......@@ -130,7 +132,7 @@ void ArgBinder::BindBuffer(const Buffer& arg,
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = value->shape.size() - arg->shape.size();
for (size_t i = 0; i < diff; ++i) {
CHECK(is_one(Simplify(value->shape[i])))
CHECK(is_one(analyzer_.Simplify(value->shape[i])))
<< "Argument " << arg_name << " shape mismatch"
<< arg->shape << " vs " << value->shape;
}
......@@ -269,7 +271,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
stride = Simplify(stride * buffer->shape[k]);
stride = analyzer_.Simplify(stride * buffer->shape[k]);
}
} else {
std::ostringstream stride_null_err_msg;
......@@ -304,7 +306,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
BinderAddAssert(&analyzer_,
truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
}
}
......
......@@ -26,6 +26,8 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/arith/analyzer.h>
#include <string>
#include <vector>
#include <unordered_map>
......@@ -153,6 +155,8 @@ class ArgBinder {
Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */
std::vector<Stmt> asserts_;
/*! \brief internal analyzer. */
arith::Analyzer analyzer_;
};
} // namespace tir
} // namespace tvm
......
......@@ -32,40 +32,6 @@
namespace tvm {
namespace tir {
TVM_REGISTER_GLOBAL("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<Stmt>()) {
if (args.size() > 1) {
*ret = Simplify(args[0].operator Stmt(), args[1]);
} else {
*ret = Simplify(args[0].operator Stmt());
}
} else {
if (args.size() > 1) {
*ret = Simplify(args[0].operator PrimExpr(), args[1]);
} else {
*ret = Simplify(args[0].operator PrimExpr());
}
}
});
TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<Stmt>()) {
if (args.size() > 1) {
*ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
} else {
*ret = CanonicalSimplify(args[0].operator Stmt());
}
} else {
if (args.size() > 1) {
*ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]);
} else {
*ret = CanonicalSimplify(args[0].operator PrimExpr());
}
}
});
TVM_REGISTER_GLOBAL("ir_pass.Substitute")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<Stmt>()) {
......
......@@ -24,6 +24,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
#include <tvm/arith/pattern.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include "../../arith/pattern_match.h"
......@@ -125,7 +126,7 @@ class CopyIntrinInjector : public StmtMutator {
DataType t = loop_vars[i].dtype();
PrimExpr svalue = src_shape[i];
if (min_value.defined()) {
PrimExpr pbefore = Simplify(MaxNode::make(min_value, make_zero(t)));
PrimExpr pbefore = analyzer_.Simplify(MaxNode::make(min_value, make_zero(t)));
src_elem_offset = src_elem_offset + pbefore * load_strides[i];
svalue = svalue - pbefore;
pad_before.push_back(pbefore);
......@@ -133,16 +134,16 @@ class CopyIntrinInjector : public StmtMutator {
pad_before.push_back(make_zero(t));
}
if (max_value.defined()) {
PrimExpr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1),
make_zero(t)));
PrimExpr pafter = analyzer_.Simplify(
max(loops[i]->extent - max_value - make_const(t, 1), make_zero(t)));
svalue = svalue - pafter;
pad_after.push_back(pafter);
} else {
pad_after.push_back(make_zero(t));
}
src_shape.Set(i, Simplify(svalue));
src_shape.Set(i, analyzer_.Simplify(svalue));
}
src_elem_offset = Simplify(src_elem_offset);
src_elem_offset = analyzer_.Simplify(src_elem_offset);
}
CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_var_size + 1);
......@@ -189,6 +190,8 @@ class CopyIntrinInjector : public StmtMutator {
const PackedFunc& flower_copy_fromto_;
// Storage scope
std::unordered_map<const VarNode*, std::string> storage_scope_;
// arith analyzer
arith::Analyzer analyzer_;
};
Stmt InjectCopyIntrin(Stmt stmt,
......
......@@ -24,11 +24,9 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/buffer.h>
#include <tvm/arith/analyzer.h>
#include <tvm/target/target_info.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
#include "../pass/ir_util.h"
#include "../../runtime/thread_storage_scope.h"
......@@ -123,8 +121,8 @@ class StorageAccessInfoLower : public StmtExprMutator {
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type,
tir::Simplify(offset / make_const(
offset.dtype(), info->unit_bits / dtype_bits)));
analyzer_.Simplify(offset / make_const(
offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
......@@ -137,6 +135,8 @@ class StorageAccessInfoLower : public StmtExprMutator {
};
// The storage scope of each buffer
std::unordered_map<const VarNode*, StorageEntry> storage_info_;
// analyzer
arith::Analyzer analyzer_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
......
......@@ -24,7 +24,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
......@@ -313,6 +313,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
return ret;
}
// The local buffer index.
PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) {
if (!is_zero(group_index)) {
return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
} else {
return reduce_index;
}
}
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
return EvaluateNode::make(
......@@ -320,14 +328,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
{StringImmNode::make(sync)},
CallNode::Intrinsic));
}
// The local buffer index.
static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) {
if (!is_zero(group_index)) {
return tir::Simplify(group_index * reduce_extent + reduce_index);
} else {
return reduce_index;
}
}
// The warp size of the device.
int warp_size_{1};
......@@ -338,6 +338,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::unordered_map<const VarNode *, PrimExpr> load_remap_;
// Allocate remap
std::unordered_map<const VarNode *, Stmt> alloc_remap_;
// Internal analyzer
arith::Analyzer analyzer_;
};
namespace transform {
......
......@@ -371,7 +371,6 @@ class WarpMemoryRewriter : private StmtMutator {
BindVarBoundInfo binder(&analyzer_);
binder(stmt);
stmt = operator()(std::move(stmt));
stmt = CanonicalSimplify(stmt);
return stmt;
}
......
......@@ -98,36 +98,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} // namespace arith
namespace tir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt));
}
PrimExpr CanonicalSimplify(PrimExpr expr, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return analyzer.canonical_simplify(expr);
}
PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
expr = analyzer.Simplify(expr);
return expr;
}
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(std::move(stmt), vrange);
}
namespace transform {
Pass Simplify() {
......
......@@ -625,7 +625,7 @@ class StoragePlanRewriter : public StmtExprMutator {
if (!divided) {
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = tir::Simplify(combo_size);
combo_size = analyzer_.Simplify(combo_size);
e->new_alloc = AllocateNode::make(
e->alloc_var, alloc_type, {combo_size}, const_true(),
EvaluateNode::make(0));
......
......@@ -25,9 +25,10 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
......@@ -160,7 +161,7 @@ class LoopUnroller : public StmtExprMutator {
// returns the extent of the loop if it's a constant integer, otherwise return -1
int GetExtent(const ForNode* op) {
// constant folding.
PrimExpr extent = tir::Simplify(op->extent);
PrimExpr extent = analyzer_.Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>();
int value = -1;
// integers that do not fit in int32_t are treated as symbolic,
......@@ -184,6 +185,8 @@ class LoopUnroller : public StmtExprMutator {
int unroll_depth_{0};
// Number of total steps unrolled
int step_count_{0};
// analyzer
arith::Analyzer analyzer_;
};
......
......@@ -19,35 +19,38 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <tvm/te/operation.h>
TEST(IRSIMPLIFY, MinMax) {
TEST(Simplify, MinMax) {
tvm::arith::Analyzer ana;
auto x = tvm::te::var("x");
auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ;
auto e1s = tvm::tir::CanonicalSimplify(e1);
auto e1s = ana.canonical_simplify(e1);
CHECK(tvm::tir::is_zero(e1s));
auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1));
auto e2s = tvm::tir::CanonicalSimplify(e2);
auto e2s = ana.canonical_simplify(e2);
CHECK(tvm::tir::is_zero(e2s));
}
TEST(IRSIMPLIFY, Mul) {
TEST(Simplify, Mul) {
tvm::arith::Analyzer ana;
auto x = tvm::te::var("x");
auto e = (x * x) - (x * x) ;
auto es = tvm::tir::CanonicalSimplify(e);
auto es = ana.canonical_simplify(e);
CHECK(tvm::tir::is_zero(es));
}
TEST(IRSIMPLIFY, Mod) {
TEST(Simplify, Mod) {
tvm::arith::Analyzer ana;
auto x = tvm::Integer(10);
auto y = tvm::Integer(12);
// Mod::make is used instead of % to avoid constant folding during
// calling operator%(x,y). Mod::make doesn't try constant folding,
// and therefore, the constant folding will be attempted in CanonicalSimplify
auto mod = tvm::tir::CanonicalSimplify(tvm::tir::ModNode::make(x, y));
auto es = tvm::tir::CanonicalSimplify(mod - x);
auto mod = ana.canonical_simplify(tvm::tir::ModNode::make(x, y));
auto es = ana.canonical_simplify(mod - x);
CHECK(tvm::tir::is_zero(es));
}
int main(int argc, char ** argv) {
......
......@@ -18,13 +18,6 @@ import tvm
from tvm import te
def assert_expr_equal(a, b):
res = tvm.tir.ir_pass.Simplify(a - b)
equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
if not equal:
raise ValueError("{} and {} are not equal".format(a, b))
def test_deduce():
a = te.var('a')
b = te.var('b')
......@@ -41,32 +34,32 @@ def test_deduce():
e0 = (-b)*a+c-d
res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = fdiv(d - c, b*-1)
assert_expr_equal(res0.max_value, ans0)
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res0.max_value, ans0)
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
e0 = d*a+c-d
res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = fdiv(d-c, d)
assert_expr_equal(res0.max_value, ans0)
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res0.max_value, ans0)
tvm.testing.assert_prim_expr_equal(res0.max_value, ans0)
e1 = (a*4+b < c)
res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = fdiv(c-1-b, 4)
assert_expr_equal(res1.max_value, ans1)
tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
# expression containing variable a is on rhs
e1 = (c > a*4+b)
res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res1.max_value, ans1)
tvm.testing.assert_prim_expr_equal(res1.max_value, ans1)
e2 = (tvm.te.max(5, a * 4) < 0)
......@@ -83,15 +76,15 @@ def test_deduce():
e3 = (-b)+a*c-d
res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = fdiv(2,c)+1
assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3)
tvm.testing.assert_prim_expr_equal(res3.min_value, ans3)
res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3)
tvm.testing.assert_prim_expr_equal(res3.min_value, ans3)
# tests for `EQ` op
res4 = tvm.arith.deduce_bound(a, a == b, {}, {})
assert_expr_equal(res4.max_value, b)
assert_expr_equal(res4.min_value, b)
tvm.testing.assert_prim_expr_equal(res4.max_value, b)
tvm.testing.assert_prim_expr_equal(res4.min_value, b)
# Unsatisfiable `EQ`, variable as one of the Operand
res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s})
......@@ -100,20 +93,20 @@ def test_deduce():
# variable `a` on the RHS side
res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {})
assert_expr_equal(res6.max_value, 10)
assert_expr_equal(res6.min_value, 10)
tvm.testing.assert_prim_expr_equal(res6.max_value, 10)
tvm.testing.assert_prim_expr_equal(res6.min_value, 10)
# Add, Sub in `EQ`
e4 = ((a - c) == (b + d))
ans4 = (b + d + c)
res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res7.max_value, ans4)
assert_expr_equal(res7.min_value, ans4)
tvm.testing.assert_prim_expr_equal(res7.max_value, ans4)
tvm.testing.assert_prim_expr_equal(res7.min_value, ans4)
# Satisfiable Mul in `EQ` with negative sign
res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {})
assert_expr_equal(res8.max_value, -2)
assert_expr_equal(res8.min_value, -2)
tvm.testing.assert_prim_expr_equal(res8.max_value, -2)
tvm.testing.assert_prim_expr_equal(res8.min_value, -2)
# Unsatisfiable Mul in `EQ`
e5 = (4 * a == b)
......@@ -158,21 +151,22 @@ def test_deduce_basic():
res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) < 17, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) > 17, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) <= 17, True)
res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) >= 17, True)
test_basic(0, 4, 4)
test_basic(1, 5, 4)
......@@ -190,21 +184,21 @@ def test_deduce_complex():
res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) < 63, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) <= 63, True)
res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) > 63, True)
# expression containing variable a is on rhs
res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) >= 63, True)
test_complex(0, 4, 4)
test_complex(0, 4, -4)
......
......@@ -23,15 +23,15 @@ def test_basic():
c = te.var("c")
m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6,
a - 1 > 0), [a])
assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1)
assert m[0].value == 2
m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6,
a - 1 > 0), [a, b])
assert len(m) == 0
m = tvm.arith.detect_clip_bound(tvm.tir.all(a + 10 * c <= 20,
b - 1 > 0), [a, b])
assert tvm.tir.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
assert tvm.tir.ir_pass.Simplify(m[2] - 2).value == 0
tvm.testing.assert_prim_expr_equal(m[1], 20 - 10 * c)
tvm.testing.assert_prim_expr_equal(m[2], 2)
if __name__ == "__main__":
......
......@@ -22,14 +22,14 @@ def test_basic():
b = te.var("b")
m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a])
assert m[0].value == 4
assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0
tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7)
m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a])
assert len(m) == 0
m = tvm.arith.detect_linear_equation(a * 4 + (a+1) + b * 6 + 7, [a])
assert m[0].value == 5
assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0
tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7 + 1)
m = tvm.arith.detect_linear_equation(a * b + 7, [a])
assert m[0] == b
......@@ -39,13 +39,15 @@ def test_basic():
m = tvm.arith.detect_linear_equation(b * 7, [])
assert len(m) == 1
assert tvm.tir.ir_pass.Simplify(m[0] - b * 7).value == 0
tvm.testing.assert_prim_expr_equal(m[0], b * 7)
def test_multivariate():
v = [te.var("v%d" % i) for i in range(4)]
b = te.var("b")
m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
assert(tvm.tir.analysis.expr_deep_equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5))
tvm.testing.assert_prim_expr_equal(m[0], b + 5)
assert(m[1].value == 8)
m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
......@@ -61,11 +63,12 @@ def test_multivariate():
m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]])
assert(m[0].value == 0)
assert(tvm.tir.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
tvm.testing.assert_prim_expr_equal(m[1], v[0] - v[1])
m = tvm.arith.detect_linear_equation((v[0] - v[1]), [])
assert(len(m) == 1)
assert(tvm.tir.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
tvm.testing.assert_prim_expr_equal(m[0], v[0] - v[1])
if __name__ == "__main__":
test_basic()
......
......@@ -55,12 +55,13 @@ def check_bruteforce(bool_expr, vranges, cond=None):
counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
raise AssertionError("Expression {}\nis not true on {}\n"
"Counterexample: {}"
.format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))
.format(tir.arith.Analyzer().simplify(bool_expr), vranges, counterex))
def check_solution(solution, vranges={}):
"""Check that solution is a bijective transformation"""
def _check_forward(constraints1, constraints2, varmap, backvarmap):
ana = tvm.arith.Analyzer()
all_vranges = vranges.copy()
all_vranges.update({v: r for v, r in constraints1.ranges.items()})
......@@ -68,7 +69,7 @@ def check_solution(solution, vranges={}):
cond_on_vars = tir.const(1, 'bool')
for v in constraints1.variables:
# variable mapping is consistent
v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap))
v_back = ana.simplify(tir.ir_pass.Substitute(varmap[v], backvarmap))
cond_on_vars = te.all(cond_on_vars, v == v_back)
# Also we have to check that the new relations are true when old relations are true
cond_subst = tir.ir_pass.Substitute(
......@@ -80,7 +81,7 @@ def check_solution(solution, vranges={}):
range_cond = te.all(v >= r.min, v < r.min + r.extent)
range_cond = tir.ir_pass.Substitute(range_cond, backvarmap)
cond_subst = te.all(cond_subst, range_cond)
cond_subst = tir.ir_pass.Simplify(cond_subst)
cond_subst = ana.simplify(cond_subst)
check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
cond=te.all(tir.const(1, 'bool'), *constraints1.relations))
......
......@@ -25,7 +25,7 @@ from tvm.te.hybrid.runtime import HYBRID_GLOBALS
def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
def tvm_val_2_py_val(val):
val = tvm.tir.ir_pass.Substitute(val, var_dict)
val = tvm.tir.ir_pass.Simplify(val)
val = tvm.arith.Analyzer().simplify(val)
assert isinstance(val, (tvm.tir.IntImm,))
return val.value
......
......@@ -139,19 +139,20 @@ def test_bound_fusesplit1():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
idxdiv = tvm.tir.indexdiv
assert(tvm.tir.ir_pass.Simplify(
bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0)
tvm.testing.assert_prim_expr_equal(
bounds[A1.op.axis[0]].min, idxdiv(xo * split1, l))
expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1)
for i in range(1, 6):
for j in range(1, 6):
for k in range(1, 6):
vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")})
comp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value
exp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(expected_extent, vars)).value
assert(comp_ext == exp_ext)
tvm.testing.assert_prim_expr_equal(
tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars),
tvm.tir.ir_pass.Substitute(expected_extent, vars)
)
assert(tvm.tir.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0)
tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l)
def test_bound_fusesplit2():
m = te.var("m")
......@@ -169,10 +170,10 @@ def test_bound_fusesplit2():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")})
assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2)
assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3)
assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1)
assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars), 2)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars), 3)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), 1)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars), 3)
def test_bound_warp():
......
......@@ -105,9 +105,10 @@ def test_tensorize_vadd():
assert tvm.ir.structural_equal(in_dom.items()[0][1][0].extent, factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[z], out_dom, in_dom, vadd)
ana = tvm.arith.Analyzer()
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(vadd.op.body[0]))
ana.simplify(body[0]),
ana.simplify(vadd.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])
......@@ -139,9 +140,11 @@ def test_tensorize_matmul():
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
ana = tvm.arith.Analyzer()
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
ana.simplify(body[0]),
ana.simplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......@@ -164,9 +167,10 @@ def test_tensorize_matmul():
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
ana = tvm.arith.Analyzer()
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
ana.simplify(body[0]),
ana.simplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......@@ -188,9 +192,10 @@ def test_tensorize_matmul():
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
ana = tvm.arith.Analyzer()
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
ana.simplify(body[0]),
ana.simplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......@@ -213,9 +218,10 @@ def test_tensorize_matmul():
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
ana = tvm.arith.Analyzer()
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
ana.simplify(body[0]),
ana.simplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......
......@@ -48,17 +48,14 @@ def test_buffer_access_ptr_offset():
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((m, n), "float32")
aptr = Ab.access_ptr("rw", offset=100)
offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
assert tvm.ir.structural_equal(offset, 100)
tvm.testing.assert_prim_expr_equal(aptr.args[2], 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
v = te.size_var('int32')
aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
assert tvm.ir.structural_equal(offset, 200 + v)
tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v))
offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
assert tvm.ir.structural_equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v))
tvm.testing.assert_prim_expr_equal(aptr.args[2], tvm.tir.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
......@@ -80,8 +77,7 @@ def test_buffer_vload():
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
load = Ab.vload([2, 3])
offset = tvm.tir.ir_pass.Simplify(load.index)
assert tvm.ir.structural_equal(offset, n * 2 + 103)
tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103)
def test_buffer_index_merge_mult_mod():
......
......@@ -17,16 +17,6 @@
import tvm
from tvm import te
def test_simplify():
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
x = te.var('x')
e1 = tvm.tir.ir_pass.Simplify(x + 2 + 1)
assert(tvm.ir.structural_equal(e1, x + 3))
e2 = tvm.tir.ir_pass.Simplify(x * 3 + 5 * x)
assert(tvm.ir.structural_equal(e2, x * 8))
e3 = tvm.tir.ir_pass.Simplify(x - tdiv(x, 3) * 3)
assert(tvm.ir.structural_equal(e3, tmod(x, 3)))
def test_verify_ssa():
......
......@@ -31,8 +31,7 @@ def test_decorate_device():
s[A1].set_scope("shared")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt1 = tvm.tir.ir_pass.Simplify(stmt)
stmt1 = tvm.te.schedule.ScheduleOps(s, bounds)
stmt2 = tvm.tir.ir_pass.DecorateDeviceScope(stmt1)
assert isinstance(stmt2, tvm.tir.AttrStmt)
assert stmt2.attr_key == "device_scope"
......
......@@ -57,7 +57,7 @@ def test_copy_pad():
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
tvm.testing.assert_prim_expr_equal(src.elem_offset, 0)
assert pad_before[0].value == 1
assert pad_before[1].value == 0
assert pad_after[0].value == 1
......@@ -82,18 +82,15 @@ def test_single_point_test():
mod = tvm.tir.transform.StorageFlatten(64)(mod)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0
assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
tvm.testing.assert_prim_expr_equal(src.elem_offset, 0)
tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0)
tvm.testing.assert_prim_expr_equal(src.strides[0], 1)
tvm.testing.assert_prim_expr_equal(dst.strides[0], 1)
return tvm.tir.Evaluate(0)
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def assert_expr_equal(a, b):
assert tvm.tir.ir_pass.Simplify(a - b).value == 0
def test_copy_pad_split():
m = 4 * 3
A = te.placeholder((m, ), name="A")
......@@ -115,13 +112,13 @@ def test_copy_pad_split():
def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1)
tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1)
rpad_before = tvm.te.max(1 - xo * 4, 0)
rpad_after = tvm.te.max(xo * 4 - 7, 0)
assert_expr_equal(pad_before[0], rpad_before)
assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before)
tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after)
tvm.testing.assert_prim_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.tir.Evaluate(0)
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
......
......@@ -22,10 +22,13 @@ def lower_intrin(params, stmt):
"""wrapper to call transformation in stmt"""
lower_expr = isinstance(stmt, tvm.tir.PrimExpr)
stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
func = tvm.tir.PrimFunc(params, stmt).with_attr(
"target", tvm.target.create("llvm"))
func = tvm.tir.transform.LowerIntrin()(tvm.IRModule.from_expr(func))["main"]
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc(params, stmt).with_attr(
"target", tvm.target.create("llvm")))
mod = tvm.transform.Sequential([
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerIntrin()
])(mod)
func = mod["main"]
stmt = func.body
return stmt.value if lower_expr else stmt.body
......
......@@ -51,9 +51,10 @@ def test_flatten_prefetch():
[_A], stmt, {A: _A})
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.transform.Sequential([
tvm.tir.transform.StorageFlatten(64),
tvm.tir.transform.Simplify()])(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert stmt.extent.value == 2
assert isinstance(stmt.body, tvm.tir.For)
assert stmt.body.extent.value == 2
......@@ -74,9 +75,11 @@ def test_flatten_storage_align():
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.transform.Sequential([
tvm.tir.transform.StorageFlatten(64),
tvm.tir.transform.Simplify()])(mod)
stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert(stmt.body.extents[0].value == 17 * 8)
......@@ -103,11 +106,12 @@ def test_flatten_double_buffer():
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C], stmt))
mod = tvm.tir.transform.StorageFlatten(64)(mod)
stmt = mod["main"].body
mod = tvm.transform.Sequential([
tvm.tir.transform.StorageFlatten(64),
tvm.tir.transform.InjectDoubleBuffer(2),
tvm.tir.transform.Simplify()])(mod)
stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.tir.ir_pass.Simplify(stmt)
stmt = mod["main"].body
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
......
......@@ -25,8 +25,9 @@
#define TOPI_DETAIL_CONSTANT_UTILS_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/te/operation.h>
#include <string>
#include <vector>
......@@ -119,7 +120,7 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) {
bool result = expr_equal(lhs, rhs);
if (!result) {
PrimExpr zero(0);
result = expr_equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero);
result = expr_equal(tvm::arith::Analyzer().Simplify(lhs-rhs), zero);
}
return result;
}
......
......@@ -26,8 +26,8 @@
#include <topi/tags.h>
#include <topi/detail/constant_utils.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/te/operation.h>
......@@ -184,6 +184,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t,
pad_after.push_back(pad_before[i]);
}
}
arith::Analyzer analyzer;
CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size());
tvm::Array<tvm::PrimExpr> output_shape;
......@@ -200,13 +201,14 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t,
output_shape.push_back(t->shape[i]);
} else {
output_shape.push_back(
tvm::tir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
}
}
if (!pad_value.defined()) {
pad_value = tvm::tir::make_const(t->dtype, 0);
}
auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
tvm::Array<tvm::PrimExpr> indices;
tvm::Array<tvm::PrimExpr> sel;
......@@ -223,7 +225,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t,
indices.push_back(ovars[i]);
}
if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
sel.push_back(tvm::tir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
}
if (pad_mode == "edge") {
pad_idx.push_back(tvm::if_then_else(
......
......@@ -25,7 +25,7 @@
#define TOPI_NN_BNN_H_
#include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <topi/tags.h>
#include <topi/detail/constant_utils.h>
......@@ -55,11 +55,12 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data,
CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0)
<< "binarize_pack: axis size must be a multiple of 32";
arith::Analyzer analyzer;
auto n = ishape.size();
Array<PrimExpr> oshape;
for (size_t i = 0; i < n; ++i) {
oshape.push_back(i == static_cast<size_t>(axis) ?
tvm::tir::Simplify(indexdiv(ishape[i], 32)) :
analyzer.Simplify(indexdiv(ishape[i], 32)) :
ishape[i]);
}
......
......@@ -25,7 +25,7 @@
#define TOPI_NN_DILATE_H_
#include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <topi/tags.h>
#include <string>
......@@ -75,8 +75,9 @@ inline Tensor dilate(const Tensor& x,
<< ") must match dimension of x (" << n << ")";
Array<PrimExpr> out_shape;
arith::Analyzer analyzer;
for (size_t i = 0; i < n; ++i) {
out_shape.push_back(tvm::tir::Simplify(
out_shape.push_back(analyzer.Simplify(
(x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
}
......
......@@ -28,7 +28,7 @@
#include <topi/nn.h>
#include <topi/reduction.h>
#include <topi/tags.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <algorithm>
#include <string>
......@@ -102,10 +102,10 @@ inline Tensor pool_impl(const Tensor& x,
Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);
auto out_height = tvm::tir::Simplify(
arith::Analyzer analyzer;
auto out_height = analyzer.Simplify(
indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
auto out_width = tvm::tir::Simplify(
auto out_width = analyzer.Simplify(
indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
......@@ -212,11 +212,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad,
Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);
arith::Analyzer analyzer;
auto out_height =
tvm::tir::Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
auto out_width =
tvm::tir::Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
......@@ -711,7 +711,8 @@ inline Tensor pool_impl_nd(const Tensor& x,
pad_before.Set(ii, pad_head[i]);
pad_after.Set(ii, pad_tail[i]);
auto out_dim = tvm::tir::Simplify(
arith::Analyzer analyzer;
auto out_dim = analyzer.Simplify(
indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
out_shape.Set(ii, out_dim);
......
......@@ -375,12 +375,12 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
for (auto t : inputs) {
axis_sizes.push_back(t->shape[axis]);
}
arith::Analyzer analyzer;
PrimExpr join_size = axis_sizes[0];
for (size_t i = 1; i < axis_sizes.size(); ++i) {
join_size += axis_sizes[i];
}
join_size = tvm::tir::Simplify(join_size);
join_size = analyzer.Simplify(join_size);
Array<PrimExpr> out_shape;
for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
......
......@@ -167,7 +167,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
b, h, w, c = s[Output].op.axis
# num_thread here could be 728, it is larger than cuda.max_num_threads
num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value
num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value
target = tvm.target.Target.current()
if target and (target.target_name not in ["cuda", "nvptx"]):
num_thread = target.max_num_threads
......
......@@ -168,7 +168,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
b, h, w, c = s[Output].op.axis
# num_thread here could be 728, it is larger than cuda.max_num_threads
num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value
num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value
target = tvm.target.Target.current()
if target and (target.target_name not in ["cuda", "nvptx"]):
num_thread = target.max_num_threads
......
......@@ -45,9 +45,9 @@ def dilate(data, strides, name="DilatedInput"):
if len(strides) != n:
raise ValueError("data dimension and strides size dismatch : %d vs %d" % (
n, len(strides)))
ana = tvm.arith.Analyzer()
out_shape = tuple(
tvm.tir.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n))
ana.simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n))
def _dilate(*indices):
not_zero = []
......
......@@ -55,9 +55,9 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
if len(pad_after) != n:
raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (
n, len(pad_before)))
ana = tvm.arith.Analyzer()
out_shape = tuple(
tvm.tir.ir_pass.Simplify(
(data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n))
ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n))
pad_value = (pad_value if isinstance(pad_value, tvm.tir.PrimExpr)
else tvm.tir.const(pad_value, data.dtype))
def _pad(*indices):
......@@ -115,8 +115,9 @@ def mirror_pad(data,
if len(pad_after) != n:
raise ValueError("Input dimension and pad_after dismatch : %d vs %d" %
(n, len(pad_before)))
ana = tvm.arith.Analyzer()
out_shape = tuple(
tvm.tir.ir_pass.Simplify((data.shape[i] + pad_before[i] + pad_after[i]))
ana.simplify(data.shape[i] + pad_before[i] + pad_after[i])
for i in range(n))
assert mode in ('SYMMETRIC', 'REFLECT')
mode = int(mode == 'SYMMETRIC')
......
......@@ -101,7 +101,8 @@ def get_const_int(expr):
if isinstance(expr, Integral):
return expr
if not isinstance(expr, tvm.tir.IntImm):
expr = tvm.tir.ir_pass.Simplify(expr)
ana = tvm.arith.Analyzer()
expr = ana.simplify(expr)
if not isinstance(expr, tvm.tir.IntImm):
raise ValueError("Expect value to be constant int")
return int(expr.value)
......@@ -123,7 +124,8 @@ def get_const_float(expr):
if isinstance(expr, float):
return float(expr)
if not isinstance(expr, tvm.tir.FloatImm):
expr = tvm.tir.ir_pass.Simplify(expr)
ana = tvm.arith.Analyzer()
expr = ana.simplify(expr)
if not isinstance(expr, tvm.tir.FloatImm):
raise ValueError("Expect value to be constant float")
return float(expr.value)
......@@ -145,7 +147,8 @@ def equal_const_int(expr, value):
if isinstance(expr, Integral):
return expr == value
if not isinstance(expr, tvm.tir.IntImm):
expr = tvm.tir.ir_pass.Simplify(expr)
ana = tvm.arith.Analyzer()
expr = ana.simplify(expr)
if not isinstance(expr, tvm.tir.IntImm):
return False
return expr.value == value
......@@ -165,11 +168,13 @@ def get_const_tuple(in_tuple):
The output.
"""
ret = []
ana = None
for elem in in_tuple:
if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
ret.append(elem)
elif not isinstance(elem, (tvm.tir.IntImm, int)):
elem = tvm.tir.ir_pass.Simplify(elem)
ana = tvm.arith.Analyzer() if ana is None else ana
elem = ana.simplify(elem)
if not isinstance(elem, tvm.tir.IntImm):
ret.append(elem)
else:
......@@ -208,7 +213,7 @@ def simplify(expr):
out : Expr or int
The simplified output
"""
return tvm.tir.ir_pass.Simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr
return tvm.arith.Analyzer().simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr
def ravel_index(indices, shape):
......
......@@ -364,6 +364,7 @@ def inject_dma_intrin(stmt_in):
shape.append(1)
strides.append(elem_block)
analyzer = tvm.arith.Analyzer()
while base < ndim + 1:
x_size = 1
x_stride = buf.strides[ndim - base]
......@@ -378,7 +379,7 @@ def inject_dma_intrin(stmt_in):
break
x_size = x_size * buf.shape[k]
next_base = i + 1
shape.append(tvm.tir.ir_pass.Simplify(x_size))
shape.append(analyzer.simplify(x_size))
strides.append(x_stride)
assert next_base != base
base = next_base
......@@ -769,10 +770,11 @@ def inject_alu_intrin(stmt_in):
"""
env = get_env()
idxm = tvm.tir.indexmod
analyzer = tvm.arith.Analyzer()
def _do_fold(stmt):
def _equal(x, y):
return tvm.ir.structural_equal(tvm.tir.ir_pass.Simplify(x - y), 0)
return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff)
......@@ -791,7 +793,7 @@ def inject_alu_intrin(stmt_in):
next_ext = extents.pop()
if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
vext = tvm.tir.ir_pass.Simplify(vext * next_ext)
vext = analyzer.simplify(vext * next_ext)
else:
rev_src_coeff.append(vsrc)
rev_dst_coeff.append(vdst)
......@@ -851,7 +853,7 @@ def inject_alu_intrin(stmt_in):
if loop_body.value.name == 'shift_left':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = tvm.tir.ir_pass.Simplify(-loop_body.value.args[1])
rhs = analyzer.simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
......@@ -914,10 +916,10 @@ def inject_alu_intrin(stmt_in):
assert len(dst_coeff) > 1
assert len(extents) != 0
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.Simplify(
analyzer.simplify(
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.Simplify(
analyzer.simplify(
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(src_coeff[-2], 1)
assert tvm.ir.structural_equal(dst_coeff[-2], 1)
......@@ -942,9 +944,9 @@ def inject_alu_intrin(stmt_in):
src_coeff.append(src_offset)
dst_coeff.append(dst_offset)
src_coeff = [
tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
dst_coeff = [
tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
if extents:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment