Unverified Commit 92439166 by Tianqi Chen Committed by GitHub

[ARITH] Introduce base-class IRMutatorWithAnalyzer for scope dependent analysis (#3969)

parent cdbf4d85
......@@ -52,10 +52,10 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
void Analyzer::Bind(const VarExpr& v, const Range& range) {
CHECK(range.defined());
Var var(v.node_);
this->const_int_bound.Bind(var, range);
if (is_one(range->extent)) {
this->rewrite_simplify.Update(var, range->min);
this->canonical_simplify.Update(var, range->min);
this->Bind(var, range->min);
} else {
this->const_int_bound.Bind(var, range);
}
// skip modular_set
// skip rewrite simplify
......
......@@ -744,8 +744,8 @@ Mutate_(const Div* op, const Expr& self) {
return std::move(lhs);
}
// both lhs and extra are non-negative
if (parent_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
parent_->CanProveGreaterEqual(extra->Normalize(), 0)) {
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) {
......@@ -761,7 +761,7 @@ Mutate_(const Div* op, const Expr& self) {
}
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return make_zero(a.type());
}
......@@ -809,7 +809,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (!(TryCompare(temp, cval) == kLT && parent_->CanProveGreaterEqual(temp, 0))) {
if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) {
lhs.CopyOnWrite()->AddToSelf(
SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
}
......@@ -817,7 +817,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
return std::move(lhs);
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return make_zero(a.type());
}
......@@ -908,8 +908,8 @@ Mutate_(const Mod* op, const Expr& self) {
return make_zero(a.type());
}
// both lhs and extra are non-negative
if (parent_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
parent_->CanProveGreaterEqual(extra->Normalize(), 0)) {
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra);
if (temp.as<IntImm>()) {
return temp % c1.Eval();
......@@ -927,7 +927,7 @@ Mutate_(const Mod* op, const Expr& self) {
}
// Simplify the offset constant if necessary.
// (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
auto cbound = parent_->const_int_bound(Normalize(a));
auto cbound = analyzer_->const_int_bound(Normalize(a));
int64_t new_base = psum->base % cval;
if (cbound->min_value >= 0 &&
cbound->min_value - psum->base + new_base >= 0) {
......@@ -937,7 +937,7 @@ Mutate_(const Mod* op, const Expr& self) {
}
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return a;
}
......@@ -980,7 +980,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT &&
parent_->CanProveGreaterEqual(temp, 0)) {
analyzer_->CanProveGreaterEqual(temp, 0)) {
return temp;
} else {
// contonue to use logic below.
......@@ -997,7 +997,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv);
} else {
// if a >= 0 && a < cval, then result == a
auto cbound = parent_->const_int_bound(Normalize(a));
auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return a;
}
......@@ -1087,12 +1087,8 @@ SimplifyReduceCombiner(const Reduce* op) {
Expr CanonicalSimplifier::Impl::
Mutate_(const Reduce* op, const Expr& self) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
parent_->Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
Expr ret = IRMutator::Mutate_(op, self);
Expr ret = RewriteSimplifier::Impl::Mutate_(op, self);
op = ret.as<Reduce>();
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
......@@ -1121,7 +1117,6 @@ void CanonicalSimplifier::Update(const Var& var,
impl_->Update(var, info, override);
}
CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) {
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arithmetic/ir_mutator_with_analyzer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include "ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
using namespace ir;
Stmt IRMutatorWithAnalyzer::
Mutate_(const For* op, const Stmt& s) {
analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, condition);
then_case = this->Mutate(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate::make(0);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var,
Range::make_by_min_extent(0, op->value));
Stmt stmt = IRMutator::Mutate_(op, s);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AssertStmt* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Call* op, const Expr& self) {
// add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->args[1]);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
return self;
} else {
return Call::make(op->type, op->name,
{cond, true_value, false_value},
op->call_type);
}
}
return IRMutator::Mutate_(op, self);
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->true_value);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->false_value);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
// normal path
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return self;
} else {
return Select::make(cond, true_value, false_value);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Reduce* op, const Expr& self) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
return IRMutator::Mutate_(op, self);
}
} // namespace arith
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arithmetic/ir_mutator_with_analyzer.h
* \brief IR mutator base-class with an analyzer context.
*/
#ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
namespace tvm {
namespace arith {
/*!
* \brief IRMutator with an analyzer context.
*
* This class can sub-classed by ir mutators that need an analyzer.
* It will populates scope-related info such as bounds of loop-variables and constraints
* for the analyzer, so that the child class can do accurate context-dependent analysis.
*
* \sa src/arithmetic/ir_mutator_with_analyzer.cc
*/
class IRMutatorWithAnalyzer : public ir::IRMutator {
public:
explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {}
// override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::IfThenElse* op, const Stmt& self) override;
Stmt Mutate_(const ir::AttrStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::AssertStmt* op, const Stmt& self) override;
Expr Mutate_(const ir::Let* op, const Expr& self) override;
Expr Mutate_(const ir::Select* op, const Expr& self) override;
Expr Mutate_(const ir::Call* op, const Expr& self) override;
Expr Mutate_(const ir::Reduce* op, const Expr& self) override;
protected:
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.cc
* \brief Rewrite-rule based simplification.
*/
......@@ -80,7 +79,7 @@ TryCompare(const Expr& x, int64_t val) {
return kLT;
}
}
ConstIntBound dbound = parent_->const_int_bound(diff);
ConstIntBound dbound = analyzer_->const_int_bound(diff);
if (dbound->min_value > val) {
return kGT;
}
......@@ -94,7 +93,7 @@ TryCompare(const Expr& x, int64_t val) {
return kLE;
}
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
ModularSet dmod = analyzer_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
......@@ -490,7 +489,7 @@ Mutate_(const Div* op, const Expr& self) {
}
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = parent_->modular_set(b1.Eval());
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
......@@ -692,7 +691,7 @@ Mutate_(const Mod* op, const Expr& self) {
}
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = parent_->modular_set(b1.Eval());
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0) {
......@@ -740,7 +739,7 @@ Mutate_(const Mod* op, const Expr& self) {
// try modular analysis
if ((x % c1).Match(ret)) {
ModularSet mod = parent_->modular_set(x.Eval());
ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 &&
c1val > 0 &&
......@@ -777,7 +776,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
ModularSet bmod = parent_->modular_set(b1.Eval());
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
......@@ -923,7 +922,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
return broadcast(floormod(b1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
ModularSet bmod = parent_->modular_set(b1.Eval());
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0) {
......@@ -956,7 +955,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
// try modular analysis
if (floormod(x, c1).Match(ret)) {
ModularSet mod = parent_->modular_set(x.Eval());
ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 && c1val > 0) {
return floormod(mod->base, c1).Eval();
......@@ -990,8 +989,8 @@ Mutate_(const Min* op, const Expr& self) {
TVM_TRY_REWRITE(min(x, x), x);
// constant int bound
ConstIntBound a_bound = parent_->const_int_bound(op->a);
ConstIntBound b_bound = parent_->const_int_bound(op->b);
ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
if (a_bound->max_value <= b_bound->min_value) {
return op->a;
}
......@@ -1175,8 +1174,8 @@ Mutate_(const Max* op, const Expr& self) {
TVM_TRY_REWRITE(max(x, x), x);
// constant int bound
ConstIntBound a_bound = parent_->const_int_bound(op->a);
ConstIntBound b_bound = parent_->const_int_bound(op->b);
ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
if (a_bound->min_value >= b_bound->max_value) {
return op->a;
}
......@@ -1658,32 +1657,9 @@ Mutate_(const Or* op, const Expr& self) {
Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->true_value);
}
{
With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->false_value);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
// normal path
Expr ret;
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
ret = self;
} else {
ret = Select::make(cond, true_value, false_value);
}
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
op = ret.as<Select>();
if (op == nullptr) return ret;
// Pattern var to match any expression
PVar<Expr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y);
......@@ -1693,37 +1669,9 @@ Mutate_(const Select* op, const Expr& self) {
Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) {
// add condition context to if_then_else
Expr ret;
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(parent_, cond);
true_value = Mutate(op->args[1]);
}
{
With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
ret = self;
} else {
ret = Call::make(op->type, op->name,
{cond, true_value, false_value},
op->call_type);
}
} else {
ret = IRMutator::Mutate_(op, self);
}
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
op = ret.as<Call>();
if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
}
......@@ -1731,23 +1679,6 @@ Mutate_(const Call* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
// For now assume value does not has side-effect.
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
parent_->Bind(op->var, value);
return this->Mutate(op->body);
}
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
} else {
return Let::make(op->var, value, body);
}
}
Expr RewriteSimplifier::Impl::
Mutate_(const Variable* op, const Expr& self) {
Var var = GetRef<Var>(op);
auto it = var_map_.find(var);
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file rewrite_simplify.h
* \brief Rewrite-rule based simplification.
*/
......@@ -31,6 +30,7 @@
#include <unordered_map>
#include "const_fold.h"
#include "pattern_match.h"
#include "ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
......@@ -42,10 +42,12 @@ using namespace ir;
*
* This class can be inheritated for other simplifiers.
*/
class RewriteSimplifier::Impl : public IRMutator {
class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::Mutate_;
explicit Impl(Analyzer* parent)
: parent_(parent) {}
: IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override);
Expr Mutate_(const Add* op, const Expr& self) override;
......@@ -68,7 +70,6 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Not* op, const Expr& self) override;
Expr Mutate_(const Select* op, const Expr& self) override;
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override;
......@@ -83,8 +84,6 @@ class RewriteSimplifier::Impl : public IRMutator {
kLE,
kNE
};
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
......@@ -103,7 +102,7 @@ class RewriteSimplifier::Impl : public IRMutator {
private:
// Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val);
return analyzer_->CanProveGreaterEqual(x, val);
}
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file stmt_simplify.cc
* \brief Statement simplifier based on analyzer
*/
......@@ -28,113 +27,32 @@
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include "ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
using namespace ir;
class StmtSimplifier : public IRMutator {
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
using IRMutator::Mutate;
explicit StmtSimplifier(Analyzer* analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::Mutate;
using Parent::Mutate_;
Expr Mutate(Expr expr) final {
return analyzer_.Simplify(expr);
return analyzer_->Simplify(expr);
}
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
Stmt Simplify(Stmt stmt) {
return Mutate(stmt);
}
Stmt Mutate_(const For* op, const Stmt& s) final {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_.Bind(op->var, value);
return this->Mutate(op->body);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(&analyzer_, condition);
then_case = this->Mutate(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(&analyzer_, Mutate(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate::make(0);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_dom_.count(iv->var.get())) {
Range dom = Range::make_by_min_extent(0, op->value);
var_dom_[iv->var.get()] = dom;
analyzer_.Bind(iv->var, dom);
}
Stmt stmt = IRMutator::Mutate_(op, s);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
// AssertStmt
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
With<ConstraintContext> ctx(&analyzer_, condition);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
// eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) {
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) {
......@@ -145,11 +63,6 @@ class StmtSimplifier : public IRMutator {
}
return stmt;
}
protected:
Analyzer analyzer_;
// variable domain
std::unordered_map<const Variable*, Range> var_dom_;
};
} // namespace arith
......@@ -157,8 +70,11 @@ class StmtSimplifier : public IRMutator {
namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::StmtSimplifier().Simplify(
stmt, vrange);
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return arith::StmtSimplifier(&analyzer).Simplify(stmt);
}
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
......@@ -179,8 +95,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
}
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::StmtSimplifier().Simplify(
stmt, vrange);
return CanonicalSimplify(stmt, vrange);
}
} // namespace ir
} // namespace tvm
......@@ -30,5 +30,23 @@ def test_stmt_simplify():
assert isinstance(body.body, tvm.stmt.Store)
def test_thread_extent_simplify():
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C")
n = tvm.var("n")
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
ib.scope_attr(tx, "thread_extent", n)
ib.scope_attr(tx, "thread_extent", n)
ib.scope_attr(ty, "thread_extent", 1)
with ib.if_scope(tx + ty < 12):
A[tx] = C[tx + ty]
body = tvm.stmt.LetStmt(n, 10, ib.get())
body = tvm.ir_pass.CanonicalSimplify(body)
assert isinstance(body.body.body.body, tvm.stmt.Store)
if __name__ == "__main__":
test_stmt_simplify()
test_thread_extent_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment