Unverified Commit 203ca7a0 by Tianqi Chen Committed by GitHub

[REFACTOR] Migrate Low-level IR Passes into the New Stmt/Expr Mutator (#4607)

* CombineContextCall

* Migrate BoundChecker

* Migrate CoprocSync

* Migrate detect_device

* Migrate loop_partition

* Migrate infer_fragement

* Migrate inject_copy_intrin

* Migrate inject double buffer

* Migrate lower_intrin and simplify

* Migrate storage flatten

* Migrate inject prefetch

* Migrate inject_virtual_thread

* migrate inline

* Migrate lift attr scope

* Migrate custom datatypes

* migrate lower_thread_all_reduce

* Migrate lower_tvm_builtin

* migrate lower_warp memory

* Migrate make_api.cc

* Migrate remap_thread_axis

* Migrate remove_no_op

* migrate rewrite_unsafe_select

* Migrate skip_assert simple_passes

* Migrate split_host_device

* Migrate ssa

* Migrate storage_access

* Migrate storage_rewrite

* Migrate tensor_core

* Migrate unroll_loop

* Migrate vectorize

* Migrate verify compact_buffer gpu_code

* Migrate verify_memory

* Migrate storage_sync

* Remove unused refs to mutator

* Migrate hybrid_op

* Migrate tensorize

* Migrate schedule ops

* Migrate schedule_dataflow_rewrite

* Migrate auto_inline_elemwise

* Remove unecessary ref to visitor

* remove unecessary ref

* Migrate bound_deducer

* Migrate domain_touched

* Migrate autotvm feature touch extractor

* Add annotations
parent 3f43bee0
......@@ -546,6 +546,35 @@ class StmtExprMutator :
}
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
......@@ -122,27 +122,6 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const StringImm* op, const Expr& e);
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
......@@ -145,15 +145,6 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op);
};
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir
} // namespace tvm
......
......@@ -25,8 +25,7 @@
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/api_registry.h>
namespace tvm {
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <tvm/api_registry.h>
......@@ -38,17 +38,17 @@ using namespace ir;
// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
class VariablePathFinder: public ExprVisitor {
public:
explicit VariablePathFinder(Expr target) : target_(target) {}
void Visit(const ObjectRef& node) final {
void VisitExpr(const Expr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true;
IRVisitor::Visit(node);
ExprVisitor::VisitExpr(node);
if (!found_) path_.pop_back();
}
......@@ -64,14 +64,14 @@ class VariablePathFinder: public IRVisitor {
// return empty vector to represent failure
std::vector<const Object*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
v(expr);
return v.path_;
}
enum CompareOp {kGreater, kLess, kEqual};
// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
class BoundDeducer: public ExprVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
......@@ -82,39 +82,39 @@ class BoundDeducer: public IRVisitor {
void Deduce();
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
} else {
success_ = false;
return;
}
}
void Visit_(const LT* op) final {
void VisitExpr_(const LT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const LE* op) final {
void VisitExpr_(const LE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const GT* op) final {
void VisitExpr_(const GT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const GE* op) final {
void VisitExpr_(const GE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const Add* op) final {
void VisitExpr_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}
void Visit_(const Sub* op) final {
void VisitExpr_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result_ += op->b;
......@@ -123,10 +123,10 @@ class BoundDeducer: public IRVisitor {
result_ = - result_;
comp_op = ReverseOp(comp_op);
}
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}
void Visit_(const Mul* op) final {
void VisitExpr_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;
......@@ -171,7 +171,7 @@ class BoundDeducer: public IRVisitor {
// ( x <= -3/-2 --> x <= 1)
}
}
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}
Expr result_;
......@@ -194,17 +194,17 @@ class BoundDeducer: public IRVisitor {
Analyzer analyzer_;
};
class BoundDeduceInputChecker: public IRVisitor {
class BoundDeduceInputChecker: public ExprVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
Visit(deducer_->expr_);
this->VisitExpr(deducer_->expr_);
return target_count == 1;
}
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
private:
......@@ -305,7 +305,7 @@ void BoundDeducer::Deduce() {
}
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_);
this->VisitExpr(expr_);
}
void BoundDeducer::Relax() {
......
......@@ -23,7 +23,6 @@
*/
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"
......@@ -435,30 +434,30 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
Expr CanonicalSimplify(Expr expr) {
expr = Mutate(expr);
expr = operator()(expr);
return expr;
}
// override the original mutate function.
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
Expr VisitExpr(const Expr& input_expr) final {
auto expr = Rewriter::VisitExpr(input_expr);
return Normalize(expr);
}
// Normal mutation without normalization.
Expr CanonicalMutate(Expr expr) {
return IRMutator::Mutate(expr);
return Rewriter::VisitExpr(expr);
}
using Rewriter::Mutate_;
Expr Mutate_(const Add* op, const Expr& self) final;
Expr Mutate_(const Sub* op, const Expr& self) final;
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
Expr Mutate_(const FloorDiv* op, const Expr& self) final;
Expr Mutate_(const FloorMod* op, const Expr& self) final;
Expr Mutate_(const Reduce* op, const Expr& self) final;
using Rewriter::VisitExpr_;
Expr VisitExpr_(const Add* op) final;
Expr VisitExpr_(const Sub* op) final;
Expr VisitExpr_(const Mul* op) final;
Expr VisitExpr_(const Div* op) final;
Expr VisitExpr_(const Mod* op) final;
Expr VisitExpr_(const FloorDiv* op) final;
Expr VisitExpr_(const FloorMod* op) final;
Expr VisitExpr_(const Reduce* op) final;
private:
/*!
......@@ -567,9 +566,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
};
Expr CanonicalSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
VisitExpr_(const Add* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -593,9 +592,9 @@ Mutate_(const Add* op, const Expr& self) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
VisitExpr_(const Sub* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -620,9 +619,9 @@ Mutate_(const Sub* op, const Expr& self) {
Expr CanonicalSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) {
VisitExpr_(const Mul* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -652,7 +651,7 @@ Mutate_(const Mul* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return Mul::make(a, b);
}
......@@ -727,9 +726,9 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) {
VisitExpr_(const Div* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
Expr a = this->CanonicalMutate(op->a);
......@@ -781,16 +780,16 @@ Mutate_(const Div* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return Div::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) {
VisitExpr_(const FloorDiv* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
......@@ -837,7 +836,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return FloorDiv::make(a, b);
}
......@@ -866,7 +865,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
// Do a recursive call to simplify the mod with the new factor.
if (new_upper_factor < lhs->upper_factor &&
lhs->upper_factor != SplitExprNode::kPosInf) {
auto updated = ToSplitExpr(Mutate(ModImpl(
auto updated = ToSplitExpr(this->VisitExpr(ModImpl(
lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
// re-apply the lower_factor
if (lhs->lower_factor != 1) {
......@@ -894,9 +893,9 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
VisitExpr_(const Mod* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -957,16 +956,16 @@ Mutate_(const Mod* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return Mod::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) {
VisitExpr_(const FloorMod* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -1017,7 +1016,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return FloorMod::make(a, b);
}
......@@ -1029,7 +1028,7 @@ SimplifyReduceCombiner(const Reduce* op) {
// First simplify the results
Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) {
Expr new_res = Mutate(res);
Expr new_res = this->VisitExpr(res);
simplified_result.push_back(new_res);
}
......@@ -1078,7 +1077,7 @@ SimplifyReduceCombiner(const Reduce* op) {
if (used[i]) {
// We simplify the result and identity, but not the source
new_result.push_back(simplified_result[i]);
new_identity.push_back(Mutate(op->combiner->identity_element[i]));
new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i]));
new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]);
......@@ -1095,9 +1094,9 @@ SimplifyReduceCombiner(const Reduce* op) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Reduce* op, const Expr& self) {
VisitExpr_(const Reduce* op) {
// Recursively call simplification when necessary.
Expr ret = RewriteSimplifier::Impl::Mutate_(op, self);
Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
op = ret.as<Reduce>();
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
......@@ -1106,7 +1105,7 @@ Mutate_(const Reduce* op, const Expr& self) {
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
return Mutate(
return this->VisitExpr(
Select::make(op->condition,
op->source[op->value_index],
op->combiner->identity_element[op->value_index]));
......
......@@ -25,7 +25,6 @@
#define TVM_ARITHMETIC_CONST_FOLD_H_
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <algorithm>
#include <cmath>
......
......@@ -23,7 +23,6 @@
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -23,7 +23,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
......@@ -36,13 +36,13 @@ namespace arith {
using namespace ir;
// Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public IRVisitor {
class FuncTouchedDomain final : public StmtExprVisitor {
public:
FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides)
: tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {}
Domain Find(const Stmt& stmt) {
this->Visit(stmt);
operator()(stmt);
Domain ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
......@@ -51,49 +51,49 @@ class FuncTouchedDomain final : public IRVisitor {
return ret;
}
void Visit_(const For *op) final {
void VisitStmt_(const For *op) final {
const Variable* var = op->loop_var.get();
dom_map_[var] = IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
void Visit_(const LetStmt* op) final {
void VisitStmt_(const LetStmt* op) final {
dom_map_[op->var.get()] =
arith::EvalSet(op->value, dom_map_);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->var.get());
}
/* TODO: Thread extent unitest not generated.*/
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const Provide* op) final {
void VisitStmt_(const Provide* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
private:
......
......@@ -30,41 +30,44 @@ namespace arith {
using namespace ir;
Stmt IRMutatorWithAnalyzer::
Mutate_(const For* op, const Stmt& s) {
VisitStmt_(const For* op) {
analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
Range::make_by_min_extent(op->min, op->extent));
return StmtExprMutator::VisitStmt_(op);
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
VisitStmt_(const LetStmt* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return LetStmt::make(op->var, value, body);
auto n = this->CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
VisitStmt_(const IfThenElse* op) {
Expr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, condition);
then_case = this->Mutate(op->then_case);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not::make(condition)));
else_case = this->Mutate(op->else_case);
else_case = this->VisitStmt(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
......@@ -77,57 +80,65 @@ Mutate_(const IfThenElse* op, const Stmt& s) {
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(condition, then_case, else_case);
auto n = this->CopyOnWrite(op);
n->condition = std::move(condition);
n->then_case = std::move(then_case);
n->else_case = std::move(else_case);
return Stmt(n);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AttrStmt* op, const Stmt& s) {
VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var,
Range::make_by_min_extent(0, op->value));
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AssertStmt* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
VisitStmt_(const AssertStmt* op) {
Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return AssertStmt::make(condition, message, body);
auto n = this->CopyOnWrite(op);
n->condition = std::move(condition);
n->message = std::move(message);
n->body = std::move(body);
return Stmt(n);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Call* op, const Expr& self) {
VisitExpr_(const Call* op) {
// add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr cond = this->VisitExpr(op->args[0]);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->args[1]);
true_value = this->VisitExpr(op->args[1]);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->args[2]);
false_value = this->VisitExpr(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
......@@ -138,45 +149,45 @@ Mutate_(const Call* op, const Expr& self) {
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
return self;
return GetRef<Expr>(op);
} else {
return Call::make(op->dtype, op->name,
{cond, true_value, false_value},
op->call_type);
}
}
return IRMutator::Mutate_(op, self);
return StmtExprMutator::VisitExpr_(op);
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
VisitExpr_(const Let* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Expr body = this->Mutate(op->body);
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition);
VisitExpr_(const Select* op) {
Expr cond = this->VisitExpr(op->condition);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->true_value);
true_value = VisitExpr(op->true_value);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->false_value);
false_value = VisitExpr(op->false_value);
}
if (is_zero(cond)) {
return false_value;
......@@ -188,20 +199,20 @@ Mutate_(const Select* op, const Expr& self) {
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return self;
return GetRef<Expr>(op);
} else {
return Select::make(cond, true_value, false_value);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Reduce* op, const Expr& self) {
VisitExpr_(const Reduce* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
return IRMutator::Mutate_(op, self);
return StmtExprMutator::VisitExpr_(op);
}
} // namespace arith
......
......@@ -24,9 +24,9 @@
#ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <utility>
namespace tvm {
namespace arith {
......@@ -40,23 +40,24 @@ namespace arith {
*
* \sa src/arithmetic/ir_mutator_with_analyzer.cc
*/
class IRMutatorWithAnalyzer : public ir::IRMutator {
class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
public:
explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {}
using IRMutator::Mutate_;
using StmtExprMutator::VisitStmt_;
using StmtExprMutator::VisitExpr_;
// override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::IfThenElse* op, const Stmt& self) override;
Stmt Mutate_(const ir::AttrStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::AssertStmt* op, const Stmt& self) override;
Expr Mutate_(const ir::Let* op, const Expr& self) override;
Expr Mutate_(const ir::Select* op, const Expr& self) override;
Expr Mutate_(const ir::Call* op, const Expr& self) override;
Expr Mutate_(const ir::Reduce* op, const Expr& self) override;
Stmt VisitStmt_(const ir::For* op) override;
Stmt VisitStmt_(const ir::LetStmt* op) override;
Stmt VisitStmt_(const ir::IfThenElse* op) override;
Stmt VisitStmt_(const ir::AttrStmt* op) override;
Stmt VisitStmt_(const ir::AssertStmt* op) override;
Expr VisitExpr_(const ir::Let* op) override;
Expr VisitExpr_(const ir::Select* op) override;
Expr VisitExpr_(const ir::Call* op) override;
Expr VisitExpr_(const ir::Reduce* op) override;
protected:
/*! \brief internal analyzer field. */
......
......@@ -27,43 +27,43 @@
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
class IRVisitorWithAnalyzer final : public IRVisitor {
class IRVisitorWithAnalyzer final : public StmtExprVisitor {
public:
Expr Simplify(const Expr& expr) {
return analyzer_.Simplify(expr);
}
void Visit_(const For* op) {
void VisitStmt_(const For* op) {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRVisitor::Visit_(op);
return StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const AttrStmt* op) {
void VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var,
Range::make_by_min_extent(0, op->value));
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Reduce* op) {
void VisitExpr_(const Reduce* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
protected:
......
......@@ -24,7 +24,6 @@
// Acknowledgement: Most rewrite-rules are from Halide.
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <algorithm>
#include "const_fold.h"
#include "pattern_match.h"
......@@ -69,7 +68,7 @@ using namespace ir;
// try to prove x equals val
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
TryCompare(const Expr& x, int64_t val) {
Expr diff = Mutate(x);
Expr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImm>()) {
if (ptr->value == val) {
return kEQ;
......@@ -117,8 +116,8 @@ Update(const Var& var, const Expr& info, bool override) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Add* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Add>();
Expr const_res = TryConstFold<Add>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -232,8 +231,8 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& const
}
Expr RewriteSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Sub* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Sub>();
Expr const_res = TryConstFold<Sub>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -431,8 +430,8 @@ Mutate_(const Sub* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Mul* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Mul>();
Expr const_res = TryConstFold<Mul>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -470,8 +469,8 @@ Mutate_(const Mul* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Div* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Div>();
Expr const_res = TryConstFold<Div>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -692,8 +691,8 @@ Mutate_(const Div* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Mod* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Mod>();
Expr const_res = TryConstFold<Mod>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -782,8 +781,8 @@ Mutate_(const Mod* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const FloorDiv* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDiv>();
Expr const_res = TryConstFold<FloorDiv>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -926,8 +925,8 @@ Mutate_(const FloorDiv* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const FloorMod* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorMod>();
Expr const_res = TryConstFold<FloorMod>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -996,8 +995,8 @@ Mutate_(const FloorMod* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Min* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Min* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Min>();
Expr const_res = TryConstFold<Min>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1181,8 +1180,8 @@ Mutate_(const Min* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Max* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Max* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Max>();
Expr const_res = TryConstFold<Max>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1354,8 +1353,8 @@ Mutate_(const Max* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const EQ* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const EQ* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQ>();
Expr const_res = TryConstFold<EQ>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1388,28 +1387,28 @@ Mutate_(const EQ* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const NE* op, const Expr& self) {
return Mutate(Not::make(op->a == op->b));
VisitExpr_(const NE* op) {
return this->VisitExpr(Not::make(op->a == op->b));
}
Expr RewriteSimplifier::Impl::
Mutate_(const LE* op, const Expr& self) {
return Mutate(Not::make(op->b < op->a));
VisitExpr_(const LE* op) {
return this->VisitExpr(Not::make(op->b < op->a));
}
Expr RewriteSimplifier::Impl::
Mutate_(const GT* op, const Expr& self) {
return Mutate(op->b < op->a);
VisitExpr_(const GT* op) {
return this->VisitExpr(op->b < op->a);
}
Expr RewriteSimplifier::Impl::
Mutate_(const GE* op, const Expr& self) {
return Mutate(Not::make(op->a < op->b));
VisitExpr_(const GE* op) {
return this->VisitExpr(Not::make(op->a < op->b));
}
Expr RewriteSimplifier::Impl::
Mutate_(const LT* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const LT* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LT>();
Expr const_res = TryConstFold<LT>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1564,8 +1563,8 @@ Mutate_(const LT* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Not* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Not* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Not>();
Expr const_res = TryConstFold<Not>(op->a);
if (const_res.defined()) return const_res;
......@@ -1589,8 +1588,8 @@ Mutate_(const Not* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const And* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const And* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<And>();
Expr const_res = TryConstFold<And>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1638,8 +1637,8 @@ Mutate_(const And* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Or* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Or* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Or>();
Expr const_res = TryConstFold<Or>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1688,8 +1687,8 @@ Mutate_(const Or* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
VisitExpr_(const Select* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Select>();
if (op == nullptr) return ret;
// Pattern var to match any expression
......@@ -1699,9 +1698,9 @@ Mutate_(const Select* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) {
VisitExpr_(const Call* op) {
// add condition context to if_then_else
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Call>();
if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
......@@ -1729,35 +1728,35 @@ Mutate_(const Call* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Variable* op, const Expr& self) {
VisitExpr_(const Variable* op) {
Var var = GetRef<Var>(op);
auto it = var_map_.find(var);
if (it != var_map_.end()) {
return it->second;
}
return self;
return GetRef<Expr>(op);
}
Expr RewriteSimplifier::Impl::
Mutate_(const Cast* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Cast* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Cast>();
return cast(op->dtype, op->value);
}
Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
VisitExpr_(const Let* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
return this->VisitExpr(op->body);
}
Expr body = this->Mutate(op->body);
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
......@@ -1768,7 +1767,7 @@ Expr RewriteSimplifier::operator()(const Expr& expr) {
Expr res = expr;
int max_iter = 2;
for (int i = 0; i < max_iter; ++i) {
Expr new_expr = impl_->Mutate(res);
Expr new_expr = impl_->operator()(res);
if (new_expr.same_as(res)) return res;
res = new_expr;
}
......
......@@ -26,7 +26,6 @@
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
#include <vector>
#include "const_fold.h"
......@@ -45,35 +44,35 @@ using namespace ir;
*/
class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::Mutate_;
using IRMutatorWithAnalyzer::VisitExpr_;
explicit Impl(Analyzer* parent)
: IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override);
Expr Mutate_(const Add* op, const Expr& self) override;
Expr Mutate_(const Sub* op, const Expr& self) override;
Expr Mutate_(const Mul* op, const Expr& self) override;
Expr Mutate_(const Div* op, const Expr& self) override;
Expr Mutate_(const Mod* op, const Expr& self) override;
Expr Mutate_(const FloorDiv* op, const Expr& self) override;
Expr Mutate_(const FloorMod* op, const Expr& self) override;
Expr Mutate_(const Min* op, const Expr& self) override;
Expr Mutate_(const Max* op, const Expr& self) override;
Expr Mutate_(const EQ* op, const Expr& self) override;
Expr Mutate_(const NE* op, const Expr& self) override;
Expr Mutate_(const LT* op, const Expr& self) override;
Expr Mutate_(const LE* op, const Expr& self) override;
Expr Mutate_(const GT* op, const Expr& self) override;
Expr Mutate_(const GE* op, const Expr& self) override;
Expr Mutate_(const And* op, const Expr& self) override;
Expr Mutate_(const Or* op, const Expr& self) override;
Expr Mutate_(const Not* op, const Expr& self) override;
Expr Mutate_(const Select* op, const Expr& self) override;
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
void Update(const Var& var, const Expr& info, bool override_info);
Expr VisitExpr_(const Add* op) override;
Expr VisitExpr_(const Sub* op) override;
Expr VisitExpr_(const Mul* op) override;
Expr VisitExpr_(const Div* op) override;
Expr VisitExpr_(const Mod* op) override;
Expr VisitExpr_(const FloorDiv* op) override;
Expr VisitExpr_(const FloorMod* op) override;
Expr VisitExpr_(const Min* op) override;
Expr VisitExpr_(const Max* op) override;
Expr VisitExpr_(const EQ* op) override;
Expr VisitExpr_(const NE* op) override;
Expr VisitExpr_(const LT* op) override;
Expr VisitExpr_(const LE* op) override;
Expr VisitExpr_(const GT* op) override;
Expr VisitExpr_(const GE* op) override;
Expr VisitExpr_(const And* op) override;
Expr VisitExpr_(const Or* op) override;
Expr VisitExpr_(const Not* op) override;
Expr VisitExpr_(const Select* op) override;
Expr VisitExpr_(const Call* op) override;
Expr VisitExpr_(const Variable* op) override;
Expr VisitExpr_(const Cast* op) override;
Expr VisitExpr_(const Let* op) override;
std::function<void()> EnterConstraint(const Expr& constraint);
......@@ -123,7 +122,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
Expr res = this->VisitExpr(x);
--recur_depth_;
return res;
}
......
......@@ -24,7 +24,6 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include "ir_mutator_with_analyzer.h"
......@@ -40,44 +39,47 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
: IRMutatorWithAnalyzer(analyzer) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::Mutate;
using Parent::Mutate_;
using Parent::VisitStmt;
using Parent::VisitStmt_;
Expr Mutate(Expr expr) final {
Expr VisitExpr(const Expr& expr) final {
return analyzer_->Simplify(expr);
}
Stmt Simplify(Stmt stmt) {
return Mutate(stmt);
return operator()(std::move(stmt));
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return IRMutator::Mutate_(op, s);
return Parent::VisitStmt_(op);
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt VisitStmt_(const LetStmt* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value);
return Mutate(op->body);
return this->VisitStmt(op->body);
}
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return LetStmt::make(op->var, value, body);
auto n = this->CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
}
// eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
......@@ -85,7 +87,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Evaluate::make(0);
}
}
return stmt;
return GetRef<Stmt>(op);
}
};
......@@ -98,7 +100,7 @@ Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return arith::StmtSimplifier(&analyzer).Simplify(stmt);
return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt));
}
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
......@@ -119,7 +121,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
}
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(stmt, vrange);
return CanonicalSimplify(std::move(stmt), vrange);
}
} // namespace ir
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -29,7 +29,7 @@ namespace tvm {
namespace autotvm {
// for loop
void FeatureVisitor::Visit_(const For *op) {
void FeatureVisitor::VisitStmt_(const For* op) {
const auto *extent = op->extent.as<IntImm>();
int64_t loop_extent = -1;
if (extent != nullptr)
......@@ -51,13 +51,13 @@ void FeatureVisitor::Visit_(const For *op) {
}
if (EnterItervar_(op->loop_var, loop_extent, ann)) {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitItervar_();
}
}
// parallel axis, virtual thread
void FeatureVisitor::Visit_(const AttrStmt *op) {
void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
......@@ -86,24 +86,24 @@ void FeatureVisitor::Visit_(const AttrStmt *op) {
}
if (EnterItervar_(var, extent->value, ann)) {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitItervar_();
}
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
// memory access
void FeatureVisitor::Visit_(const Load *op) {
void FeatureVisitor::VisitExpr_(const Load* op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
ExitMem_();
}
void FeatureVisitor::Visit_(const Store *op) {
void FeatureVisitor::VisitStmt_(const Store* op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitMem_();
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -27,7 +27,7 @@
#define TVM_AUTOTVM_FEATURE_VISITOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <string>
namespace tvm {
......@@ -48,15 +48,18 @@ enum AnnotationType {
* \brief A base class for feature extractor, used for processing
* for loop and memory access in the IR
*/
class FeatureVisitor : public IRVisitor {
class FeatureVisitor : public StmtExprVisitor {
public:
// for loop
void Visit_(const For *op);
void Visit_(const AttrStmt *op);
void VisitStmt_(const For *op);
void VisitStmt_(const AttrStmt *op);
// memory access
void Visit_(const Load *op);
void Visit_(const Store *op);
void VisitExpr_(const Load *op);
void VisitStmt_(const Store *op);
using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_;
protected:
/*!
......
......@@ -44,14 +44,14 @@ int ParallelLevel(AnnotationType ann) {
}
// get touch pattern from index expression
class IndexParser: public IRVisitor {
class IndexParser: public ExprVisitor {
public:
void Parse(Expr expr) {
pattern_map.clear();
this->Visit(expr);
this->VisitExpr(expr);
}
void Visit_(const Variable *op) {
void VisitExpr_(const Variable *op) {
// TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern();
......@@ -60,13 +60,13 @@ class IndexParser: public IRVisitor {
}
}
void Visit_(const Mul *op) {
void VisitExpr_(const Mul *op) {
if (op->a.as<Variable>()) {
if (const auto stride = op->b.as<IntImm>()) {
next_stride_ = stride->value;
}
}
IRVisitor::Visit_(op);
ExprVisitor::VisitExpr_(op);
}
std::unordered_map<const Variable*, TouchPattern> pattern_map;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -26,7 +26,7 @@
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/api_registry.h>
#include <stack>
#include <vector>
......@@ -85,39 +85,39 @@ struct ItervarFeature {
// extract iter vars and their touch pattern from ir
class TouchExtractor : public FeatureVisitor {
public:
void Analyze(Stmt stmt) {
this->Visit(stmt);
void Analyze(const Stmt& stmt) {
operator()(stmt);
}
// arithmetic stats
void Visit_(const Add *op) {
void VisitExpr_(const Add *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Sub *op) {
void VisitExpr_(const Sub *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Mul *op) {
void VisitExpr_(const Mul *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Div *op) {
void VisitExpr_(const Div *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Mod *op) {
void VisitExpr_(const Mod *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
......@@ -134,7 +134,7 @@ class TouchExtractor : public FeatureVisitor {
std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing
std::deque<size_t> skip_stack_size_;
using IRVisitor::Visit_;
using FeatureVisitor::VisitExpr_;
};
} // namespace autotvm
......
......@@ -24,8 +24,8 @@
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_set>
#include <string>
#include <utility>
......@@ -538,7 +538,7 @@ namespace {
* must be Reduce as well; and their inputs should have the
* same attribute except value_index.
*/
class ComputeVerifier final : protected ir::IRVisitor {
class ComputeVerifier final : protected ir::ExprVisitor {
public:
/// Special member functions
//@{
......@@ -567,20 +567,20 @@ class ComputeVerifier final : protected ir::IRVisitor {
}
level_ = 0;
ir::IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
}
protected:
/// Visitor implementation
//@{
void Visit(const ObjectRef& n) final {
void VisitExpr(const Expr& n) final {
++level_;
ir::IRVisitor::Visit(n);
ExprVisitor::VisitExpr(n);
--level_;
}
void Visit_(const ir::Reduce* op) final {
void VisitExpr_(const ir::Reduce* op) final {
// Check for non top level reductions
CHECK(0 == level_)
<< "Reductions are only allowed at the top level of compute. "
......
......@@ -24,7 +24,7 @@
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include <unordered_set>
......@@ -221,7 +221,7 @@ namespace op {
Stmt ApplyLoopShapes(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
class LoopSpliter : public IRMutator {
class LoopSpliter : public StmtExprMutator {
Expr factor;
const Variable *parent;
IterVar inner, outer;
......@@ -247,7 +247,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
}
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For *op) final {
if (op->loop_var.get() == parent) {
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = inner + outer * factor;
......@@ -261,11 +261,11 @@ Stmt ApplyLoopShapes(const Stage &stage,
splitted = true;
return ret;
}
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
};
class LoopFuser : public IRMutator {
class LoopFuser : public StmtExprMutator {
const IterVar &parent;
const Variable *inner;
const Variable *outer;
......@@ -280,8 +280,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
extent(0), fused(false) {}
// TODO(@were): Handle imperfect loops
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For* op) final {
if (op->loop_var.get() == inner) {
CHECK(under_outer);
std::unordered_map<const Variable *, Expr> rmap;
......@@ -291,7 +290,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
return ir::Substitute(op->body, rmap);
} else if (op->loop_var.get() == outer) {
under_outer = true;
Stmt body = IRMutator::Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap);
......@@ -299,25 +298,25 @@ Stmt ApplyLoopShapes(const Stage &stage,
return For::make(parent->var, Expr(0), extent * op->extent,
op->for_type, op->device_api, body);
} else if (under_outer) {
Stmt body = IRMutator::Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap);
extent = extent * op->extent;
return body;
}
return IRMutator::Mutate(stmt);
return StmtExprMutator::VisitStmt_(op);
}
};
for (auto &rel : stage->relations) {
if (const SplitNode *split = rel.as<SplitNode>()) {
LoopSpliter Spliter(split, dom_map);
stmt = Spliter.Mutate(stmt);
stmt = Spliter(stmt);
CHECK(Spliter.splitted);
} else if (const FuseNode *fuse = rel.as<FuseNode>()) {
LoopFuser Fuser(fuse);
stmt = Fuser.Mutate(stmt);
stmt = Fuser(stmt);
CHECK(Fuser.fused);
}
}
......@@ -327,14 +326,14 @@ Stmt ApplyLoopShapes(const Stage &stage,
Stmt ApplyLoopAnnotations(const Stage &stage,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
class LoopAnnotator : public IRMutator {
class LoopAnnotator : public StmtMutator {
const Variable *var;
const IterVarAttr &attr;
public:
LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For *op) final {
if (op->loop_var.get() == var) {
if (attr->bind_thread.defined()) {
const auto &iter_var = attr->bind_thread;
......@@ -352,7 +351,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
}
}
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
};
......@@ -381,7 +380,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
CHECK_EQ(found, 1) << " iter var should be found exactly once!";
if (need_change) {
stmt = LoopAnnotator(var, attr).Mutate(stmt);
stmt = LoopAnnotator(var, attr)(std::move(stmt));
}
}
return stmt;
......@@ -411,7 +410,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
}
}
class LoopReorder : public IRMutator {
class LoopReorder : public StmtMutator {
const Stage &stage;
const std::unordered_map<IterVar, Range> &dom_map;
const std::unordered_map<const Variable *, IterVar> &reorder;
......@@ -422,13 +421,13 @@ Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<const Variable*, IterVar> &reorder)
: stage(stage), dom_map(dom_map), reorder(reorder) {}
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For* op) final {
// Reorder from in to out
Stmt body_ = IRMutator::Mutate(op->body);
Stmt body_ = this->VisitStmt(op->body);
CHECK(reorder.count(op->loop_var.get()));
auto target = reorder.find(op->loop_var.get())->second;
if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
return stmt;
return GetRef<Stmt>(op);
const Stmt &body = op->body.same_as(body_) ? op->body : body_;
ForType for_type = IterVarTypeToForType(target->iter_type);
if (stage->iter_var_attrs.count(target)) {
......@@ -441,7 +440,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
};
if (need_reorder)
return LoopReorder(stage, dom_map, reorder).Mutate(stmt);
return LoopReorder(stage, dom_map, reorder)(stmt);
return stmt;
}
......@@ -479,21 +478,21 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) {
}
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
class ProviderReplacer : public ir::StmtMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt &s) {
Stmt VisitStmt_(const ir::Provide* op) final {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
return this->VisitStmt(ret);
}
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
// whether it is found.
......@@ -506,7 +505,7 @@ class ProviderReplacer : public ir::IRMutator {
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor> &replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
Stmt ret = repl(stmt);
return repl.found ? ret : stmt;
}
} // namespace op
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -25,8 +25,6 @@
#define TVM_OP_HYBRID_OP_H_
#include <tvm/expr.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include <unordered_set>
......
......@@ -23,8 +23,8 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <string>
#include "op_util.h"
#include "../schedule/message_passing.h"
......@@ -186,12 +186,12 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
class TensorReplacer : public ir::StmtExprMutator {
public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
Expr VisitExpr_(const ir::Call* op) final {
if (op->call_type == ir::Call::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
......@@ -200,10 +200,10 @@ class TensorReplacer : public ir::IRMutator {
op->dtype, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
return this->VisitExpr(ret);
}
}
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
// whether it is found.
......@@ -216,13 +216,13 @@ class TensorReplacer : public ir::IRMutator {
Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
Stmt ret = repl(stmt);
return repl.found ? ret : stmt;
}
Expr ReplaceTensor(Expr expr,
const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
Expr ret = repl.Mutate(expr);
Expr ret = repl(expr);
return repl.found ? ret : expr;
}
......
......@@ -24,7 +24,6 @@
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./op_util.h"
......
......@@ -22,7 +22,7 @@
* \file tensorize.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include "op_util.h"
......@@ -157,10 +157,10 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
}
// Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public IRMutator {
class TensorIntrinMatcher final : public StmtExprMutator {
public:
Expr Mutate_(const Call* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Call* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
if (op->call_type == Call::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
......@@ -180,17 +180,17 @@ class TensorIntrinMatcher final : public IRMutator {
return expr;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
return e;
return GetRef<Expr>(op);
}
}
Expr Mutate_(const Reduce* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Reduce* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Reduce>();
Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) {
......@@ -317,7 +317,7 @@ Array<Expr> MatchTensorizeBody(
matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
Array<Expr> ret;
for (Expr expr : self->body) {
ret.push_back(matcher.Mutate(expr));
ret.push_back(matcher(expr));
}
return ret;
}
......
......@@ -23,9 +23,8 @@
// Instrument checkers for out of the bounds access.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <vector>
#include <unordered_map>
#include <utility>
......@@ -33,48 +32,48 @@
namespace tvm {
namespace ir {
class BoundCollector : public IRVisitor {
class BoundCollector : public StmtVisitor {
public:
BoundCollector() {}
void Visit_(const AttrStmt *op) {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == ir::attr::buffer_bound) {
if (const Variable *key = op->node.as<Variable>()) {
mem_to_shape[key] = op->value;
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape;
};
class BoundChecker : public IRMutator {
class BoundChecker : public StmtExprMutator {
public:
explicit BoundChecker(
const std::unordered_map<const Variable *, Expr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
Stmt Mutate_(const Allocate *op, const Stmt &s) final {
Stmt VisitStmt_(const Allocate* op) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->dtype);
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Expr Mutate_(const Call *op, const Expr &ex) final {
Expr VisitExpr_(const Call* op) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true;
}
return IRMutator::Mutate_(op, ex);
return StmtExprMutator::VisitExpr_(op);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
Stmt VisitStmt_(const Store* op) final {
store_scope_bound_collector_.clear();
process_store_ = true;
unsafe_rewritten_ = false;
IRMutator::Mutate_(op, s);
StmtExprMutator::VisitStmt_(op);
process_store_ = false;
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
......@@ -92,23 +91,24 @@ class BoundChecker : public IRMutator {
return body;
}
}
return s;
return GetRef<Stmt>(op);
}
Expr Mutate_(const Load *op, const Expr &ex) final {
Expr VisitExpr_(const Load* op) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
return IRMutator::Mutate_(op, ex);
return StmtExprMutator::VisitExpr_(op);
}
private:
bool UpdateIsNeeded(const VarExpr &buffer_var) const {
bool UpdateIsNeeded(const VarExpr& buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}
void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape,
const DataType &type) {
void Update(const VarExpr& buffer_var,
const Array<Expr>& new_shape,
const DataType& type) {
// Sanity check at first.
if (!new_shape.size()) {
return;
......@@ -132,7 +132,7 @@ class BoundChecker : public IRMutator {
mem_to_shape_[buffer_var.get()] = shape;
}
bool IndexIsValid(const Expr &index) const {
bool IndexIsValid(const Expr& index) const {
if (!index.defined()) {
return false;
}
......@@ -146,7 +146,7 @@ class BoundChecker : public IRMutator {
return true;
}
bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const {
bool CanInstrument(const Expr& index, const VarExpr& buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
IndexIsValid(index) && !unsafe_rewritten_;
}
......@@ -206,8 +206,8 @@ class BoundChecker : public IRMutator {
Stmt InstrumentBoundCheckers(Stmt stmt) {
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector.Visit(stmt);
return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt);
bound_collector(stmt);
return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
}
} // namespace ir
} // namespace tvm
......@@ -23,7 +23,7 @@
* \file combine_context_call.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <map>
......@@ -32,7 +32,7 @@ namespace ir {
// Calculate the statistics of packed function.
// These information are needed during codegen.
class ContextCallCombiner final : public IRMutator {
class ContextCallCombiner final : public StmtExprMutator {
public:
struct CompareExpr {
bool operator()(const Expr& lhs, const Expr& rhs) const {
......@@ -40,7 +40,7 @@ class ContextCallCombiner final : public IRMutator {
}
};
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U);
Expr ctx = op->args[0];
......@@ -60,39 +60,39 @@ class ContextCallCombiner final : public IRMutator {
return std::move(ctx_var);
}
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
return BuildContext(temp, stmt);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
if (op->for_type == ForType::Parallel) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
return BuildContext(temp, stmt);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Combine(Stmt stmt) {
return BuildContext(ctx_map_, this->Mutate(stmt));
return BuildContext(ctx_map_, this->VisitStmt(stmt));
}
private:
......
......@@ -22,8 +22,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
......@@ -33,25 +32,25 @@ namespace tvm {
namespace ir {
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public IRVisitor {
class CoProcTouchedBuffer : public StmtExprVisitor {
public:
void Visit_(const Load* op) final {
void VisitExpr_(const Load* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const Store* op) final {
void VisitStmt_(const Store* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
const Variable* buffer = op->args[1].as<Variable>();
if (in_scope_) {
......@@ -60,17 +59,17 @@ class CoProcTouchedBuffer : public IRVisitor {
touched_[buffer].normal = true;
}
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv = Downcast<IterVar>(op->node);
coproc_.insert(iv);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
in_scope_ = false;
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
......@@ -96,7 +95,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
}
void Plan(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
PlanSync(scope_.back(), nullptr, true);
if (sync_.size() == 0) {
sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync");
......@@ -218,14 +217,14 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
write_barrier_name_ = coproc_name + ".coproc_write_barrier";
}
void PlanReadBarrier(Stmt stmt) {
void PlanReadBarrier(const Stmt& stmt) {
read_barrier_ = true;
this->Visit(stmt);
this->VisitStmt(stmt);
PlanReadBarrier(scope_.back(), nullptr);
}
void PlanWriteBarrier(Stmt stmt) {
void PlanWriteBarrier(const Stmt& stmt) {
read_barrier_ = false;
this->Visit(stmt);
this->VisitStmt(stmt);
PlanWriteBarrier(scope_.back(), nullptr);
}
......@@ -356,7 +355,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
};
class CoProcInstDepDetector : public IRVisitor {
class CoProcInstDepDetector : public StmtVisitor {
public:
explicit CoProcInstDepDetector(
const IterVar& coproc_axis,
......@@ -366,15 +365,15 @@ class CoProcInstDepDetector : public IRVisitor {
sync_pop_name_ = coproc_name + ".coproc_dep_pop";
}
void Plan(Stmt stmt) {
this->Visit(stmt);
void Plan(const Stmt& stmt) {
this->VisitStmt(stmt);
if (last_state_.node != nullptr) {
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
}
}
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope &&
op->node.same_as(coproc_axis_)) {
const IntImm* ctx_id = op->value.as<IntImm>();
......@@ -385,15 +384,15 @@ class CoProcInstDepDetector : public IRVisitor {
curr_state_.exit_ctx.insert(ctx_id->value);
UpdateState();
} else {
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
}
void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
SyncState temp_first, temp_last;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
this->Visit(op->body);
this->VisitStmt(op->body);
curr_state_.clear();
if (last_state_.node != nullptr) {
curr_state_.node = op;
......@@ -412,13 +411,13 @@ class CoProcInstDepDetector : public IRVisitor {
}
}
void Visit_(const IfThenElse* op) final {
void VisitStmt_(const IfThenElse* op) final {
SyncState temp_first, temp_last, curr_state;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
{
// then stmt
this->Visit(op->then_case);
this->VisitStmt(op->then_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
......@@ -434,7 +433,7 @@ class CoProcInstDepDetector : public IRVisitor {
last_state_.clear();
}
if (op->else_case.defined()) {
this->Visit(op->else_case);
this->VisitStmt(op->else_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
......@@ -606,11 +605,11 @@ class CoProcInstDepDetector : public IRVisitor {
};
class CoProcSyncInserter : public IRMutator {
class CoProcSyncInserter : public StmtMutator {
public:
Stmt Insert(Stmt stmt) {
CoProcTouchedBuffer visitor;
visitor.Visit(stmt);
visitor(stmt);
if (visitor.coproc_.size() == 0) return stmt;
std::unordered_set<const Variable*> touched;
......@@ -652,10 +651,10 @@ class CoProcSyncInserter : public IRMutator {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
return Mutate(stmt);
return operator()(std::move(stmt));
}
Stmt Mutate(Stmt stmt) final {
Stmt VisitStmt(const Stmt& stmt) final {
Stmt before, after;
auto it = insert_before_.find(stmt.get());
if (it != insert_before_.end()) {
......@@ -666,14 +665,14 @@ class CoProcSyncInserter : public IRMutator {
if (it != insert_after_.end()) {
after = MergeSeq(it->second);
}
stmt = IRMutator::Mutate(stmt);
Stmt new_stmt = StmtMutator::VisitStmt(stmt);
if (before.defined()) {
stmt = Block::make(before, stmt);
new_stmt = Block::make(before, new_stmt);
}
if (after.defined()) {
stmt = Block::make(stmt, after);
new_stmt = Block::make(new_stmt, after);
}
return stmt;
return new_stmt;
}
private:
......@@ -685,7 +684,7 @@ class CoProcSyncInserter : public IRMutator {
Stmt CoProcSync(Stmt stmt) {
return CoProcSyncInserter().Insert(stmt);
return CoProcSyncInserter().Insert(std::move(stmt));
}
} // namespace ir
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -22,7 +22,6 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include "../pass/ir_util.h"
namespace tvm {
......
......@@ -21,9 +21,7 @@
* \file hoist_if_then_else.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <tvm/api_registry.h>
#include <unordered_map>
......
......@@ -23,8 +23,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
......@@ -35,7 +34,7 @@ namespace tvm {
namespace ir {
// Get fragment information from tensor intrinsics
class FragmentGetter : public IRVisitor {
class FragmentGetter : public StmtExprVisitor {
public:
// fragment metadata
struct FragmentInfo {
......@@ -48,8 +47,8 @@ class FragmentGetter : public IRVisitor {
: m(_m), n(_n), k(_k), layout(_layout) {}
};
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
void VisitExpr_(const Call* op) final {
StmtExprVisitor::VisitExpr_(op);
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
......@@ -116,13 +115,13 @@ class FragmentGetter : public IRVisitor {
}
// Get memory scope
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buffer = op->node.as<Variable>();
CHECK(buffer);
scopes[buffer] = op->value.as<StringImm>()->value;
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
// Memory scope for allocations
......@@ -132,11 +131,12 @@ class FragmentGetter : public IRVisitor {
};
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public IRVisitor {
class FragmentChecker : public StmtExprVisitor {
public:
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
CHECK_EQ(op->args.size(), 8U);
......@@ -170,12 +170,12 @@ class FragmentChecker : public IRVisitor {
};
// Store the metadata into attributes
class InferFragmenter : public IRMutator {
class InferFragmenter : public StmtMutator {
public:
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const Variable* buffer = op->buffer_var.get();
if (fragment_getter.fragments.count(buffer)) {
// Add attribute to fragments allocation
......@@ -206,9 +206,10 @@ class InferFragmenter : public IRMutator {
Stmt InferFragment(Stmt stmt) {
FragmentGetter getter;
getter.Visit(stmt);
FragmentChecker(getter).Visit(stmt);
stmt = InferFragmenter(getter).Mutate(stmt);
getter(stmt);
FragmentChecker checker(getter);
checker(stmt);
stmt = InferFragmenter(getter)(std::move(stmt));
return stmt;
}
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include "../arithmetic/pattern_match.h"
......@@ -32,7 +32,7 @@ namespace ir {
using runtime::PackedFunc;
class CopyIntrinInjector : public IRMutator {
class CopyIntrinInjector : public StmtMutator {
public:
CopyIntrinInjector(const std::string& pragma_key,
const PackedFunc& flower_copy_fromto)
......@@ -40,7 +40,7 @@ class CopyIntrinInjector : public IRMutator {
flower_copy_fromto_(flower_copy_fromto) {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = op->value.as<StringImm>()->value;
......@@ -50,7 +50,7 @@ class CopyIntrinInjector : public IRMutator {
<< "Cannot match copy pattern of " << op->body;
return ret;
}
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
private:
......@@ -193,8 +193,7 @@ class CopyIntrinInjector : public IRMutator {
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) {
return CopyIntrinInjector(pragma_key, flower_copy_fromto)
.Mutate(stmt);
return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
}
} // namespace ir
......
......@@ -22,8 +22,7 @@
* \file inject_double_buffer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/expr_operator.h>
#include "ir_util.h"
#include "../arithmetic/compute_expr.h"
......@@ -32,18 +31,18 @@ namespace tvm {
namespace ir {
// Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor {
class DoubleBufferDetector : public StmtExprVisitor {
public:
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<Variable>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Variable* op) final {
void VisitExpr_(const Variable* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
......@@ -53,55 +52,55 @@ class DoubleBufferDetector : public IRVisitor {
};
class StripDoubleBufferWrite : public IRMutator {
class StripDoubleBufferWrite : public StmtMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_write) {
return Mutate(op->body);
return VisitStmt(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
}
};
class DoubleBufferInjector : public IRMutator {
class DoubleBufferInjector : public StmtExprMutator {
public:
explicit DoubleBufferInjector(int split_loop)
: split_loop_(split_loop) {}
Stmt Inject(const Stmt& stmt) {
Stmt Inject(Stmt stmt) {
DoubleBufferDetector detector;
detector.Visit(stmt);
detector(stmt);
if (detector.touched_.empty()) return stmt;
for (const Variable* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry();
}
return ConvertSSA(this->Mutate(stmt));
return ConvertSSA(operator()(std::move(stmt)));
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
it->second.scope = op->value.as<StringImm>()->value;
return Mutate(op->body);
return this->VisitStmt(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
} else if (op->attr_key == attr::double_buffer_scope) {
return MakeProducer(op, s);
return MakeProducer(op);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul>(
op->extents, Expr()) * op->dtype.lanes();
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
for (Expr e : op->extents) {
......@@ -118,13 +117,13 @@ class DoubleBufferInjector : public IRMutator {
Evaluate::make(0)));
return op->body;
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
loop_nest_.push_back(op);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>();
......@@ -151,7 +150,7 @@ class DoubleBufferInjector : public IRMutator {
MergeSeq(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body);
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) {
Expr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx;
......@@ -171,8 +170,8 @@ class DoubleBufferInjector : public IRMutator {
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
......@@ -188,8 +187,8 @@ class DoubleBufferInjector : public IRMutator {
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
......@@ -205,20 +204,20 @@ class DoubleBufferInjector : public IRMutator {
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
CHECK(!dbuffer_info_.count(op));
return e;
return GetRef<Expr>(op);
}
private:
Stmt MakeProducer(const AttrStmt* op, const Stmt& s) {
Stmt MakeProducer(const AttrStmt* op) {
const VarExpr buffer = Downcast<VarExpr>(op->node);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node;
return Mutate(op->body);
return this->VisitStmt(op->body);
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
......@@ -230,7 +229,7 @@ class DoubleBufferInjector : public IRMutator {
e.loop->loop_var.dtype());
e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
in_double_buffer_scope_ = false;
std::unordered_map<const Variable*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero;
......
......@@ -22,8 +22,7 @@
*/
// Inject prefetch op in HalideIR
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <unordered_set>
......@@ -34,10 +33,10 @@ namespace ir {
using arith::IntSet;
using arith::DomainTouched;
class PrefetchInjector : public IRMutator {
class PrefetchInjector : public StmtMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt ret = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const AttrStmt* op) final {
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmt>();
if (op && op->attr_key == attr::prefetch_scope) {
Tensor ts = Downcast<Tensor>(op->node);
......@@ -65,13 +64,13 @@ class PrefetchInjector : public IRMutator {
return ret;
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
auto &var = op->loop_var;
loop_nest_.push_back(var);
if (op->for_type == ForType::Vectorized) {
vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1);
}
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtMutator::VisitStmt_(op);
if (op->for_type == ForType::Vectorized) {
vectorized_.erase(var.get());
}
......@@ -88,7 +87,7 @@ class PrefetchInjector : public IRMutator {
const Range PrefetchInjector::none;
Stmt InjectPrefetch(Stmt stmt) {
return PrefetchInjector().Mutate(stmt);
return PrefetchInjector()(std::move(stmt));
}
} // namespace ir
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,8 +21,8 @@
* \file inline.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
......@@ -30,13 +30,13 @@ namespace ir {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline final : public IRMutator {
class IRInline final : public StmtExprMutator {
public:
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Call* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
if (op->func == f_) {
......@@ -78,7 +78,7 @@ Stmt Inline(Stmt stmt,
Expr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
Stmt ret = IRInline(f, args, body).Mutate(stmt);
Stmt ret = IRInline(f, args, body)(std::move(stmt));
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
}
......
......@@ -21,7 +21,6 @@
*/
#include <tvm/ir_functor_ext.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
......
......@@ -24,7 +24,7 @@
* \file lift_attr_scope.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include "ir_util.h"
namespace tvm {
......@@ -32,13 +32,13 @@ namespace ir {
// NOTE: this optimization can only be applied
// to a few specified attr keys
class AttrScopeLifter : public IRMutator {
class AttrScopeLifter : public StmtMutator {
public:
explicit AttrScopeLifter(std::string attr_key)
: attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) {
stmt = Mutate(stmt);
stmt = operator()(std::move(stmt));
if (attr_node_.defined()) {
stmt = AttrStmt::make(
attr_node_, attr_key_, attr_value_, stmt);
......@@ -47,8 +47,8 @@ class AttrScopeLifter : public IRMutator {
}
// do not go beyond
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
if (attr_node_.defined()) {
Stmt body = AttrStmt::make(
......@@ -65,17 +65,17 @@ class AttrScopeLifter : public IRMutator {
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr_key_) {
attr_node_ = op->node;
attr_value_ = op->value;
return op->body;
} else {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt VisitStmt_(const Block* op) final {
std::vector<Stmt> seq;
FlattenSeq(op->first, &seq);
FlattenSeq(op->rest, &seq);
......@@ -83,21 +83,21 @@ class AttrScopeLifter : public IRMutator {
if (seq.size() == 2 &&
seq[0].same_as(op->first) &&
seq[1].same_as(op->rest)) {
return s;
return GetRef<Stmt>(op);
}
return MergeSeq(seq);
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt VisitStmt_(const IfThenElse* op) final {
if (!op->else_case.defined()) {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
Stmt then_case = this->Mutate(op->then_case);
Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node;
Expr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->Mutate(op->else_case);
Stmt else_case = this->VisitStmt(op->else_case);
if (attr_node_.defined() &&
attr_value_.defined() &&
first_node.defined() &&
......@@ -106,7 +106,7 @@ class AttrScopeLifter : public IRMutator {
ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
......@@ -124,7 +124,7 @@ class AttrScopeLifter : public IRMutator {
}
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
......@@ -155,7 +155,7 @@ class AttrScopeLifter : public IRMutator {
for (const Stmt & stmt : seq) {
attr_node_ = ObjectRef();
attr_value_ = Expr();
Stmt rest = this->Mutate(stmt);
Stmt rest = this->VisitStmt(stmt);
if (attr_node_.defined() &&
attr_value_.defined() &&
curr_node.defined() &&
......@@ -214,7 +214,7 @@ class AttrScopeLifter : public IRMutator {
};
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(stmt);
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
}
} // namespace ir
......
......@@ -21,7 +21,7 @@
* \brief Pass for lowering custom datatypes
*/
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h>
#include "../codegen/datatype/registry.h"
......@@ -37,17 +37,17 @@ namespace ir {
* datatype) for lowering this type of expression, and uses it to lower the
* expression.
*/
class CustomDatatypesLowerer : public IRMutator {
class CustomDatatypesLowerer : public StmtExprMutator {
public:
explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
inline Expr Mutate_(const Cast* op, const Expr& e) final {
inline Expr VisitExpr_(const Cast* op) final {
auto type_code = op->dtype.code();
auto src_type_code = op->value.dtype().code();
// If either datatype is a registered custom datatype, we must lower.
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Cast>();
if (toBeLowered) {
auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
......@@ -59,8 +59,9 @@ class CustomDatatypesLowerer : public IRMutator {
return expr;
}
inline Expr Mutate_(const FloatImm* imm, const Expr& e) final {
inline Expr VisitExpr_(const FloatImm* imm) final {
auto type_code = imm->dtype.code();
auto e = GetRef<Expr>(imm);
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
CHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
......@@ -70,9 +71,9 @@ class CustomDatatypesLowerer : public IRMutator {
return e;
}
inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final {
inline Stmt VisitStmt_(const Allocate* allocate) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
Stmt stmt = IRMutator::Mutate_(allocate, s);
Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
allocate = stmt.as<Allocate>();
if (toBeLowered) {
......@@ -84,9 +85,9 @@ class CustomDatatypesLowerer : public IRMutator {
return stmt;
}
inline Expr Mutate_(const Load* load, const Expr& e) final {
inline Expr VisitExpr_(const Load* load) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
Expr expr = IRMutator::Mutate_(load, e);
Expr expr = StmtExprMutator::VisitExpr_(load);
load = expr.as<Load>();
if (toBeLowered) {
auto new_load_type = DataType::UInt(load->dtype.bits());
......@@ -96,10 +97,10 @@ class CustomDatatypesLowerer : public IRMutator {
}
#define DEFINE_MUTATE__(OP) \
inline Expr Mutate_(const OP* op, const Expr& e) final { \
inline Expr VisitExpr_(const OP* op) final { \
auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
Expr expr = IRMutator::Mutate_(op, e); \
Expr expr = StmtExprMutator::VisitExpr_(op); \
op = expr.as<OP>(); \
if (toBeLowered) { \
auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
......@@ -131,7 +132,7 @@ class CustomDatatypesLowerer : public IRMutator {
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = CustomDatatypesLowerer(target).Mutate(n->body);
n->body = CustomDatatypesLowerer(target)(n->body);
return LoweredFunc(n);
}
......
......@@ -22,7 +22,6 @@
* \file lower_intrin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include <tvm/expr_operator.h>
......@@ -34,9 +33,10 @@
namespace tvm {
namespace ir {
class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::Mutate_;
using IRMutatorWithAnalyzer::VisitStmt_;
using IRMutatorWithAnalyzer::VisitExpr_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target)
: IRMutatorWithAnalyzer(analyzer) {
......@@ -51,28 +51,29 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
Expr r = ApplyPattern(op->name, e);
Expr r = ApplyPattern(op->name, GetRef<Expr>(op));
if (r.defined()) return r;
}
return IRMutator::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr Mutate_(const Add* op, const Expr& e) final {
Expr VisitExpr_(const Add* op) final {
if (const Mul* mb = op->b.as<Mul>()) {
return MakeFMA(mb->a, mb->b, op->a, op, e);
return MakeFMA(mb->a, mb->b, op->a, op);
} else if (const Mul* ma = op->a.as<Mul>()) {
return MakeFMA(ma->a, ma->b, op->b, op, e);
return MakeFMA(ma->a, ma->b, op->b, op);
}
return IRMutator::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
Expr Mutate_(const FloorDiv* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
Expr VisitExpr_(const FloorDiv* op) final {
auto e = GetRef<Expr>(op);
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDiv>();
if (op == nullptr) return ret;
int shift;
......@@ -117,8 +118,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
}
Expr Mutate_(const FloorMod* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
Expr VisitExpr_(const FloorMod* op) final {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorMod>();
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
......@@ -167,34 +168,37 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
}
Expr Mutate_(const Max* op, const Expr& e) final {
Expr VisitExpr_(const Max* op) final {
using namespace arith;
PVar<Expr> x, y;
PVar<Integer> c;
auto e = GetRef<Expr>(op);
if (max(floordiv(x, y), c).Match(e) &&
c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(Mutate(truncdiv(x, y).Eval()), c.Eval());
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr Mutate_(const EQ* op, const Expr& e) final {
Expr VisitExpr_(const EQ* op) final {
using namespace arith;
PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
if ((floormod(x, y) == 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
return VisitExpr((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr Mutate_(const NE* op, const Expr& e) final {
Expr VisitExpr_(const NE* op) final {
using namespace arith;
PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
if ((floormod(x, y) != 0).Match(e)) {
return Mutate((truncmod(x, y) != 0).Eval());
return VisitExpr((truncmod(x, y) != 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
private:
......@@ -231,7 +235,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
const Add* op, const Expr& e) {
const Add* op) {
// emit fma instruction: a * b + c
Expr lhs = SwapBroadcastCast(a);
Expr rhs = SwapBroadcastCast(b);
......@@ -239,14 +243,14 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (fma_ != nullptr && op->dtype.is_float()) {
Expr r = (*fma_)(Call::make(
op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
if (r.defined()) return this->Mutate(r);
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
Expr mul = this->Mutate(Mul::make(lhs, rhs));
return Add::make(mul, this->Mutate(c));
Expr mul = this->VisitExpr(Mul::make(lhs, rhs));
return Add::make(mul, this->VisitExpr(c));
}
}
return IRMutator::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr ApplyPattern(const std::string& name, const Expr& e) {
......@@ -262,7 +266,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
Expr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
return this->Mutate(r);
return this->VisitExpr(r);
}
}
}
......@@ -277,7 +281,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target).Mutate(stmt);
return IntrinInjecter(&analyzer, target)(std::move(stmt));
}
LoweredFunc
......
......@@ -22,7 +22,7 @@
* \file lower_thread_allreduce.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
......@@ -32,19 +32,19 @@
namespace tvm {
namespace ir {
class ThreadAllreduceBuilder final : public IRMutator {
class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
return ret;
} else if (op->attr_key == attr::storage_scope) {
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmt>();
const Variable* v = op->node.as<Variable>();
if (alloc_remap_.count(v)) {
......@@ -56,15 +56,15 @@ class ThreadAllreduceBuilder final : public IRMutator {
const CommReducerNode *combiner = op->node.as<CommReducerNode>();
CHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
reduce_combiner_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Evaluate* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Evaluate>();
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
......@@ -73,8 +73,8 @@ class ThreadAllreduceBuilder final : public IRMutator {
return stmt;
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
......@@ -93,13 +93,13 @@ class ThreadAllreduceBuilder final : public IRMutator {
return stmt;
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr VisitExpr_(const Load* op) final {
auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) {
CHECK(is_zero(op->index)) << e;
CHECK(is_zero(op->index));
return it->second;
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
......@@ -339,7 +339,7 @@ LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
n->body = ThreadAllreduceBuilder(warp_size)(n->body);
return LoweredFunc(n);
}
} // namespace ir
......
......@@ -22,7 +22,7 @@
* \file lower_tvm_buildin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
......@@ -43,14 +43,14 @@ inline Expr StackAlloca(std::string type, size_t num) {
// Calculate the statistics of packed function.
// These information are needed during codegen.
class BuiltinLower : public IRMutator {
class BuiltinLower : public StmtExprMutator {
public:
Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", DataType::Handle());
stack_array_ = Var("stack_array", DataType::Handle());
stack_value_ = Var("stack_value", DataType::Handle());
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->Mutate(stmt);
stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) {
stmt = LetStmt::make(
stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
......@@ -68,8 +68,8 @@ class BuiltinLower : public IRMutator {
return stmt;
}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
Stmt VisitStmt(const Stmt& s) final {
auto stmt = StmtExprMutator::VisitStmt(s);
CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_array_stack_, 0);
while (prep_seq_.size() != 0) {
......@@ -79,9 +79,9 @@ class BuiltinLower : public IRMutator {
return stmt;
}
Stmt Mutate_(const Allocate* op, const Stmt& s) {
Stmt VisitStmt_(const Allocate* op) {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound.
......@@ -141,39 +141,39 @@ class BuiltinLower : public IRMutator {
return body;
}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined());
device_id_ = op->value;
return Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::device_context_type) {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
return this->VisitStmt(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
return MakeCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
return MakeCallTracePacked(op, e);
return MakeCallTracePacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
return MakeShape(op);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e);
return MakeArray(op);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return make_zero(op->dtype);
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
// call shape
Expr MakeShape(const Call* op, const Expr& e) {
Expr MakeShape(const Call* op) {
size_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
......@@ -183,10 +183,10 @@ class BuiltinLower : public IRMutator {
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
// make array
Expr MakeArray(const Call* op, const Expr& e) {
Expr MakeArray(const Call* op) {
size_t idx = run_array_stack_;
run_array_stack_ += 1;
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
......@@ -230,13 +230,13 @@ class BuiltinLower : public IRMutator {
return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packed.
Expr MakeCallPacked(const Call* op, const Expr& e) {
Expr MakeCallPacked(const Call* op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
......@@ -278,14 +278,14 @@ class BuiltinLower : public IRMutator {
packed_args, Call::Intrinsic);
}
Expr MakeCallTracePacked(const Call *op, const Expr &e) {
Expr MakeCallTracePacked(const Call *op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
size_t args_size = op->args.size();
CHECK_GT(args_size, 0);
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
......
......@@ -26,8 +26,7 @@
// Thanks to Andrew Adams and Vinod Grover for
// explaining the concept of warp shuffle.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
......@@ -75,7 +74,7 @@ namespace ir {
// Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
class WarpStoreCoeffFinder : private IRVisitor {
class WarpStoreCoeffFinder : private StmtVisitor {
public:
WarpStoreCoeffFinder(const Variable* buffer,
Var warp_index,
......@@ -86,13 +85,13 @@ class WarpStoreCoeffFinder : private IRVisitor {
}
// find the warp co-efficient in the statement given the warp size
int Find(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
return warp_coeff_;
}
private:
/// Visitor implementation
void Visit_(const Store *op) final {
void VisitStmt_(const Store *op) final {
if (op->buffer_var.get() == buffer_) {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
......@@ -104,7 +103,7 @@ class WarpStoreCoeffFinder : private IRVisitor {
UpdatePattern(base);
}
} else {
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
}
......@@ -141,14 +140,14 @@ class WarpStoreCoeffFinder : private IRVisitor {
// Visitor to find the warp index
class WarpIndexFinder : private IRVisitor {
class WarpIndexFinder : private StmtVisitor {
public:
explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) {
}
// find the warp co-efficient in the statement given the warp size
IterVar Find(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_;
......@@ -156,7 +155,7 @@ class WarpIndexFinder : private IRVisitor {
private:
/// Visitor implementation
void Visit_(const AttrStmt *op) final {
void VisitStmt_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
......@@ -177,7 +176,7 @@ class WarpIndexFinder : private IRVisitor {
}
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
// warp size
int warp_size_{0};
......@@ -185,13 +184,13 @@ class WarpIndexFinder : private IRVisitor {
IterVar warp_index_{nullptr};
};
// Mutator to change the read pattern
class WarpAccessRewriter : protected IRMutator {
class WarpAccessRewriter : protected StmtExprMutator {
public:
explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer)
: warp_size_(warp_size), analyzer_(analyzer) {}
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const Allocate* op, const Stmt& stmt) {
Stmt Rewrite(const Allocate* op) {
buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size();
CHECK_GT(alloc_size, 0)
......@@ -208,27 +207,27 @@ class WarpAccessRewriter : protected IRMutator {
op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)},
op->condition,
this->Mutate(op->body));
this->VisitStmt(op->body));
}
protected:
Expr Mutate_(const Variable* op, const Expr& expr) {
Expr Mutate_(const Variable* op) {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}
Stmt Mutate_(const Store* op, const Stmt& stmt) {
Stmt VisitStmt_(const Store* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
return Store::make(op->buffer_var, op->value, local_index, op->predicate);
} else {
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Load* op, const Expr& expr) {
Expr Mutate_(const Load* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
......@@ -243,7 +242,7 @@ class WarpAccessRewriter : protected IRMutator {
{load_value, group},
Call::Intrinsic);
} else {
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}
}
// Split the index to the two component
......@@ -297,18 +296,18 @@ class WarpAccessRewriter : protected IRMutator {
// Bind bound information of variables to make analyzer more effective
// TODO(tqchen): consider a pass to inline the bound info into the expr
// so analysis can be context independent.
class BindVarBoundInfo : public IRVisitor {
class BindVarBoundInfo : public StmtVisitor {
public:
explicit BindVarBoundInfo(arith::Analyzer* analyzer)
: analyzer_(analyzer) {}
void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
const Var& loop_var = op->loop_var;
analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
void Visit_(const AttrStmt* op) {
void VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -319,7 +318,7 @@ class BindVarBoundInfo : public IRVisitor {
analyzer_->Bind(iv->var, dom);
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
protected:
......@@ -330,7 +329,7 @@ class BindVarBoundInfo : public IRVisitor {
};
// Mutator to change the read pattern
class WarpMemoryRewriter : private IRMutator {
class WarpMemoryRewriter : private StmtMutator {
public:
explicit WarpMemoryRewriter(int warp_size)
: warp_size_(warp_size) {
......@@ -338,36 +337,37 @@ class WarpMemoryRewriter : private IRMutator {
Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt;
BindVarBoundInfo(&analyzer_).Visit(stmt);
stmt = this->Mutate(stmt);
BindVarBoundInfo binder(&analyzer_);
binder(stmt);
stmt = operator()(std::move(stmt));
stmt = CanonicalSimplify(stmt);
return stmt;
}
private:
Stmt Mutate_(const Allocate* op, const Stmt& stmt) {
Stmt VisitStmt_(const Allocate* op) {
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_, &analyzer_);
return rewriter.Rewrite(op, stmt);
return rewriter.Rewrite(op);
} else {
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
Stmt VisitStmt_(const AttrStmt* op) {
using runtime::StorageScope;
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
if (scope.rank == runtime::StorageRank::kWarp) {
warp_buffer_.insert(buf);
Stmt ret = IRMutator::Mutate_(op, stmt);
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmt>();
return AttrStmt::make(
op->node, op->attr_key, StringImm::make("local"), op->body);
}
}
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
int warp_size_{0};
......
......@@ -22,8 +22,7 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <vector>
......@@ -207,29 +206,29 @@ LoweredFunc MakeAPI(Stmt body,
return f;
}
class DeviceTypeBinder: public IRMutator {
class DeviceTypeBinder: public StmtExprMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) {
var_ = var;
Expr value = make_const(op->value.dtype(), device_type_);
Stmt body = IRMutator::Mutate_(op, s);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmt::make(op->value == value, os.str(), body);
}
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt VisitStmt_(const IfThenElse* op) final {
// eager simplify if guard.
Stmt res = IRMutator::Mutate_(op, s);
Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElse>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
......@@ -241,9 +240,9 @@ class DeviceTypeBinder: public IRMutator {
return res;
}
Expr Mutate_(const NE* op, const Expr& e) final {
Expr VisitExpr_(const NE* op) final {
// eager check NE for device check
Expr res = IRMutator::Mutate_(op, e);
Expr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NE>();
if (ir::Equal(op->a, op->b)) {
return make_const(op->dtype, false);
......@@ -251,11 +250,11 @@ class DeviceTypeBinder: public IRMutator {
return res;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
return e;
return GetRef<Expr>(op);
}
}
......@@ -267,7 +266,7 @@ class DeviceTypeBinder: public IRMutator {
LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type).Mutate(n->body);
n->body = DeviceTypeBinder(device_type)(n->body);
return LoweredFunc(n);
}
......
......@@ -21,8 +21,7 @@
* \file remap_thread_axis.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
......@@ -31,7 +30,7 @@ namespace tvm {
namespace ir {
// Mutator to change the read pattern
class ThreadAxisRewriter : private IRMutator {
class ThreadAxisRewriter : private StmtExprMutator {
public:
explicit ThreadAxisRewriter(
const std::unordered_map<std::string, IterVar>& tmap)
......@@ -39,11 +38,11 @@ class ThreadAxisRewriter : private IRMutator {
}
Stmt Rewrite(Stmt stmt) {
return Mutate(stmt);
return operator()(std::move(stmt));
}
private:
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
......@@ -56,18 +55,18 @@ class ThreadAxisRewriter : private IRMutator {
} else {
CHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
return AttrStmt::make(
new_iv, op->attr_key, op->value, body);
}
}
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
Expr Mutate_(const Variable* op, const Expr& expr) final {
Expr VisitExpr_(const Variable* op) final {
auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second;
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}
// The thread map
const std::unordered_map<std::string, IterVar>& tmap_;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -23,30 +23,30 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_map>
namespace tvm {
namespace ir {
// Mark the statment of each stage.
class NoOpRemover : public IRMutator {
class NoOpRemover : public StmtMutator {
public:
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const LetStmt* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<LetStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == "pragma_debug_skip_region") {
return MakeEvaluate(0);
}
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const IfThenElse* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<IfThenElse>();
if (op->else_case.defined()) {
if (is_no_op(op->else_case)) {
......@@ -66,35 +66,35 @@ class NoOpRemover : public IRMutator {
}
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const For* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<For>();
if (is_zero(op->extent)) {
return Evaluate::make(0);
}
return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const ProducerConsumer* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ProducerConsumer>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Realize* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Realize>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
if (HasSideEffect(op->value)) return s;
Stmt VisitStmt_(const Evaluate* op) final {
if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
return Evaluate::make(0);
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Block* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Block>();
if (is_no_op(op->first)) {
return op->rest;
......@@ -129,7 +129,7 @@ class NoOpRemover : public IRMutator {
};
Stmt RemoveNoOp(Stmt stmt) {
return NoOpRemover().Mutate(stmt);
return NoOpRemover()(std::move(stmt));
}
} // namespace ir
} // namespace tvm
......@@ -23,7 +23,6 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
......@@ -109,10 +108,10 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
}
};
class UnsafeSelectRewriter : public IRMutator {
class UnsafeSelectRewriter : public StmtExprMutator {
public:
Expr Mutate_(const Select* op, const Expr& e) {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Select* op) {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Select>();
UnsafeExprDetector unsafe;
bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
......@@ -131,7 +130,7 @@ class UnsafeSelectRewriter : public IRMutator {
};
Stmt RewriteUnsafeSelect(Stmt stmt) {
return UnsafeSelectRewriter().Mutate(stmt);
return UnsafeSelectRewriter()(std::move(stmt));
}
} // namespace ir
......
......@@ -22,25 +22,24 @@
* \brief Implementation of simple passes
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
class IRSideEffect : public IRVisitor {
class IRSideEffect : public ExprVisitor {
public:
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (has_side_effect_) return;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
if (!op->is_pure()) {
has_side_effect_ = true; return;
} else {
IRVisitor::Visit_(op);
ExprVisitor::VisitExpr_(op);
}
}
......@@ -49,23 +48,23 @@ class IRSideEffect : public IRVisitor {
bool HasSideEffect(const Expr& e) {
IRSideEffect v;
v.Visit(e);
v(e);
return v.has_side_effect_;
}
class IRSubstitue : public IRMutator {
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(
const std::unordered_map<const Variable*, Expr>& smap)
: smap_(smap) {
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = smap_.find(op);
if (it != smap_.end()) {
return it->second;
} else {
return e;
return GetRef<Expr>(op);
}
}
......@@ -76,13 +75,13 @@ class IRSubstitue : public IRMutator {
Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return stmt;
return IRSubstitue(value_map).Mutate(stmt);
return IRSubstitue(value_map)(std::move(stmt));
}
Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return expr;
return IRSubstitue(value_map).Mutate(expr);
return IRSubstitue(value_map)(std::move(expr));
}
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
......@@ -101,20 +100,20 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
return Substitute(expr, vmap);
}
class VarTouchVisitor : public IRVisitor {
class VarTouchVisitor : public ExprVisitor {
public:
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (use_var_) return;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
void Visit_(const Variable* op) final {
void VisitExpr_(const Variable* op) final {
Handle(op);
}
void Visit_(const Load* op) final {
void VisitExpr_(const Load* op) final {
Handle(op->buffer_var.get());
IRVisitor::Visit_(op);
ExprVisitor::VisitExpr_(op);
}
virtual void Handle(const Variable* var) = 0;
......@@ -149,14 +148,14 @@ class ExprUseVSetVisitor : public VarTouchVisitor {
bool ExprUseVar(const Expr& e, const Var& v) {
ExprUseVarVisitor visitor(v.get());
visitor.Visit(e);
visitor(e);
return visitor.use_var_;
}
bool ExprUseVar(const Expr& e,
const std::unordered_set<const Variable*>& vset) {
ExprUseVSetVisitor visitor(vset);
visitor.Visit(e);
visitor(e);
return visitor.use_var_;
}
......
......@@ -19,22 +19,22 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
class AssertSkipper : public IRMutator {
class AssertSkipper : public StmtMutator {
public:
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const AssertStmt* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AssertStmt>();
return op->body;
}
};
Stmt SkipAssert(Stmt stmt) {
return AssertSkipper().Mutate(stmt);
return AssertSkipper()(std::move(stmt));
}
LoweredFunc SkipAssert(LoweredFunc f) {
......
......@@ -24,7 +24,7 @@
#include <tvm/ir.h>
#include <tvm/lowered_func.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/runtime/module.h>
#include <unordered_map>
......@@ -32,9 +32,9 @@ namespace tvm {
namespace ir {
// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public IRMutator {
class IRUseDefAnalysis : public StmtExprMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
......@@ -48,75 +48,77 @@ class IRUseDefAnalysis : public IRMutator {
Expr value = op->value;
if (visit_thread_extent_) {
value = this->Mutate(value);
value = this->VisitExpr(value);
}
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) return s;
return AttrStmt::make(op->node, op->attr_key, value, body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const LetStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const LetStmt* op) final {
this->HandleDef(op->var.get());
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
Expr value = this->VisitExpr(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return s;
return GetRef<Stmt>(op);
} else {
return LetStmt::make(op->var, value, body);
}
}
}
Stmt Mutate_(const For *op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
this->HandleDef(op->loop_var.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Allocate *op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
this->HandleDef(op->buffer_var.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Store *op, const Stmt& s) final {
Stmt VisitStmt_(const Store* op) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Expr Mutate_(const Let *op, const Expr& e) final {
Expr VisitExpr_(const Let* op) final {
this->HandleDef(op->var.get());
Expr body = this->Mutate(op->body);
Expr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
Expr value = this->VisitExpr(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return e;
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
}
}
Expr Mutate_(const Variable *op, const Expr& e) final {
this->HandleUse(e);
return IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Variable* op) final {
this->HandleUse(GetRef<Expr>(op));
return StmtExprMutator::VisitExpr_(op);
}
Expr Mutate_(const Load *op, const Expr& e) final {
Expr VisitExpr_(const Load* op) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
void HandleDef(const Variable* v) {
......@@ -154,20 +156,20 @@ class IRUseDefAnalysis : public IRMutator {
std::unordered_map<const Variable*, int> def_count_;
};
class HostDeviceSplitter : public IRMutator {
class HostDeviceSplitter : public StmtMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
return SplitDeviceFunc(s);
return SplitDeviceFunc(GetRef<Stmt>(op));
}
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
Array<LoweredFunc> Split(LoweredFunc f) {
......@@ -178,7 +180,7 @@ class HostDeviceSplitter : public IRMutator {
name_ = f->name;
ObjectPtr<LoweredFuncNode> n =
make_object<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
n->body = operator()(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
......@@ -195,7 +197,7 @@ class HostDeviceSplitter : public IRMutator {
// isolate the device function.
IRUseDefAnalysis m;
m.visit_thread_extent_ = false;
n->body = m.Mutate(body);
n->body = m(std::move(body));
n->name = os.str();
n->func_type = kDeviceFunc;
n->thread_axis = m.thread_axis_;
......@@ -243,7 +245,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
for (Var arg : args) {
m.use_count_[arg.get()] = 0;
}
m.Mutate(stmt);
m(stmt);
return m.undefined_;
}
......
......@@ -24,8 +24,7 @@
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include <unordered_map>
......@@ -34,29 +33,33 @@
namespace tvm {
namespace ir {
namespace {
class IRVerifySSA final : public IRVisitor {
class IRVerifySSA final : public StmtExprVisitor {
public:
bool is_ssa{true};
void Visit(const ObjectRef& n) final {
void VisitExpr(const Expr& n) final {
if (!is_ssa) return;
IRVisitor::Visit(n);
StmtExprVisitor::VisitExpr(n);
}
void Visit_(const Let* op) final {
void VisitStmt(const Stmt& n) final {
if (!is_ssa) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr_(const Let* op) final {
MarkDef(op->var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const LetStmt* op) final {
void VisitStmt_(const LetStmt* op) final {
MarkDef(op->var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
MarkDef(op->loop_var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const Allocate* op) final {
void VisitStmt_(const Allocate* op) final {
MarkDef(op->buffer_var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
private:
......@@ -70,31 +73,32 @@ class IRVerifySSA final : public IRVisitor {
std::unordered_map<const Variable*, int> defined_;
};
class IRConvertSSA final : public IRMutator {
class IRConvertSSA final : public StmtExprMutator {
public:
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
if (scope_.count(op)) {
return scope_[op].back();
} else {
return e;
return GetRef<Expr>(op);
}
}
Expr Mutate_(const Let* op, const Expr& e) final {
Expr VisitExpr_(const Let* op) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
Expr value = this->VisitExpr(op->value);
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Expr body = IRMutator::Mutate(op->body);
Expr body = this->VisitExpr(op->body);
scope_[v.get()].pop_back();
return Let::make(new_var, value, body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
if (scope_.count(op->buffer_var.get())) {
return Load::make(
......@@ -104,8 +108,8 @@ class IRConvertSSA final : public IRMutator {
return expr;
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
if (scope_.count(op->buffer_var.get())) {
return Store::make(
......@@ -115,41 +119,41 @@ class IRConvertSSA final : public IRMutator {
return stmt;
}
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const LetStmt* op) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
Expr value = this->VisitExpr(op->value);
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt body = IRMutator::Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
scope_[v.get()].pop_back();
return LetStmt::make(new_var, value, body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
const VarExpr& v = op->loop_var;
if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<For>();
return For::make(
new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
const VarExpr& v = op->buffer_var;
if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<Allocate>();
return Allocate::make(
......@@ -157,23 +161,23 @@ class IRConvertSSA final : public IRMutator {
op->body, op->new_expr, op->free_function);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (const Variable* v = op->node.as<Variable>()) {
if (op->attr_key == attr::storage_scope) {
const Allocate* alloc = op->body.as<Allocate>();
if (alloc && op->node.same_as(alloc->buffer_var)) {
Stmt new_alloc = Mutate(op->body);
if (new_alloc.same_as(op->body)) return s;
Stmt new_alloc = this->VisitStmt(op->body);
if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
alloc = new_alloc.as<Allocate>();
CHECK(alloc);
return AttrStmt::make(
alloc->buffer_var, op->attr_key, op->value, new_alloc);
}
}
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>();
if (scope_.count(v) && scope_[v].size() != 0) {
return AttrStmt::make(
......@@ -182,7 +186,7 @@ class IRConvertSSA final : public IRMutator {
return stmt;
}
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
......@@ -194,13 +198,13 @@ class IRConvertSSA final : public IRMutator {
} // namespace
bool VerifySSA(const Stmt& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
IRVerifySSA visitor;
visitor(ir);
return visitor.is_ssa;
}
Stmt ConvertSSA(Stmt stmt) {
return IRConvertSSA().Mutate(stmt);
return IRConvertSSA()(std::move(stmt));
}
} // namespace ir
......
......@@ -21,7 +21,6 @@
* \file storage_access.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/target_info.h>
#include <string>
#include <utility>
......@@ -32,7 +31,7 @@
namespace tvm {
namespace ir {
void StorageAccessVisitor::Visit_(const Load* op) {
void StorageAccessVisitor::VisitExpr_(const Load* op) {
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) {
......@@ -47,10 +46,10 @@ void StorageAccessVisitor::Visit_(const Load* op) {
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void StorageAccessVisitor::Visit_(const Store* op) {
void StorageAccessVisitor::VisitStmt_(const Store* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
......@@ -67,7 +66,7 @@ void StorageAccessVisitor::Visit_(const Store* op) {
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
......@@ -75,11 +74,11 @@ void StorageAccessVisitor::Visit_(const Store* op) {
allow_append_ = false;
}
void StorageAccessVisitor::Visit_(const Evaluate* op) {
void StorageAccessVisitor::VisitStmt_(const Evaluate* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// push to the scope
if (curr_stmt_.access.size() != 0) {
scope_.back().push_back(curr_stmt_);
......@@ -88,17 +87,17 @@ void StorageAccessVisitor::Visit_(const Evaluate* op) {
allow_append_ = false;
}
void StorageAccessVisitor::Visit_(const AttrStmt* op) {
void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::double_buffer_write) {
CHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<Variable>();
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
......@@ -115,7 +114,7 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
} else if (op->attr_key == attr::coproc_scope) {
IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
env_threads_.CopyOnWrite()->data.pop_back();
} else if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -123,23 +122,23 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
if (!in_device_env_) {
in_device_env_ = true;
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// no need to take the result as the thread barrier automatically syncs.
Summarize(std::move(scope_.back()), nullptr);
in_device_env_ = false;
scope_.pop_back();
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
env_threads_.CopyOnWrite()->data.pop_back();
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void StorageAccessVisitor::Visit_(const For* op) {
void StorageAccessVisitor::VisitStmt_(const For* op) {
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op);
......@@ -161,11 +160,11 @@ void StorageAccessVisitor::Visit_(const For* op) {
}
}
void StorageAccessVisitor::Visit_(const IfThenElse* op) {
void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) {
++condition_counter_;
this->Visit(op->condition);
this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->Visit(op->then_case);
this->VisitStmt(op->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
......@@ -180,10 +179,10 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) {
--condition_counter_;
}
void StorageAccessVisitor::Visit_(const Call* op) {
void StorageAccessVisitor::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l);
StmtExprVisitor::VisitExpr_(l);
} else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
......@@ -211,7 +210,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
curr_stmt_.access.emplace_back(e);
}
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
CHECK(allow_append_);
const std::string& s = op->args[0].as<StringImm>()->value;
......@@ -224,7 +223,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
curr_stmt_.access.emplace_back(std::move(e));
}
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
}
......@@ -236,11 +235,12 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
return it->second;
}
class StorageAccessInfoLower : public IRMutator {
class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
......@@ -259,7 +259,7 @@ class StorageAccessInfoLower : public IRMutator {
return stmt;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
......@@ -270,26 +270,26 @@ class StorageAccessInfoLower : public IRMutator {
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
return MakeAccessPtr(op);
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
Expr MakeAccessPtr(const Call* op) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
......@@ -337,7 +337,7 @@ class StorageAccessInfoLower : public IRMutator {
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower().Mutate(stmt);
return StorageAccessInfoLower()(std::move(stmt));
}
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
......
......@@ -27,7 +27,7 @@
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <vector>
#include <unordered_map>
#include "../runtime/thread_storage_scope.h"
......@@ -40,7 +40,7 @@ using runtime::StorageRank;
/*!
* \brief Base class of storage access analysis
*/
class StorageAccessVisitor : public IRVisitor {
class StorageAccessVisitor : public StmtExprVisitor {
public:
/*! \brief Storage access type */
enum AccessType {
......@@ -76,13 +76,13 @@ class StorageAccessVisitor : public IRVisitor {
std::vector<AccessEntry> access;
};
// override visitor pattern
void Visit_(const Load* op) final;
void Visit_(const Store* op) final;
void Visit_(const Evaluate* op) final;
void Visit_(const AttrStmt* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Call* op) final;
void VisitExpr_(const Load* op) final;
void VisitStmt_(const Store* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const For* op) final;
void VisitStmt_(const IfThenElse* op) final;
void VisitExpr_(const Call* op) final;
protected:
StorageAccessVisitor() {
......
......@@ -26,8 +26,7 @@
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
......@@ -48,7 +47,7 @@ using runtime::StorageScope;
using runtime::ThreadScope;
using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator {
class StorageFlattener : public StmtExprMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes,
......@@ -64,8 +63,8 @@ class StorageFlattener : public IRMutator {
cache_line_size_ = cache_line_size;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
......@@ -78,14 +77,14 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::double_buffer_scope &&
op->node->IsInstance<OperationNode>()) {
Operation func = Downcast<Operation>(op->node);
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
for (int i = 0; i < func->num_outputs(); ++i) {
TensorKey key{func, i};
auto it = buf_map_.find(key);
......@@ -99,7 +98,7 @@ class StorageFlattener : public IRMutator {
IterVar iv = Downcast<IterVar>(op->node);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
curr_thread_scope_.pop_back();
return stmt;
} else if (op->attr_key == attr::buffer_bind_scope) {
......@@ -116,17 +115,17 @@ class StorageFlattener : public IRMutator {
}
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::opengl_stage_scope) {
is_opengl_ = true;
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Stmt VisitStmt_(const Provide* op) final {
if (create_bound_attributes_)
shape_collector_.clear();
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
......@@ -159,11 +158,11 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
Stmt VisitStmt_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else {
// create a buffer entry
BufferEntry e;
......@@ -226,7 +225,7 @@ class StorageFlattener : public IRMutator {
align, 0, kDefault);
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
buf_map_[key].released = true;
Stmt ret;
......@@ -263,8 +262,8 @@ class StorageFlattener : public IRMutator {
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
......@@ -277,17 +276,17 @@ class StorageFlattener : public IRMutator {
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
return e;
return GetRef<Expr>(op);
}
}
Expr Mutate_(const Call* op, const Expr& olde) final {
Expr expr = IRMutator::Mutate_(op, olde);
Expr VisitExpr_(const Call* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
......@@ -308,8 +307,8 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Mutate_(const Prefetch *op, const Stmt &s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Prefetch *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Prefetch>();
CHECK(op != nullptr);
TensorKey key{op->func, op->value_index};
......@@ -443,7 +442,7 @@ class StorageFlattener : public IRMutator {
// Apply the remaps
Stmt body = MergeNest(binder.asserts(), op->body);
body = MergeNest(binder.init_nest(), body);
body = this->Mutate(body);
body = this->VisitStmt(body);
// remove the binds
for (const Var& v : binder.defs()) {
var_remap_.erase(v.get());
......@@ -531,10 +530,10 @@ class StorageFlattener : public IRMutator {
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) {
IRVisitorWithAnalyzer bounded_analyzer;
bounded_analyzer.Visit(stmt);
bounded_analyzer(stmt);
stmt =
StorageFlattener(extern_buffer, cache_line_size,
create_bound_attributes, &bounded_analyzer).Mutate(stmt);
create_bound_attributes, &bounded_analyzer)(std::move(stmt));
return stmt;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment