Unverified Commit 203ca7a0 by Tianqi Chen Committed by GitHub

[REFACTOR] Migrate Low-level IR Passes into the New Stmt/Expr Mutator (#4607)

* CombineContextCall

* Migrate BoundChecker

* Migrate CoprocSync

* Migrate detect_device

* Migrate loop_partition

* Migrate infer_fragement

* Migrate inject_copy_intrin

* Migrate inject double buffer

* Migrate lower_intrin and simplify

* Migrate storage flatten

* Migrate inject prefetch

* Migrate inject_virtual_thread

* migrate inline

* Migrate lift attr scope

* Migrate custom datatypes

* migrate lower_thread_all_reduce

* Migrate lower_tvm_builtin

* migrate lower_warp memory

* Migrate make_api.cc

* Migrate remap_thread_axis

* Migrate remove_no_op

* migrate rewrite_unsafe_select

* Migrate skip_assert simple_passes

* Migrate split_host_device

* Migrate ssa

* Migrate storage_access

* Migrate storage_rewrite

* Migrate tensor_core

* Migrate unroll_loop

* Migrate vectorize

* Migrate verify compact_buffer gpu_code

* Migrate verify_memory

* Migrate storage_sync

* Remove unused refs to mutator

* Migrate hybrid_op

* Migrate tensorize

* Migrate schedule ops

* Migrate schedule_dataflow_rewrite

* Migrate auto_inline_elemwise

* Remove unecessary ref to visitor

* remove unecessary ref

* Migrate bound_deducer

* Migrate domain_touched

* Migrate autotvm feature touch extractor

* Add annotations
parent 3f43bee0
...@@ -546,6 +546,35 @@ class StmtExprMutator : ...@@ -546,6 +546,35 @@ class StmtExprMutator :
} }
}; };
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_ #endif // TVM_IR_FUNCTOR_EXT_H_
...@@ -122,27 +122,6 @@ class TVM_DLL IRMutator { ...@@ -122,27 +122,6 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const StringImm* op, const Expr& e);
virtual Expr Mutate_(const Shuffle* op, const Expr& e); virtual Expr Mutate_(const Shuffle* op, const Expr& e);
}; };
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_IR_MUTATOR_H_ #endif // TVM_IR_MUTATOR_H_
...@@ -145,15 +145,6 @@ class TVM_DLL IRVisitor { ...@@ -145,15 +145,6 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const FloatImm* op); virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op); virtual void Visit_(const StringImm* op);
}; };
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -25,8 +25,7 @@ ...@@ -25,8 +25,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
namespace tvm { namespace tvm {
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
...@@ -38,17 +38,17 @@ using namespace ir; ...@@ -38,17 +38,17 @@ using namespace ir;
// a visitor to find the path to the target variable // a visitor to find the path to the target variable
// from a expression. // from a expression.
class VariablePathFinder: public IRVisitor { class VariablePathFinder: public ExprVisitor {
public: public:
explicit VariablePathFinder(Expr target) : target_(target) {} explicit VariablePathFinder(Expr target) : target_(target) {}
void Visit(const ObjectRef& node) final { void VisitExpr(const Expr& node) final {
if (visited_.count(node.get()) != 0) return; if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get()); visited_.insert(node.get());
if (!found_) path_.push_back(node.get()); if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true; if (node.same_as(target_)) found_ = true;
IRVisitor::Visit(node); ExprVisitor::VisitExpr(node);
if (!found_) path_.pop_back(); if (!found_) path_.pop_back();
} }
...@@ -64,14 +64,14 @@ class VariablePathFinder: public IRVisitor { ...@@ -64,14 +64,14 @@ class VariablePathFinder: public IRVisitor {
// return empty vector to represent failure // return empty vector to represent failure
std::vector<const Object*> GetPath(Expr target, Expr expr) { std::vector<const Object*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target); VariablePathFinder v(target);
v.Visit(expr); v(expr);
return v.path_; return v.path_;
} }
enum CompareOp {kGreater, kLess, kEqual}; enum CompareOp {kGreater, kLess, kEqual};
// a visitor to deduce the bound of a variable from a expression // a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor { class BoundDeducer: public ExprVisitor {
public: public:
friend class BoundDeduceInputChecker; friend class BoundDeduceInputChecker;
friend class Converter; friend class Converter;
...@@ -82,39 +82,39 @@ class BoundDeducer: public IRVisitor { ...@@ -82,39 +82,39 @@ class BoundDeducer: public IRVisitor {
void Deduce(); void Deduce();
void Visit(const ObjectRef& e) final { void VisitExpr(const Expr& e) final {
if (!success_) return; if (!success_) return;
if (e.get() == path_[iter_++]) { if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e); ExprVisitor::VisitExpr(e);
} else { } else {
success_ = false; success_ = false;
return; return;
} }
} }
void Visit_(const LT* op) final { void VisitExpr_(const LT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void Visit_(const LE* op) final { void VisitExpr_(const LE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void Visit_(const GT* op) final { void VisitExpr_(const GT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void Visit_(const GE* op) final { void VisitExpr_(const GE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void Visit_(const Add* op) final { void VisitExpr_(const Add* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a; result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b); this->VisitExpr(left ? op->a : op->b);
} }
void Visit_(const Sub* op) final { void VisitExpr_(const Sub* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
if (left) { if (left) {
result_ += op->b; result_ += op->b;
...@@ -123,10 +123,10 @@ class BoundDeducer: public IRVisitor { ...@@ -123,10 +123,10 @@ class BoundDeducer: public IRVisitor {
result_ = - result_; result_ = - result_;
comp_op = ReverseOp(comp_op); comp_op = ReverseOp(comp_op);
} }
Visit(left ? op->a : op->b); this->VisitExpr(left ? op->a : op->b);
} }
void Visit_(const Mul* op) final { void VisitExpr_(const Mul* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a; Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b; Expr target_var = left ? op->a : op->b;
...@@ -171,7 +171,7 @@ class BoundDeducer: public IRVisitor { ...@@ -171,7 +171,7 @@ class BoundDeducer: public IRVisitor {
// ( x <= -3/-2 --> x <= 1) // ( x <= -3/-2 --> x <= 1)
} }
} }
Visit(left ? op->a : op->b); this->VisitExpr(left ? op->a : op->b);
} }
Expr result_; Expr result_;
...@@ -194,17 +194,17 @@ class BoundDeducer: public IRVisitor { ...@@ -194,17 +194,17 @@ class BoundDeducer: public IRVisitor {
Analyzer analyzer_; Analyzer analyzer_;
}; };
class BoundDeduceInputChecker: public IRVisitor { class BoundDeduceInputChecker: public ExprVisitor {
public: public:
bool Check(BoundDeducer* deducer) { bool Check(BoundDeducer* deducer) {
deducer_ = deducer; deducer_ = deducer;
Visit(deducer_->expr_); this->VisitExpr(deducer_->expr_);
return target_count == 1; return target_count == 1;
} }
void Visit(const ObjectRef& e) final { void VisitExpr(const Expr& e) final {
if (e.same_as(deducer_->target_)) ++target_count; if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e); ExprVisitor::VisitExpr(e);
} }
private: private:
...@@ -305,7 +305,7 @@ void BoundDeducer::Deduce() { ...@@ -305,7 +305,7 @@ void BoundDeducer::Deduce() {
} }
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_); this->VisitExpr(expr_);
} }
void BoundDeducer::Relax() { void BoundDeducer::Relax() {
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
*/ */
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include "const_fold.h" #include "const_fold.h"
#include "pattern_match.h" #include "pattern_match.h"
#include "rewrite_simplify.h" #include "rewrite_simplify.h"
...@@ -435,30 +434,30 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -435,30 +434,30 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
Expr CanonicalSimplify(Expr expr) { Expr CanonicalSimplify(Expr expr) {
expr = Mutate(expr); expr = operator()(expr);
return expr; return expr;
} }
// override the original mutate function. // override the original mutate function.
Expr Mutate(Expr expr) final { Expr VisitExpr(const Expr& input_expr) final {
expr = IRMutator::Mutate(expr); auto expr = Rewriter::VisitExpr(input_expr);
return Normalize(expr); return Normalize(expr);
} }
// Normal mutation without normalization. // Normal mutation without normalization.
Expr CanonicalMutate(Expr expr) { Expr CanonicalMutate(Expr expr) {
return IRMutator::Mutate(expr); return Rewriter::VisitExpr(expr);
} }
using Rewriter::Mutate_; using Rewriter::VisitExpr_;
Expr Mutate_(const Add* op, const Expr& self) final; Expr VisitExpr_(const Add* op) final;
Expr Mutate_(const Sub* op, const Expr& self) final; Expr VisitExpr_(const Sub* op) final;
Expr Mutate_(const Mul* op, const Expr& self) final; Expr VisitExpr_(const Mul* op) final;
Expr Mutate_(const Div* op, const Expr& self) final; Expr VisitExpr_(const Div* op) final;
Expr Mutate_(const Mod* op, const Expr& self) final; Expr VisitExpr_(const Mod* op) final;
Expr Mutate_(const FloorDiv* op, const Expr& self) final; Expr VisitExpr_(const FloorDiv* op) final;
Expr Mutate_(const FloorMod* op, const Expr& self) final; Expr VisitExpr_(const FloorMod* op) final;
Expr Mutate_(const Reduce* op, const Expr& self) final; Expr VisitExpr_(const Reduce* op) final;
private: private:
/*! /*!
...@@ -567,9 +566,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -567,9 +566,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
}; };
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) { VisitExpr_(const Add* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self); return Rewriter::VisitExpr_(op);
} }
// normalize // normalize
Expr a = this->CanonicalMutate(op->a); Expr a = this->CanonicalMutate(op->a);
...@@ -593,9 +592,9 @@ Mutate_(const Add* op, const Expr& self) { ...@@ -593,9 +592,9 @@ Mutate_(const Add* op, const Expr& self) {
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) { VisitExpr_(const Sub* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self); return Rewriter::VisitExpr_(op);
} }
// normalize // normalize
Expr a = this->CanonicalMutate(op->a); Expr a = this->CanonicalMutate(op->a);
...@@ -620,9 +619,9 @@ Mutate_(const Sub* op, const Expr& self) { ...@@ -620,9 +619,9 @@ Mutate_(const Sub* op, const Expr& self) {
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) { VisitExpr_(const Mul* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self); return Rewriter::VisitExpr_(op);
} }
// normalize // normalize
Expr a = this->CanonicalMutate(op->a); Expr a = this->CanonicalMutate(op->a);
...@@ -652,7 +651,7 @@ Mutate_(const Mul* op, const Expr& self) { ...@@ -652,7 +651,7 @@ Mutate_(const Mul* op, const Expr& self) {
a = Normalize(a); a = Normalize(a);
b = Normalize(b); b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return self; return GetRef<Expr>(op);
} else { } else {
return Mul::make(a, b); return Mul::make(a, b);
} }
...@@ -727,9 +726,9 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { ...@@ -727,9 +726,9 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) { VisitExpr_(const Div* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self); return Rewriter::VisitExpr_(op);
} }
Expr a = this->CanonicalMutate(op->a); Expr a = this->CanonicalMutate(op->a);
...@@ -781,16 +780,16 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -781,16 +780,16 @@ Mutate_(const Div* op, const Expr& self) {
a = Normalize(a); a = Normalize(a);
b = Normalize(b); b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return self; return GetRef<Expr>(op);
} else { } else {
return Div::make(a, b); return Div::make(a, b);
} }
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) { VisitExpr_(const FloorDiv* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self); return Rewriter::VisitExpr_(op);
} }
Expr a = this->CanonicalMutate(op->a); Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
...@@ -837,7 +836,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { ...@@ -837,7 +836,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
a = Normalize(a); a = Normalize(a);
b = Normalize(b); b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return self; return GetRef<Expr>(op);
} else { } else {
return FloorDiv::make(a, b); return FloorDiv::make(a, b);
} }
...@@ -866,7 +865,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { ...@@ -866,7 +865,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
// Do a recursive call to simplify the mod with the new factor. // Do a recursive call to simplify the mod with the new factor.
if (new_upper_factor < lhs->upper_factor && if (new_upper_factor < lhs->upper_factor &&
lhs->upper_factor != SplitExprNode::kPosInf) { lhs->upper_factor != SplitExprNode::kPosInf) {
auto updated = ToSplitExpr(Mutate(ModImpl( auto updated = ToSplitExpr(this->VisitExpr(ModImpl(
lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
// re-apply the lower_factor // re-apply the lower_factor
if (lhs->lower_factor != 1) { if (lhs->lower_factor != 1) {
...@@ -894,9 +893,9 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { ...@@ -894,9 +893,9 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) { VisitExpr_(const Mod* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self); return Rewriter::VisitExpr_(op);
} }
// normalize // normalize
Expr a = this->CanonicalMutate(op->a); Expr a = this->CanonicalMutate(op->a);
...@@ -957,16 +956,16 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -957,16 +956,16 @@ Mutate_(const Mod* op, const Expr& self) {
a = Normalize(a); a = Normalize(a);
b = Normalize(b); b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return self; return GetRef<Expr>(op);
} else { } else {
return Mod::make(a, b); return Mod::make(a, b);
} }
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) { VisitExpr_(const FloorMod* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self); return Rewriter::VisitExpr_(op);
} }
// normalize // normalize
Expr a = this->CanonicalMutate(op->a); Expr a = this->CanonicalMutate(op->a);
...@@ -1017,7 +1016,7 @@ Mutate_(const FloorMod* op, const Expr& self) { ...@@ -1017,7 +1016,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
a = Normalize(a); a = Normalize(a);
b = Normalize(b); b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return self; return GetRef<Expr>(op);
} else { } else {
return FloorMod::make(a, b); return FloorMod::make(a, b);
} }
...@@ -1029,7 +1028,7 @@ SimplifyReduceCombiner(const Reduce* op) { ...@@ -1029,7 +1028,7 @@ SimplifyReduceCombiner(const Reduce* op) {
// First simplify the results // First simplify the results
Array<Expr> simplified_result; Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) { for (const auto& res : op->combiner->result) {
Expr new_res = Mutate(res); Expr new_res = this->VisitExpr(res);
simplified_result.push_back(new_res); simplified_result.push_back(new_res);
} }
...@@ -1078,7 +1077,7 @@ SimplifyReduceCombiner(const Reduce* op) { ...@@ -1078,7 +1077,7 @@ SimplifyReduceCombiner(const Reduce* op) {
if (used[i]) { if (used[i]) {
// We simplify the result and identity, but not the source // We simplify the result and identity, but not the source
new_result.push_back(simplified_result[i]); new_result.push_back(simplified_result[i]);
new_identity.push_back(Mutate(op->combiner->identity_element[i])); new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i]));
new_lhs.push_back(op->combiner->lhs[i]); new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]); new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]); new_source.push_back(op->source[i]);
...@@ -1095,9 +1094,9 @@ SimplifyReduceCombiner(const Reduce* op) { ...@@ -1095,9 +1094,9 @@ SimplifyReduceCombiner(const Reduce* op) {
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
Mutate_(const Reduce* op, const Expr& self) { VisitExpr_(const Reduce* op) {
// Recursively call simplification when necessary. // Recursively call simplification when necessary.
Expr ret = RewriteSimplifier::Impl::Mutate_(op, self); Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
op = ret.as<Reduce>(); op = ret.as<Reduce>();
// already been simplified by const reduction axis removal // already been simplified by const reduction axis removal
if (op == nullptr) return ret; if (op == nullptr) return ret;
...@@ -1106,7 +1105,7 @@ Mutate_(const Reduce* op, const Expr& self) { ...@@ -1106,7 +1105,7 @@ Mutate_(const Reduce* op, const Expr& self) {
// assumption we would have to perform a single iteration of the loop, i.e. use // assumption we would have to perform a single iteration of the loop, i.e. use
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify. // instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
return Mutate( return this->VisitExpr(
Select::make(op->condition, Select::make(op->condition,
op->source[op->value_index], op->source[op->value_index],
op->combiner->identity_element[op->value_index])); op->combiner->identity_element[op->value_index]));
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define TVM_ARITHMETIC_CONST_FOLD_H_ #define TVM_ARITHMETIC_CONST_FOLD_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
...@@ -36,13 +36,13 @@ namespace arith { ...@@ -36,13 +36,13 @@ namespace arith {
using namespace ir; using namespace ir;
// Find Read region of the tensor in the stmt. // Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public IRVisitor { class FuncTouchedDomain final : public StmtExprVisitor {
public: public:
FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides) FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides)
: tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {} : tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {}
Domain Find(const Stmt& stmt) { Domain Find(const Stmt& stmt) {
this->Visit(stmt); operator()(stmt);
Domain ret; Domain ret;
Range none; Range none;
for (size_t i = 0; i < bounds_.size(); ++i) { for (size_t i = 0; i < bounds_.size(); ++i) {
...@@ -51,49 +51,49 @@ class FuncTouchedDomain final : public IRVisitor { ...@@ -51,49 +51,49 @@ class FuncTouchedDomain final : public IRVisitor {
return ret; return ret;
} }
void Visit_(const For *op) final { void VisitStmt_(const For *op) final {
const Variable* var = op->loop_var.get(); const Variable* var = op->loop_var.get();
dom_map_[var] = IntSet::range( dom_map_[var] = IntSet::range(
Range::make_by_min_extent(op->min, op->extent)); Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var); dom_map_.erase(var);
} }
void Visit_(const LetStmt* op) final { void VisitStmt_(const LetStmt* op) final {
dom_map_[op->var.get()] = dom_map_[op->var.get()] =
arith::EvalSet(op->value, dom_map_); arith::EvalSet(op->value, dom_map_);
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->var.get()); dom_map_.erase(op->var.get());
} }
/* TODO: Thread extent unitest not generated.*/ /* TODO: Thread extent unitest not generated.*/
void Visit_(const AttrStmt* op) final { void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>(); const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis); CHECK(thread_axis);
const Variable* var = thread_axis->var.get(); const Variable* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var); dom_map_.erase(var);
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
} }
void Visit_(const Call* op) final { void VisitExpr_(const Call* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func) if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) { && tensor_->value_index == op->value_index) {
Touch(op->args); Touch(op->args);
} }
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} }
void Visit_(const Provide* op) final { void VisitStmt_(const Provide* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func) if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) { && tensor_->value_index == op->value_index) {
Touch(op->args); Touch(op->args);
} }
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
private: private:
......
...@@ -30,41 +30,44 @@ namespace arith { ...@@ -30,41 +30,44 @@ namespace arith {
using namespace ir; using namespace ir;
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
Mutate_(const For* op, const Stmt& s) { VisitStmt_(const For* op) {
analyzer_->Bind(op->loop_var, analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent)); Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
Mutate_(const LetStmt* op, const Stmt& s) { VisitStmt_(const LetStmt* op) {
Expr value = this->Mutate(op->value); Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
} }
// We keep the let-binding here // We keep the let-binding here
// as sub-class may or maynot choose to replace it. // as sub-class may or maynot choose to replace it.
Stmt body = this->Mutate(op->body); Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return GetRef<Stmt>(op);
} else { } else {
return LetStmt::make(op->var, value, body); auto n = this->CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
} }
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
Mutate_(const IfThenElse* op, const Stmt& s) { VisitStmt_(const IfThenElse* op) {
Expr condition = this->Mutate(op->condition); Expr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case; Stmt then_case, else_case;
{ {
With<ConstraintContext> ctx(analyzer_, condition); With<ConstraintContext> ctx(analyzer_, condition);
then_case = this->Mutate(op->then_case); then_case = this->VisitStmt(op->then_case);
} }
if (op->else_case.defined()) { if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_, With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not::make(condition))); analyzer_->rewrite_simplify(Not::make(condition)));
else_case = this->Mutate(op->else_case); else_case = this->VisitStmt(op->else_case);
} }
if (is_one(condition)) return then_case; if (is_one(condition)) return then_case;
if (is_zero(condition)) { if (is_zero(condition)) {
...@@ -77,57 +80,65 @@ Mutate_(const IfThenElse* op, const Stmt& s) { ...@@ -77,57 +80,65 @@ Mutate_(const IfThenElse* op, const Stmt& s) {
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return s; return GetRef<Stmt>(op);
} else { } else {
return IfThenElse::make(condition, then_case, else_case); auto n = this->CopyOnWrite(op);
n->condition = std::move(condition);
n->then_case = std::move(then_case);
n->else_case = std::move(else_case);
return Stmt(n);
} }
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
Mutate_(const AttrStmt* op, const Stmt& s) { VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var, analyzer_->Bind(iv->var,
Range::make_by_min_extent(0, op->value)); Range::make_by_min_extent(0, op->value));
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
return stmt; return stmt;
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
Mutate_(const AssertStmt* op, const Stmt& s) { VisitStmt_(const AssertStmt* op) {
Expr condition = this->Mutate(op->condition); Expr condition = this->VisitExpr(op->condition);
Expr message = this->Mutate(op->message); Expr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition); With<ConstraintContext> ctx(analyzer_, condition);
Stmt body = this->Mutate(op->body); Stmt body = this->VisitStmt(op->body);
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
message.same_as(op->message) && message.same_as(op->message) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return GetRef<Stmt>(op);
} else { } else {
return AssertStmt::make(condition, message, body); auto n = this->CopyOnWrite(op);
n->condition = std::move(condition);
n->message = std::move(message);
n->body = std::move(body);
return Stmt(n);
} }
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
Mutate_(const Call* op, const Expr& self) { VisitExpr_(const Call* op) {
// add condition context to if_then_else // add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]); Expr cond = this->VisitExpr(op->args[0]);
Expr true_value, false_value; Expr true_value, false_value;
{ {
With<ConstraintContext> constraint(analyzer_, cond); With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->args[1]); true_value = this->VisitExpr(op->args[1]);
} }
{ {
With<ConstraintContext> constraint(analyzer_, With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond))); analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->args[2]); false_value = this->VisitExpr(op->args[2]);
} }
if (is_zero(cond)) { if (is_zero(cond)) {
return false_value; return false_value;
...@@ -138,45 +149,45 @@ Mutate_(const Call* op, const Expr& self) { ...@@ -138,45 +149,45 @@ Mutate_(const Call* op, const Expr& self) {
if (cond.same_as(op->args[0]) && if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) && true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) { false_value.same_as(op->args[2])) {
return self; return GetRef<Expr>(op);
} else { } else {
return Call::make(op->dtype, op->name, return Call::make(op->dtype, op->name,
{cond, true_value, false_value}, {cond, true_value, false_value},
op->call_type); op->call_type);
} }
} }
return IRMutator::Mutate_(op, self); return StmtExprMutator::VisitExpr_(op);
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
Mutate_(const Let* op, const Expr& self) { VisitExpr_(const Let* op) {
Expr value = this->Mutate(op->value); Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
} }
// We keep the let-binding here // We keep the let-binding here
// as sub-class may or maynot choose to replace it. // as sub-class may or maynot choose to replace it.
Expr body = this->Mutate(op->body); Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return self; return GetRef<Expr>(op);
} else { } else {
return Let::make(op->var, value, body); return Let::make(op->var, value, body);
} }
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
Mutate_(const Select* op, const Expr& self) { VisitExpr_(const Select* op) {
Expr cond = Mutate(op->condition); Expr cond = this->VisitExpr(op->condition);
Expr true_value, false_value; Expr true_value, false_value;
{ {
With<ConstraintContext> constraint(analyzer_, cond); With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->true_value); true_value = VisitExpr(op->true_value);
} }
{ {
With<ConstraintContext> constraint(analyzer_, With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond))); analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->false_value); false_value = VisitExpr(op->false_value);
} }
if (is_zero(cond)) { if (is_zero(cond)) {
return false_value; return false_value;
...@@ -188,20 +199,20 @@ Mutate_(const Select* op, const Expr& self) { ...@@ -188,20 +199,20 @@ Mutate_(const Select* op, const Expr& self) {
if (cond.same_as(op->condition) && if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) && true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) { false_value.same_as(op->false_value)) {
return self; return GetRef<Expr>(op);
} else { } else {
return Select::make(cond, true_value, false_value); return Select::make(cond, true_value, false_value);
} }
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
Mutate_(const Reduce* op, const Expr& self) { VisitExpr_(const Reduce* op) {
// Setup the domain information before simplification. // Setup the domain information before simplification.
for (const IterVar& iv : op->axis) { for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom); analyzer_->Bind(iv->var, iv->dom);
} }
// Recursively call simplification when necessary. // Recursively call simplification when necessary.
return IRMutator::Mutate_(op, self); return StmtExprMutator::VisitExpr_(op);
} }
} // namespace arith } // namespace arith
......
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
#ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_ #ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <utility>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -40,23 +40,24 @@ namespace arith { ...@@ -40,23 +40,24 @@ namespace arith {
* *
* \sa src/arithmetic/ir_mutator_with_analyzer.cc * \sa src/arithmetic/ir_mutator_with_analyzer.cc
*/ */
class IRMutatorWithAnalyzer : public ir::IRMutator { class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
public: public:
explicit IRMutatorWithAnalyzer(Analyzer* analyzer) explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {} : analyzer_(analyzer) {}
using IRMutator::Mutate_; using StmtExprMutator::VisitStmt_;
using StmtExprMutator::VisitExpr_;
// override functions that need to populate the context information. // override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override; Stmt VisitStmt_(const ir::For* op) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override; Stmt VisitStmt_(const ir::LetStmt* op) override;
Stmt Mutate_(const ir::IfThenElse* op, const Stmt& self) override; Stmt VisitStmt_(const ir::IfThenElse* op) override;
Stmt Mutate_(const ir::AttrStmt* op, const Stmt& self) override; Stmt VisitStmt_(const ir::AttrStmt* op) override;
Stmt Mutate_(const ir::AssertStmt* op, const Stmt& self) override; Stmt VisitStmt_(const ir::AssertStmt* op) override;
Expr Mutate_(const ir::Let* op, const Expr& self) override; Expr VisitExpr_(const ir::Let* op) override;
Expr Mutate_(const ir::Select* op, const Expr& self) override; Expr VisitExpr_(const ir::Select* op) override;
Expr Mutate_(const ir::Call* op, const Expr& self) override; Expr VisitExpr_(const ir::Call* op) override;
Expr Mutate_(const ir::Reduce* op, const Expr& self) override; Expr VisitExpr_(const ir::Reduce* op) override;
protected: protected:
/*! \brief internal analyzer field. */ /*! \brief internal analyzer field. */
......
...@@ -27,43 +27,43 @@ ...@@ -27,43 +27,43 @@
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class IRVisitorWithAnalyzer final : public IRVisitor { class IRVisitorWithAnalyzer final : public StmtExprVisitor {
public: public:
Expr Simplify(const Expr& expr) { Expr Simplify(const Expr& expr) {
return analyzer_.Simplify(expr); return analyzer_.Simplify(expr);
} }
void Visit_(const For* op) { void VisitStmt_(const For* op) {
analyzer_.Bind(op->loop_var, analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent)); Range::make_by_min_extent(op->min, op->extent));
return IRVisitor::Visit_(op); return StmtExprVisitor::VisitStmt_(op);
} }
void Visit_(const AttrStmt* op) { void VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var, analyzer_.Bind(iv->var,
Range::make_by_min_extent(0, op->value)); Range::make_by_min_extent(0, op->value));
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
} }
void Visit_(const Reduce* op) { void VisitExpr_(const Reduce* op) {
// Setup the domain information before simplification. // Setup the domain information before simplification.
for (const IterVar& iv : op->axis) { for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom); analyzer_.Bind(iv->var, iv->dom);
} }
// Recursively call simplification when necessary. // Recursively call simplification when necessary.
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} }
protected: protected:
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
// Acknowledgement: Most rewrite-rules are from Halide. // Acknowledgement: Most rewrite-rules are from Halide.
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <algorithm> #include <algorithm>
#include "const_fold.h" #include "const_fold.h"
#include "pattern_match.h" #include "pattern_match.h"
...@@ -69,7 +68,7 @@ using namespace ir; ...@@ -69,7 +68,7 @@ using namespace ir;
// try to prove x equals val // try to prove x equals val
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
TryCompare(const Expr& x, int64_t val) { TryCompare(const Expr& x, int64_t val) {
Expr diff = Mutate(x); Expr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImm>()) { if (const auto* ptr = diff.as<IntImm>()) {
if (ptr->value == val) { if (ptr->value == val) {
return kEQ; return kEQ;
...@@ -117,8 +116,8 @@ Update(const Var& var, const Expr& info, bool override) { ...@@ -117,8 +116,8 @@ Update(const Var& var, const Expr& info, bool override) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) { VisitExpr_(const Add* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Add>(); op = ret.as<Add>();
Expr const_res = TryConstFold<Add>(op->a, op->b); Expr const_res = TryConstFold<Add>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -232,8 +231,8 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& const ...@@ -232,8 +231,8 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& const
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) { VisitExpr_(const Sub* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Sub>(); op = ret.as<Sub>();
Expr const_res = TryConstFold<Sub>(op->a, op->b); Expr const_res = TryConstFold<Sub>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -431,8 +430,8 @@ Mutate_(const Sub* op, const Expr& self) { ...@@ -431,8 +430,8 @@ Mutate_(const Sub* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) { VisitExpr_(const Mul* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Mul>(); op = ret.as<Mul>();
Expr const_res = TryConstFold<Mul>(op->a, op->b); Expr const_res = TryConstFold<Mul>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -470,8 +469,8 @@ Mutate_(const Mul* op, const Expr& self) { ...@@ -470,8 +469,8 @@ Mutate_(const Mul* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) { VisitExpr_(const Div* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Div>(); op = ret.as<Div>();
Expr const_res = TryConstFold<Div>(op->a, op->b); Expr const_res = TryConstFold<Div>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -692,8 +691,8 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -692,8 +691,8 @@ Mutate_(const Div* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) { VisitExpr_(const Mod* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Mod>(); op = ret.as<Mod>();
Expr const_res = TryConstFold<Mod>(op->a, op->b); Expr const_res = TryConstFold<Mod>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -782,8 +781,8 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -782,8 +781,8 @@ Mutate_(const Mod* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) { VisitExpr_(const FloorDiv* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDiv>(); op = ret.as<FloorDiv>();
Expr const_res = TryConstFold<FloorDiv>(op->a, op->b); Expr const_res = TryConstFold<FloorDiv>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -926,8 +925,8 @@ Mutate_(const FloorDiv* op, const Expr& self) { ...@@ -926,8 +925,8 @@ Mutate_(const FloorDiv* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) { VisitExpr_(const FloorMod* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorMod>(); op = ret.as<FloorMod>();
Expr const_res = TryConstFold<FloorMod>(op->a, op->b); Expr const_res = TryConstFold<FloorMod>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -996,8 +995,8 @@ Mutate_(const FloorMod* op, const Expr& self) { ...@@ -996,8 +995,8 @@ Mutate_(const FloorMod* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Min* op, const Expr& self) { VisitExpr_(const Min* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Min>(); op = ret.as<Min>();
Expr const_res = TryConstFold<Min>(op->a, op->b); Expr const_res = TryConstFold<Min>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -1181,8 +1180,8 @@ Mutate_(const Min* op, const Expr& self) { ...@@ -1181,8 +1180,8 @@ Mutate_(const Min* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Max* op, const Expr& self) { VisitExpr_(const Max* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Max>(); op = ret.as<Max>();
Expr const_res = TryConstFold<Max>(op->a, op->b); Expr const_res = TryConstFold<Max>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -1354,8 +1353,8 @@ Mutate_(const Max* op, const Expr& self) { ...@@ -1354,8 +1353,8 @@ Mutate_(const Max* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const EQ* op, const Expr& self) { VisitExpr_(const EQ* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQ>(); op = ret.as<EQ>();
Expr const_res = TryConstFold<EQ>(op->a, op->b); Expr const_res = TryConstFold<EQ>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -1388,28 +1387,28 @@ Mutate_(const EQ* op, const Expr& self) { ...@@ -1388,28 +1387,28 @@ Mutate_(const EQ* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const NE* op, const Expr& self) { VisitExpr_(const NE* op) {
return Mutate(Not::make(op->a == op->b)); return this->VisitExpr(Not::make(op->a == op->b));
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const LE* op, const Expr& self) { VisitExpr_(const LE* op) {
return Mutate(Not::make(op->b < op->a)); return this->VisitExpr(Not::make(op->b < op->a));
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const GT* op, const Expr& self) { VisitExpr_(const GT* op) {
return Mutate(op->b < op->a); return this->VisitExpr(op->b < op->a);
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const GE* op, const Expr& self) { VisitExpr_(const GE* op) {
return Mutate(Not::make(op->a < op->b)); return this->VisitExpr(Not::make(op->a < op->b));
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const LT* op, const Expr& self) { VisitExpr_(const LT* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LT>(); op = ret.as<LT>();
Expr const_res = TryConstFold<LT>(op->a, op->b); Expr const_res = TryConstFold<LT>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -1564,8 +1563,8 @@ Mutate_(const LT* op, const Expr& self) { ...@@ -1564,8 +1563,8 @@ Mutate_(const LT* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Not* op, const Expr& self) { VisitExpr_(const Not* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Not>(); op = ret.as<Not>();
Expr const_res = TryConstFold<Not>(op->a); Expr const_res = TryConstFold<Not>(op->a);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -1589,8 +1588,8 @@ Mutate_(const Not* op, const Expr& self) { ...@@ -1589,8 +1588,8 @@ Mutate_(const Not* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const And* op, const Expr& self) { VisitExpr_(const And* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<And>(); op = ret.as<And>();
Expr const_res = TryConstFold<And>(op->a, op->b); Expr const_res = TryConstFold<And>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -1638,8 +1637,8 @@ Mutate_(const And* op, const Expr& self) { ...@@ -1638,8 +1637,8 @@ Mutate_(const And* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Or* op, const Expr& self) { VisitExpr_(const Or* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Or>(); op = ret.as<Or>();
Expr const_res = TryConstFold<Or>(op->a, op->b); Expr const_res = TryConstFold<Or>(op->a, op->b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
...@@ -1688,8 +1687,8 @@ Mutate_(const Or* op, const Expr& self) { ...@@ -1688,8 +1687,8 @@ Mutate_(const Or* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) { VisitExpr_(const Select* op) {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Select>(); op = ret.as<Select>();
if (op == nullptr) return ret; if (op == nullptr) return ret;
// Pattern var to match any expression // Pattern var to match any expression
...@@ -1699,9 +1698,9 @@ Mutate_(const Select* op, const Expr& self) { ...@@ -1699,9 +1698,9 @@ Mutate_(const Select* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) { VisitExpr_(const Call* op) {
// add condition context to if_then_else // add condition context to if_then_else
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Call>(); op = ret.as<Call>();
if (op == nullptr) return ret; if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
...@@ -1729,35 +1728,35 @@ Mutate_(const Call* op, const Expr& self) { ...@@ -1729,35 +1728,35 @@ Mutate_(const Call* op, const Expr& self) {
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Variable* op, const Expr& self) { VisitExpr_(const Variable* op) {
Var var = GetRef<Var>(op); Var var = GetRef<Var>(op);
auto it = var_map_.find(var); auto it = var_map_.find(var);
if (it != var_map_.end()) { if (it != var_map_.end()) {
return it->second; return it->second;
} }
return self; return GetRef<Expr>(op);
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Cast* op, const Expr& self) { VisitExpr_(const Cast* op) {
Expr ret = IRMutator::Mutate_(op, self); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Cast>(); op = ret.as<Cast>();
return cast(op->dtype, op->value); return cast(op->dtype, op->value);
} }
Expr RewriteSimplifier::Impl:: Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) { VisitExpr_(const Let* op) {
Expr value = this->Mutate(op->value); Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding // it is fine to discard the let binding
// because the value will always be inlined in the simplifier. // because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
return this->Mutate(op->body); return this->VisitExpr(op->body);
} }
Expr body = this->Mutate(op->body); Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return self; return GetRef<Expr>(op);
} else { } else {
return Let::make(op->var, value, body); return Let::make(op->var, value, body);
} }
...@@ -1768,7 +1767,7 @@ Expr RewriteSimplifier::operator()(const Expr& expr) { ...@@ -1768,7 +1767,7 @@ Expr RewriteSimplifier::operator()(const Expr& expr) {
Expr res = expr; Expr res = expr;
int max_iter = 2; int max_iter = 2;
for (int i = 0; i < max_iter; ++i) { for (int i = 0; i < max_iter; ++i) {
Expr new_expr = impl_->Mutate(res); Expr new_expr = impl_->operator()(res);
if (new_expr.same_as(res)) return res; if (new_expr.same_as(res)) return res;
res = new_expr; res = new_expr;
} }
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "const_fold.h" #include "const_fold.h"
...@@ -45,35 +44,35 @@ using namespace ir; ...@@ -45,35 +44,35 @@ using namespace ir;
*/ */
class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
public: public:
using IRMutatorWithAnalyzer::Mutate_; using IRMutatorWithAnalyzer::VisitExpr_;
explicit Impl(Analyzer* parent) explicit Impl(Analyzer* parent)
: IRMutatorWithAnalyzer(parent) {} : IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override); void Update(const Var& var, const Expr& info, bool override_info);
Expr Mutate_(const Add* op, const Expr& self) override; Expr VisitExpr_(const Add* op) override;
Expr Mutate_(const Sub* op, const Expr& self) override; Expr VisitExpr_(const Sub* op) override;
Expr Mutate_(const Mul* op, const Expr& self) override; Expr VisitExpr_(const Mul* op) override;
Expr Mutate_(const Div* op, const Expr& self) override; Expr VisitExpr_(const Div* op) override;
Expr Mutate_(const Mod* op, const Expr& self) override; Expr VisitExpr_(const Mod* op) override;
Expr Mutate_(const FloorDiv* op, const Expr& self) override; Expr VisitExpr_(const FloorDiv* op) override;
Expr Mutate_(const FloorMod* op, const Expr& self) override; Expr VisitExpr_(const FloorMod* op) override;
Expr Mutate_(const Min* op, const Expr& self) override; Expr VisitExpr_(const Min* op) override;
Expr Mutate_(const Max* op, const Expr& self) override; Expr VisitExpr_(const Max* op) override;
Expr Mutate_(const EQ* op, const Expr& self) override; Expr VisitExpr_(const EQ* op) override;
Expr Mutate_(const NE* op, const Expr& self) override; Expr VisitExpr_(const NE* op) override;
Expr Mutate_(const LT* op, const Expr& self) override; Expr VisitExpr_(const LT* op) override;
Expr Mutate_(const LE* op, const Expr& self) override; Expr VisitExpr_(const LE* op) override;
Expr Mutate_(const GT* op, const Expr& self) override; Expr VisitExpr_(const GT* op) override;
Expr Mutate_(const GE* op, const Expr& self) override; Expr VisitExpr_(const GE* op) override;
Expr Mutate_(const And* op, const Expr& self) override; Expr VisitExpr_(const And* op) override;
Expr Mutate_(const Or* op, const Expr& self) override; Expr VisitExpr_(const Or* op) override;
Expr Mutate_(const Not* op, const Expr& self) override; Expr VisitExpr_(const Not* op) override;
Expr Mutate_(const Select* op, const Expr& self) override; Expr VisitExpr_(const Select* op) override;
Expr Mutate_(const Call* op, const Expr& self) override; Expr VisitExpr_(const Call* op) override;
Expr Mutate_(const Variable* op, const Expr& self) override; Expr VisitExpr_(const Variable* op) override;
Expr Mutate_(const Cast* op, const Expr& self) override; Expr VisitExpr_(const Cast* op) override;
Expr Mutate_(const Let* op, const Expr& self) override; Expr VisitExpr_(const Let* op) override;
std::function<void()> EnterConstraint(const Expr& constraint); std::function<void()> EnterConstraint(const Expr& constraint);
...@@ -123,7 +122,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { ...@@ -123,7 +122,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr RecursiveRewrite(const Expr& x) { Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x; if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_; ++recur_depth_;
Expr res = Mutate(x); Expr res = this->VisitExpr(x);
--recur_depth_; --recur_depth_;
return res; return res;
} }
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include "ir_mutator_with_analyzer.h" #include "ir_mutator_with_analyzer.h"
...@@ -40,44 +39,47 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -40,44 +39,47 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
: IRMutatorWithAnalyzer(analyzer) {} : IRMutatorWithAnalyzer(analyzer) {}
using Parent = IRMutatorWithAnalyzer; using Parent = IRMutatorWithAnalyzer;
using Parent::Mutate; using Parent::VisitStmt;
using Parent::Mutate_; using Parent::VisitStmt_;
Expr Mutate(Expr expr) final { Expr VisitExpr(const Expr& expr) final {
return analyzer_->Simplify(expr); return analyzer_->Simplify(expr);
} }
Stmt Simplify(Stmt stmt) { Stmt Simplify(Stmt stmt) {
return Mutate(stmt); return operator()(std::move(stmt));
} }
Stmt Mutate_(const For* op, const Stmt& s) final { Stmt VisitStmt_(const For* op) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min); With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent); With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return IRMutator::Mutate_(op, s); return Parent::VisitStmt_(op);
} }
Stmt Mutate_(const LetStmt* op, const Stmt& s) { Stmt VisitStmt_(const LetStmt* op) {
Expr value = this->Mutate(op->value); Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding // it is fine to discard the let binding
// because the call to simplify will always inline the var. // because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
return Mutate(op->body); return this->VisitStmt(op->body);
} }
Stmt body = this->Mutate(op->body); Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return GetRef<Stmt>(op);
} else { } else {
return LetStmt::make(op->var, value, body); auto n = this->CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
} }
} }
// eliminate useless stores // eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt VisitStmt_(const Store* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<Store>(); op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) { if (const Load* load = op->value.as<Load>()) {
if (load->buffer_var.same_as(op->buffer_var) && if (load->buffer_var.same_as(op->buffer_var) &&
...@@ -85,7 +87,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -85,7 +87,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Evaluate::make(0); return Evaluate::make(0);
} }
} }
return stmt; return GetRef<Stmt>(op);
} }
}; };
...@@ -98,7 +100,7 @@ Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) { ...@@ -98,7 +100,7 @@ Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) { for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second); analyzer.Bind(kv.first, kv.second);
} }
return arith::StmtSimplifier(&analyzer).Simplify(stmt); return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt));
} }
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) { Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
...@@ -119,7 +121,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) { ...@@ -119,7 +121,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
} }
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) { Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(stmt, vrange); return CanonicalSimplify(std::move(stmt), vrange);
} }
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -29,7 +29,7 @@ namespace tvm { ...@@ -29,7 +29,7 @@ namespace tvm {
namespace autotvm { namespace autotvm {
// for loop // for loop
void FeatureVisitor::Visit_(const For *op) { void FeatureVisitor::VisitStmt_(const For* op) {
const auto *extent = op->extent.as<IntImm>(); const auto *extent = op->extent.as<IntImm>();
int64_t loop_extent = -1; int64_t loop_extent = -1;
if (extent != nullptr) if (extent != nullptr)
...@@ -51,13 +51,13 @@ void FeatureVisitor::Visit_(const For *op) { ...@@ -51,13 +51,13 @@ void FeatureVisitor::Visit_(const For *op) {
} }
if (EnterItervar_(op->loop_var, loop_extent, ann)) { if (EnterItervar_(op->loop_var, loop_extent, ann)) {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
ExitItervar_(); ExitItervar_();
} }
} }
// parallel axis, virtual thread // parallel axis, virtual thread
void FeatureVisitor::Visit_(const AttrStmt *op) { void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var; VarExpr var = op->node.as<tvm::IterVarNode>()->var;
...@@ -86,24 +86,24 @@ void FeatureVisitor::Visit_(const AttrStmt *op) { ...@@ -86,24 +86,24 @@ void FeatureVisitor::Visit_(const AttrStmt *op) {
} }
if (EnterItervar_(var, extent->value, ann)) { if (EnterItervar_(var, extent->value, ann)) {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
ExitItervar_(); ExitItervar_();
} }
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
} }
// memory access // memory access
void FeatureVisitor::Visit_(const Load *op) { void FeatureVisitor::VisitExpr_(const Load* op) {
EnterMem_(op->buffer_var, op->index); EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
ExitMem_(); ExitMem_();
} }
void FeatureVisitor::Visit_(const Store *op) { void FeatureVisitor::VisitStmt_(const Store* op) {
EnterMem_(op->buffer_var, op->index); EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
ExitMem_(); ExitMem_();
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#define TVM_AUTOTVM_FEATURE_VISITOR_H_ #define TVM_AUTOTVM_FEATURE_VISITOR_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
...@@ -48,15 +48,18 @@ enum AnnotationType { ...@@ -48,15 +48,18 @@ enum AnnotationType {
* \brief A base class for feature extractor, used for processing * \brief A base class for feature extractor, used for processing
* for loop and memory access in the IR * for loop and memory access in the IR
*/ */
class FeatureVisitor : public IRVisitor { class FeatureVisitor : public StmtExprVisitor {
public: public:
// for loop // for loop
void Visit_(const For *op); void VisitStmt_(const For *op);
void Visit_(const AttrStmt *op); void VisitStmt_(const AttrStmt *op);
// memory access // memory access
void Visit_(const Load *op); void VisitExpr_(const Load *op);
void Visit_(const Store *op); void VisitStmt_(const Store *op);
using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_;
protected: protected:
/*! /*!
......
...@@ -44,14 +44,14 @@ int ParallelLevel(AnnotationType ann) { ...@@ -44,14 +44,14 @@ int ParallelLevel(AnnotationType ann) {
} }
// get touch pattern from index expression // get touch pattern from index expression
class IndexParser: public IRVisitor { class IndexParser: public ExprVisitor {
public: public:
void Parse(Expr expr) { void Parse(Expr expr) {
pattern_map.clear(); pattern_map.clear();
this->Visit(expr); this->VisitExpr(expr);
} }
void Visit_(const Variable *op) { void VisitExpr_(const Variable *op) {
// TODO(lmzheng): handle more index types (multiple occurrence) // TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) { if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern(); pattern_map[op] = TouchPattern();
...@@ -60,13 +60,13 @@ class IndexParser: public IRVisitor { ...@@ -60,13 +60,13 @@ class IndexParser: public IRVisitor {
} }
} }
void Visit_(const Mul *op) { void VisitExpr_(const Mul *op) {
if (op->a.as<Variable>()) { if (op->a.as<Variable>()) {
if (const auto stride = op->b.as<IntImm>()) { if (const auto stride = op->b.as<IntImm>()) {
next_stride_ = stride->value; next_stride_ = stride->value;
} }
} }
IRVisitor::Visit_(op); ExprVisitor::VisitExpr_(op);
} }
std::unordered_map<const Variable*, TouchPattern> pattern_map; std::unordered_map<const Variable*, TouchPattern> pattern_map;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <stack> #include <stack>
#include <vector> #include <vector>
...@@ -85,39 +85,39 @@ struct ItervarFeature { ...@@ -85,39 +85,39 @@ struct ItervarFeature {
// extract iter vars and their touch pattern from ir // extract iter vars and their touch pattern from ir
class TouchExtractor : public FeatureVisitor { class TouchExtractor : public FeatureVisitor {
public: public:
void Analyze(Stmt stmt) { void Analyze(const Stmt& stmt) {
this->Visit(stmt); operator()(stmt);
} }
// arithmetic stats // arithmetic stats
void Visit_(const Add *op) { void VisitExpr_(const Add *op) {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++; itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op); FeatureVisitor::VisitExpr_(op);
} }
void Visit_(const Sub *op) { void VisitExpr_(const Sub *op) {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++; itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op); FeatureVisitor::VisitExpr_(op);
} }
void Visit_(const Mul *op) { void VisitExpr_(const Mul *op) {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++; itervar_map[itervar_stack_.back()].mul_ct++;
IRVisitor::Visit_(op); FeatureVisitor::VisitExpr_(op);
} }
void Visit_(const Div *op) { void VisitExpr_(const Div *op) {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++; itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op); FeatureVisitor::VisitExpr_(op);
} }
void Visit_(const Mod *op) { void VisitExpr_(const Mod *op) {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++; itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op); FeatureVisitor::VisitExpr_(op);
} }
std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map; std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
...@@ -134,7 +134,7 @@ class TouchExtractor : public FeatureVisitor { ...@@ -134,7 +134,7 @@ class TouchExtractor : public FeatureVisitor {
std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing
std::deque<size_t> skip_stack_size_; std::deque<size_t> skip_stack_size_;
using IRVisitor::Visit_; using FeatureVisitor::VisitExpr_;
}; };
} // namespace autotvm } // namespace autotvm
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_set> #include <unordered_set>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -538,7 +538,7 @@ namespace { ...@@ -538,7 +538,7 @@ namespace {
* must be Reduce as well; and their inputs should have the * must be Reduce as well; and their inputs should have the
* same attribute except value_index. * same attribute except value_index.
*/ */
class ComputeVerifier final : protected ir::IRVisitor { class ComputeVerifier final : protected ir::ExprVisitor {
public: public:
/// Special member functions /// Special member functions
//@{ //@{
...@@ -567,20 +567,20 @@ class ComputeVerifier final : protected ir::IRVisitor { ...@@ -567,20 +567,20 @@ class ComputeVerifier final : protected ir::IRVisitor {
} }
level_ = 0; level_ = 0;
ir::IRVisitor::Visit(e); ExprVisitor::VisitExpr(e);
} }
} }
protected: protected:
/// Visitor implementation /// Visitor implementation
//@{ //@{
void Visit(const ObjectRef& n) final { void VisitExpr(const Expr& n) final {
++level_; ++level_;
ir::IRVisitor::Visit(n); ExprVisitor::VisitExpr(n);
--level_; --level_;
} }
void Visit_(const ir::Reduce* op) final { void VisitExpr_(const ir::Reduce* op) final {
// Check for non top level reductions // Check for non top level reductions
CHECK(0 == level_) CHECK(0 == level_)
<< "Reductions are only allowed at the top level of compute. " << "Reductions are only allowed at the top level of compute. "
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <unordered_set> #include <unordered_set>
...@@ -221,7 +221,7 @@ namespace op { ...@@ -221,7 +221,7 @@ namespace op {
Stmt ApplyLoopShapes(const Stage &stage, Stmt ApplyLoopShapes(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) { const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
class LoopSpliter : public IRMutator { class LoopSpliter : public StmtExprMutator {
Expr factor; Expr factor;
const Variable *parent; const Variable *parent;
IterVar inner, outer; IterVar inner, outer;
...@@ -247,7 +247,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -247,7 +247,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type); outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
} }
Stmt Mutate_(const For *op, const Stmt &stmt) { Stmt VisitStmt_(const For *op) final {
if (op->loop_var.get() == parent) { if (op->loop_var.get() == parent) {
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = inner + outer * factor; rmap[op->loop_var.get()] = inner + outer * factor;
...@@ -261,11 +261,11 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -261,11 +261,11 @@ Stmt ApplyLoopShapes(const Stage &stage,
splitted = true; splitted = true;
return ret; return ret;
} }
return IRMutator::Mutate_(op, stmt); return StmtExprMutator::VisitStmt_(op);
} }
}; };
class LoopFuser : public IRMutator { class LoopFuser : public StmtExprMutator {
const IterVar &parent; const IterVar &parent;
const Variable *inner; const Variable *inner;
const Variable *outer; const Variable *outer;
...@@ -280,8 +280,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -280,8 +280,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
extent(0), fused(false) {} extent(0), fused(false) {}
// TODO(@were): Handle imperfect loops // TODO(@were): Handle imperfect loops
Stmt VisitStmt_(const For* op) final {
Stmt Mutate_(const For *op, const Stmt &stmt) {
if (op->loop_var.get() == inner) { if (op->loop_var.get() == inner) {
CHECK(under_outer); CHECK(under_outer);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
...@@ -291,7 +290,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -291,7 +290,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
return ir::Substitute(op->body, rmap); return ir::Substitute(op->body, rmap);
} else if (op->loop_var.get() == outer) { } else if (op->loop_var.get() == outer) {
under_outer = true; under_outer = true;
Stmt body = IRMutator::Mutate(op->body); Stmt body = this->VisitStmt(op->body);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = indexdiv(parent, extent); rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap); body = ir::Substitute(body, rmap);
...@@ -299,25 +298,25 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -299,25 +298,25 @@ Stmt ApplyLoopShapes(const Stage &stage,
return For::make(parent->var, Expr(0), extent * op->extent, return For::make(parent->var, Expr(0), extent * op->extent,
op->for_type, op->device_api, body); op->for_type, op->device_api, body);
} else if (under_outer) { } else if (under_outer) {
Stmt body = IRMutator::Mutate(op->body); Stmt body = this->VisitStmt(op->body);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap); body = ir::Substitute(body, rmap);
extent = extent * op->extent; extent = extent * op->extent;
return body; return body;
} }
return IRMutator::Mutate(stmt); return StmtExprMutator::VisitStmt_(op);
} }
}; };
for (auto &rel : stage->relations) { for (auto &rel : stage->relations) {
if (const SplitNode *split = rel.as<SplitNode>()) { if (const SplitNode *split = rel.as<SplitNode>()) {
LoopSpliter Spliter(split, dom_map); LoopSpliter Spliter(split, dom_map);
stmt = Spliter.Mutate(stmt); stmt = Spliter(stmt);
CHECK(Spliter.splitted); CHECK(Spliter.splitted);
} else if (const FuseNode *fuse = rel.as<FuseNode>()) { } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
LoopFuser Fuser(fuse); LoopFuser Fuser(fuse);
stmt = Fuser.Mutate(stmt); stmt = Fuser(stmt);
CHECK(Fuser.fused); CHECK(Fuser.fused);
} }
} }
...@@ -327,14 +326,14 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -327,14 +326,14 @@ Stmt ApplyLoopShapes(const Stage &stage,
Stmt ApplyLoopAnnotations(const Stage &stage, Stmt ApplyLoopAnnotations(const Stage &stage,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) { const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
class LoopAnnotator : public IRMutator { class LoopAnnotator : public StmtMutator {
const Variable *var; const Variable *var;
const IterVarAttr &attr; const IterVarAttr &attr;
public: public:
LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
Stmt Mutate_(const For *op, const Stmt &stmt) { Stmt VisitStmt_(const For *op) final {
if (op->loop_var.get() == var) { if (op->loop_var.get() == var) {
if (attr->bind_thread.defined()) { if (attr->bind_thread.defined()) {
const auto &iter_var = attr->bind_thread; const auto &iter_var = attr->bind_thread;
...@@ -352,7 +351,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, ...@@ -352,7 +351,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
IterVarTypeToForType(attr->iter_type), op->device_api, op->body); IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
} }
} }
return IRMutator::Mutate_(op, stmt); return StmtMutator::VisitStmt_(op);
} }
}; };
...@@ -381,7 +380,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, ...@@ -381,7 +380,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
CHECK_EQ(found, 1) << " iter var should be found exactly once!"; CHECK_EQ(found, 1) << " iter var should be found exactly once!";
if (need_change) { if (need_change) {
stmt = LoopAnnotator(var, attr).Mutate(stmt); stmt = LoopAnnotator(var, attr)(std::move(stmt));
} }
} }
return stmt; return stmt;
...@@ -411,7 +410,7 @@ Stmt ApplyLoopOrder(const Stage &stage, ...@@ -411,7 +410,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
} }
} }
class LoopReorder : public IRMutator { class LoopReorder : public StmtMutator {
const Stage &stage; const Stage &stage;
const std::unordered_map<IterVar, Range> &dom_map; const std::unordered_map<IterVar, Range> &dom_map;
const std::unordered_map<const Variable *, IterVar> &reorder; const std::unordered_map<const Variable *, IterVar> &reorder;
...@@ -422,13 +421,13 @@ Stmt ApplyLoopOrder(const Stage &stage, ...@@ -422,13 +421,13 @@ Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<const Variable*, IterVar> &reorder) const std::unordered_map<const Variable*, IterVar> &reorder)
: stage(stage), dom_map(dom_map), reorder(reorder) {} : stage(stage), dom_map(dom_map), reorder(reorder) {}
Stmt Mutate_(const For *op, const Stmt &stmt) { Stmt VisitStmt_(const For* op) final {
// Reorder from in to out // Reorder from in to out
Stmt body_ = IRMutator::Mutate(op->body); Stmt body_ = this->VisitStmt(op->body);
CHECK(reorder.count(op->loop_var.get())); CHECK(reorder.count(op->loop_var.get()));
auto target = reorder.find(op->loop_var.get())->second; auto target = reorder.find(op->loop_var.get())->second;
if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
return stmt; return GetRef<Stmt>(op);
const Stmt &body = op->body.same_as(body_) ? op->body : body_; const Stmt &body = op->body.same_as(body_) ? op->body : body_;
ForType for_type = IterVarTypeToForType(target->iter_type); ForType for_type = IterVarTypeToForType(target->iter_type);
if (stage->iter_var_attrs.count(target)) { if (stage->iter_var_attrs.count(target)) {
...@@ -441,7 +440,7 @@ Stmt ApplyLoopOrder(const Stage &stage, ...@@ -441,7 +440,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
}; };
if (need_reorder) if (need_reorder)
return LoopReorder(stage, dom_map, reorder).Mutate(stmt); return LoopReorder(stage, dom_map, reorder)(stmt);
return stmt; return stmt;
} }
...@@ -479,21 +478,21 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) { ...@@ -479,21 +478,21 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) {
} }
// replacer to replace tensors' usage in Provide // replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator { class ProviderReplacer : public ir::StmtMutator {
public: public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap) explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
: vmap_(vmap) {} : vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt &s) { Stmt VisitStmt_(const ir::Provide* op) final {
Tensor t = Downcast<Operation>(op->func).output(op->value_index); Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t); auto it = vmap_.find(t);
if (it != vmap_.end()) { if (it != vmap_.end()) {
Stmt ret = ir::Provide::make( Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args); it->second->op, it->second->value_index, op->value, op->args);
found = true; found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret); return this->VisitStmt(ret);
} }
return IRMutator::Mutate_(op, s); return StmtMutator::VisitStmt_(op);
} }
// whether it is found. // whether it is found.
...@@ -506,7 +505,7 @@ class ProviderReplacer : public ir::IRMutator { ...@@ -506,7 +505,7 @@ class ProviderReplacer : public ir::IRMutator {
Stmt ReplaceProvideTensor(Stmt stmt, Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor> &replace) { const std::unordered_map<Tensor, Tensor> &replace) {
ProviderReplacer repl(replace); ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt); Stmt ret = repl(stmt);
return repl.found ? ret : stmt; return repl.found ? ret : stmt;
} }
} // namespace op } // namespace op
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
#define TVM_OP_HYBRID_OP_H_ #define TVM_OP_HYBRID_OP_H_
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <string> #include <string>
#include "op_util.h" #include "op_util.h"
#include "../schedule/message_passing.h" #include "../schedule/message_passing.h"
...@@ -186,12 +186,12 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { ...@@ -186,12 +186,12 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
} }
// replacer to replace tensors // replacer to replace tensors
class TensorReplacer : public ir::IRMutator { class TensorReplacer : public ir::StmtExprMutator {
public: public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap) explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {} : vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) { Expr VisitExpr_(const ir::Call* op) final {
if (op->call_type == ir::Call::Halide) { if (op->call_type == ir::Call::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index); Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t); auto it = vmap_.find(t);
...@@ -200,10 +200,10 @@ class TensorReplacer : public ir::IRMutator { ...@@ -200,10 +200,10 @@ class TensorReplacer : public ir::IRMutator {
op->dtype, it->second->op->name, op->args, op->dtype, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index); op->call_type, it->second->op, it->second->value_index);
found = true; found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret); return this->VisitExpr(ret);
} }
} }
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
// whether it is found. // whether it is found.
...@@ -216,13 +216,13 @@ class TensorReplacer : public ir::IRMutator { ...@@ -216,13 +216,13 @@ class TensorReplacer : public ir::IRMutator {
Stmt ReplaceTensor(Stmt stmt, Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) { const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace); TensorReplacer repl(replace);
Stmt ret = repl.Mutate(stmt); Stmt ret = repl(stmt);
return repl.found ? ret : stmt; return repl.found ? ret : stmt;
} }
Expr ReplaceTensor(Expr expr, Expr ReplaceTensor(Expr expr,
const std::unordered_map<Tensor, Tensor>& replace) { const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace); TensorReplacer repl(replace);
Expr ret = repl.Mutate(expr); Expr ret = repl(expr);
return repl.found ? ret : expr; return repl.found ? ret : expr;
} }
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include "./op_util.h" #include "./op_util.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \file tensorize.cc * \file tensorize.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include "op_util.h" #include "op_util.h"
...@@ -157,10 +157,10 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, ...@@ -157,10 +157,10 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
} }
// Remap the tensor placeholder, index and inline things. // Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public IRMutator { class TensorIntrinMatcher final : public StmtExprMutator {
public: public:
Expr Mutate_(const Call* op, const Expr& e) final { Expr VisitExpr_(const Call* op) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
if (op->call_type == Call::Halide) { if (op->call_type == Call::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index); Tensor t = Downcast<Operation>(op->func).output(op->value_index);
...@@ -180,17 +180,17 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -180,17 +180,17 @@ class TensorIntrinMatcher final : public IRMutator {
return expr; return expr;
} }
Expr Mutate_(const Variable* op, const Expr& e) final { Expr VisitExpr_(const Variable* op) final {
auto it = var_remap_.find(op); auto it = var_remap_.find(op);
if (it != var_remap_.end()) { if (it != var_remap_.end()) {
return it->second; return it->second;
} else { } else {
return e; return GetRef<Expr>(op);
} }
} }
Expr Mutate_(const Reduce* op, const Expr& e) final { Expr VisitExpr_(const Reduce* op) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Reduce>(); op = expr.as<Reduce>();
Array<IterVar> axis; Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) { for (size_t i = 0; i < op->axis.size(); ++i) {
...@@ -317,7 +317,7 @@ Array<Expr> MatchTensorizeBody( ...@@ -317,7 +317,7 @@ Array<Expr> MatchTensorizeBody(
matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space); matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
Array<Expr> ret; Array<Expr> ret;
for (Expr expr : self->body) { for (Expr expr : self->body) {
ret.push_back(matcher.Mutate(expr)); ret.push_back(matcher(expr));
} }
return ret; return ret;
} }
......
...@@ -23,9 +23,8 @@ ...@@ -23,9 +23,8 @@
// Instrument checkers for out of the bounds access. // Instrument checkers for out of the bounds access.
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -33,48 +32,48 @@ ...@@ -33,48 +32,48 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class BoundCollector : public IRVisitor { class BoundCollector : public StmtVisitor {
public: public:
BoundCollector() {} BoundCollector() {}
void Visit_(const AttrStmt *op) { void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == ir::attr::buffer_bound) { if (op->attr_key == ir::attr::buffer_bound) {
if (const Variable *key = op->node.as<Variable>()) { if (const Variable *key = op->node.as<Variable>()) {
mem_to_shape[key] = op->value; mem_to_shape[key] = op->value;
} }
} }
IRVisitor::Visit_(op); StmtVisitor::VisitStmt_(op);
} }
// Hashtable which maps buffer_var to shape. // Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape; std::unordered_map<const Variable *, Expr> mem_to_shape;
}; };
class BoundChecker : public IRMutator { class BoundChecker : public StmtExprMutator {
public: public:
explicit BoundChecker( explicit BoundChecker(
const std::unordered_map<const Variable *, Expr> &mem_to_shape) const std::unordered_map<const Variable *, Expr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {} : mem_to_shape_(mem_to_shape) {}
Stmt Mutate_(const Allocate *op, const Stmt &s) final { Stmt VisitStmt_(const Allocate* op) final {
// If the shape was updated we should update the hashtable. // If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) { if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->dtype); Update(op->buffer_var, op->extents, op->dtype);
} }
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
Expr Mutate_(const Call *op, const Expr &ex) final { Expr VisitExpr_(const Call* op) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true; unsafe_rewritten_ = true;
} }
return IRMutator::Mutate_(op, ex); return StmtExprMutator::VisitExpr_(op);
} }
Stmt Mutate_(const Store *op, const Stmt &s) final { Stmt VisitStmt_(const Store* op) final {
store_scope_bound_collector_.clear(); store_scope_bound_collector_.clear();
process_store_ = true; process_store_ = true;
unsafe_rewritten_ = false; unsafe_rewritten_ = false;
IRMutator::Mutate_(op, s); StmtExprMutator::VisitStmt_(op);
process_store_ = false; process_store_ = false;
if (CanInstrument(op->index, op->buffer_var)) { if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var); Collect(op->index, op->buffer_var);
...@@ -92,23 +91,24 @@ class BoundChecker : public IRMutator { ...@@ -92,23 +91,24 @@ class BoundChecker : public IRMutator {
return body; return body;
} }
} }
return s; return GetRef<Stmt>(op);
} }
Expr Mutate_(const Load *op, const Expr &ex) final { Expr VisitExpr_(const Load* op) final {
if (CanInstrument(op->index, op->buffer_var)) { if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var); Collect(op->index, op->buffer_var);
} }
return IRMutator::Mutate_(op, ex); return StmtExprMutator::VisitExpr_(op);
} }
private: private:
bool UpdateIsNeeded(const VarExpr &buffer_var) const { bool UpdateIsNeeded(const VarExpr& buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
} }
void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape, void Update(const VarExpr& buffer_var,
const DataType &type) { const Array<Expr>& new_shape,
const DataType& type) {
// Sanity check at first. // Sanity check at first.
if (!new_shape.size()) { if (!new_shape.size()) {
return; return;
...@@ -132,7 +132,7 @@ class BoundChecker : public IRMutator { ...@@ -132,7 +132,7 @@ class BoundChecker : public IRMutator {
mem_to_shape_[buffer_var.get()] = shape; mem_to_shape_[buffer_var.get()] = shape;
} }
bool IndexIsValid(const Expr &index) const { bool IndexIsValid(const Expr& index) const {
if (!index.defined()) { if (!index.defined()) {
return false; return false;
} }
...@@ -146,7 +146,7 @@ class BoundChecker : public IRMutator { ...@@ -146,7 +146,7 @@ class BoundChecker : public IRMutator {
return true; return true;
} }
bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const { bool CanInstrument(const Expr& index, const VarExpr& buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
IndexIsValid(index) && !unsafe_rewritten_; IndexIsValid(index) && !unsafe_rewritten_;
} }
...@@ -206,8 +206,8 @@ class BoundChecker : public IRMutator { ...@@ -206,8 +206,8 @@ class BoundChecker : public IRMutator {
Stmt InstrumentBoundCheckers(Stmt stmt) { Stmt InstrumentBoundCheckers(Stmt stmt) {
BoundCollector bound_collector; BoundCollector bound_collector;
// At first walk recursively and collect bound attributes. // At first walk recursively and collect bound attributes.
bound_collector.Visit(stmt); bound_collector(stmt);
return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt); return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
} }
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
* \file combine_context_call.cc * \file combine_context_call.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <map> #include <map>
...@@ -32,7 +32,7 @@ namespace ir { ...@@ -32,7 +32,7 @@ namespace ir {
// Calculate the statistics of packed function. // Calculate the statistics of packed function.
// These information are needed during codegen. // These information are needed during codegen.
class ContextCallCombiner final : public IRMutator { class ContextCallCombiner final : public StmtExprMutator {
public: public:
struct CompareExpr { struct CompareExpr {
bool operator()(const Expr& lhs, const Expr& rhs) const { bool operator()(const Expr& lhs, const Expr& rhs) const {
...@@ -40,7 +40,7 @@ class ContextCallCombiner final : public IRMutator { ...@@ -40,7 +40,7 @@ class ContextCallCombiner final : public IRMutator {
} }
}; };
Expr Mutate_(const Call* op, const Expr& e) final { Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_thread_context)) { if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U); CHECK_EQ(op->args.size(), 1U);
Expr ctx = op->args[0]; Expr ctx = op->args[0];
...@@ -60,39 +60,39 @@ class ContextCallCombiner final : public IRMutator { ...@@ -60,39 +60,39 @@ class ContextCallCombiner final : public IRMutator {
return std::move(ctx_var); return std::move(ctx_var);
} }
} else { } else {
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::coproc_uop_scope) { op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable // Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp; std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_); std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_); std::swap(temp, ctx_map_);
return BuildContext(temp, stmt); return BuildContext(temp, stmt);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const For* op, const Stmt& s) final { Stmt VisitStmt_(const For* op) final {
if (op->for_type == ForType::Parallel) { if (op->for_type == ForType::Parallel) {
// Map of comparison expression to variable // Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp; std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_); std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_); std::swap(temp, ctx_map_);
return BuildContext(temp, stmt); return BuildContext(temp, stmt);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Combine(Stmt stmt) { Stmt Combine(Stmt stmt) {
return BuildContext(ctx_map_, this->Mutate(stmt)); return BuildContext(ctx_map_, this->VisitStmt(stmt));
} }
private: private:
......
...@@ -22,8 +22,7 @@ ...@@ -22,8 +22,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h" #include "ir_util.h"
...@@ -33,25 +32,25 @@ namespace tvm { ...@@ -33,25 +32,25 @@ namespace tvm {
namespace ir { namespace ir {
// Visitor to find touched set by co-processor scope. // Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public IRVisitor { class CoProcTouchedBuffer : public StmtExprVisitor {
public: public:
void Visit_(const Load* op) final { void VisitExpr_(const Load* op) final {
if (in_scope_) { if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true; touched_[op->buffer_var.get()].coproc = true;
} else { } else {
touched_[op->buffer_var.get()].normal = true; touched_[op->buffer_var.get()].normal = true;
} }
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} }
void Visit_(const Store* op) final { void VisitStmt_(const Store* op) final {
if (in_scope_) { if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true; touched_[op->buffer_var.get()].coproc = true;
} else { } else {
touched_[op->buffer_var.get()].normal = true; touched_[op->buffer_var.get()].normal = true;
} }
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
void Visit_(const Call* op) final { void VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
const Variable* buffer = op->args[1].as<Variable>(); const Variable* buffer = op->args[1].as<Variable>();
if (in_scope_) { if (in_scope_) {
...@@ -60,17 +59,17 @@ class CoProcTouchedBuffer : public IRVisitor { ...@@ -60,17 +59,17 @@ class CoProcTouchedBuffer : public IRVisitor {
touched_[buffer].normal = true; touched_[buffer].normal = true;
} }
} }
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} }
void Visit_(const AttrStmt* op) final { void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) { if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true; in_scope_ = true;
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
coproc_.insert(iv); coproc_.insert(iv);
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
in_scope_ = false; in_scope_ = false;
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
} }
...@@ -96,7 +95,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { ...@@ -96,7 +95,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
} }
void Plan(const Stmt& stmt) { void Plan(const Stmt& stmt) {
this->Visit(stmt); this->VisitStmt(stmt);
PlanSync(scope_.back(), nullptr, true); PlanSync(scope_.back(), nullptr, true);
if (sync_.size() == 0) { if (sync_.size() == 0) {
sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync"); sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync");
...@@ -218,14 +217,14 @@ class CoProcBarrierDetector : public StorageAccessVisitor { ...@@ -218,14 +217,14 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
write_barrier_name_ = coproc_name + ".coproc_write_barrier"; write_barrier_name_ = coproc_name + ".coproc_write_barrier";
} }
void PlanReadBarrier(Stmt stmt) { void PlanReadBarrier(const Stmt& stmt) {
read_barrier_ = true; read_barrier_ = true;
this->Visit(stmt); this->VisitStmt(stmt);
PlanReadBarrier(scope_.back(), nullptr); PlanReadBarrier(scope_.back(), nullptr);
} }
void PlanWriteBarrier(Stmt stmt) { void PlanWriteBarrier(const Stmt& stmt) {
read_barrier_ = false; read_barrier_ = false;
this->Visit(stmt); this->VisitStmt(stmt);
PlanWriteBarrier(scope_.back(), nullptr); PlanWriteBarrier(scope_.back(), nullptr);
} }
...@@ -356,7 +355,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { ...@@ -356,7 +355,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
}; };
class CoProcInstDepDetector : public IRVisitor { class CoProcInstDepDetector : public StmtVisitor {
public: public:
explicit CoProcInstDepDetector( explicit CoProcInstDepDetector(
const IterVar& coproc_axis, const IterVar& coproc_axis,
...@@ -366,15 +365,15 @@ class CoProcInstDepDetector : public IRVisitor { ...@@ -366,15 +365,15 @@ class CoProcInstDepDetector : public IRVisitor {
sync_pop_name_ = coproc_name + ".coproc_dep_pop"; sync_pop_name_ = coproc_name + ".coproc_dep_pop";
} }
void Plan(Stmt stmt) { void Plan(const Stmt& stmt) {
this->Visit(stmt); this->VisitStmt(stmt);
if (last_state_.node != nullptr) { if (last_state_.node != nullptr) {
MatchFixEnterPop(first_state_); MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_); MatchFixExitPush(last_state_);
} }
} }
void Visit_(const AttrStmt* op) final { void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && if (op->attr_key == attr::coproc_scope &&
op->node.same_as(coproc_axis_)) { op->node.same_as(coproc_axis_)) {
const IntImm* ctx_id = op->value.as<IntImm>(); const IntImm* ctx_id = op->value.as<IntImm>();
...@@ -385,15 +384,15 @@ class CoProcInstDepDetector : public IRVisitor { ...@@ -385,15 +384,15 @@ class CoProcInstDepDetector : public IRVisitor {
curr_state_.exit_ctx.insert(ctx_id->value); curr_state_.exit_ctx.insert(ctx_id->value);
UpdateState(); UpdateState();
} else { } else {
IRVisitor::Visit_(op); StmtVisitor::VisitStmt_(op);
} }
} }
void Visit_(const For* op) final { void VisitStmt_(const For* op) final {
SyncState temp_first, temp_last; SyncState temp_first, temp_last;
std::swap(first_state_, temp_first); std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last); std::swap(last_state_, temp_last);
this->Visit(op->body); this->VisitStmt(op->body);
curr_state_.clear(); curr_state_.clear();
if (last_state_.node != nullptr) { if (last_state_.node != nullptr) {
curr_state_.node = op; curr_state_.node = op;
...@@ -412,13 +411,13 @@ class CoProcInstDepDetector : public IRVisitor { ...@@ -412,13 +411,13 @@ class CoProcInstDepDetector : public IRVisitor {
} }
} }
void Visit_(const IfThenElse* op) final { void VisitStmt_(const IfThenElse* op) final {
SyncState temp_first, temp_last, curr_state; SyncState temp_first, temp_last, curr_state;
std::swap(first_state_, temp_first); std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last); std::swap(last_state_, temp_last);
{ {
// then stmt // then stmt
this->Visit(op->then_case); this->VisitStmt(op->then_case);
if (last_state_.node != nullptr) { if (last_state_.node != nullptr) {
curr_state.node = op; curr_state.node = op;
MatchFixEnterPop(first_state_); MatchFixEnterPop(first_state_);
...@@ -434,7 +433,7 @@ class CoProcInstDepDetector : public IRVisitor { ...@@ -434,7 +433,7 @@ class CoProcInstDepDetector : public IRVisitor {
last_state_.clear(); last_state_.clear();
} }
if (op->else_case.defined()) { if (op->else_case.defined()) {
this->Visit(op->else_case); this->VisitStmt(op->else_case);
if (last_state_.node != nullptr) { if (last_state_.node != nullptr) {
curr_state.node = op; curr_state.node = op;
MatchFixEnterPop(first_state_); MatchFixEnterPop(first_state_);
...@@ -606,11 +605,11 @@ class CoProcInstDepDetector : public IRVisitor { ...@@ -606,11 +605,11 @@ class CoProcInstDepDetector : public IRVisitor {
}; };
class CoProcSyncInserter : public IRMutator { class CoProcSyncInserter : public StmtMutator {
public: public:
Stmt Insert(Stmt stmt) { Stmt Insert(Stmt stmt) {
CoProcTouchedBuffer visitor; CoProcTouchedBuffer visitor;
visitor.Visit(stmt); visitor(stmt);
if (visitor.coproc_.size() == 0) return stmt; if (visitor.coproc_.size() == 0) return stmt;
std::unordered_set<const Variable*> touched; std::unordered_set<const Variable*> touched;
...@@ -652,10 +651,10 @@ class CoProcSyncInserter : public IRMutator { ...@@ -652,10 +651,10 @@ class CoProcSyncInserter : public IRMutator {
auto& vec = insert_after_[kv.first]; auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end()); vec.insert(vec.end(), kv.second.begin(), kv.second.end());
} }
return Mutate(stmt); return operator()(std::move(stmt));
} }
Stmt Mutate(Stmt stmt) final { Stmt VisitStmt(const Stmt& stmt) final {
Stmt before, after; Stmt before, after;
auto it = insert_before_.find(stmt.get()); auto it = insert_before_.find(stmt.get());
if (it != insert_before_.end()) { if (it != insert_before_.end()) {
...@@ -666,14 +665,14 @@ class CoProcSyncInserter : public IRMutator { ...@@ -666,14 +665,14 @@ class CoProcSyncInserter : public IRMutator {
if (it != insert_after_.end()) { if (it != insert_after_.end()) {
after = MergeSeq(it->second); after = MergeSeq(it->second);
} }
stmt = IRMutator::Mutate(stmt); Stmt new_stmt = StmtMutator::VisitStmt(stmt);
if (before.defined()) { if (before.defined()) {
stmt = Block::make(before, stmt); new_stmt = Block::make(before, new_stmt);
} }
if (after.defined()) { if (after.defined()) {
stmt = Block::make(stmt, after); new_stmt = Block::make(new_stmt, after);
} }
return stmt; return new_stmt;
} }
private: private:
...@@ -685,7 +684,7 @@ class CoProcSyncInserter : public IRMutator { ...@@ -685,7 +684,7 @@ class CoProcSyncInserter : public IRMutator {
Stmt CoProcSync(Stmt stmt) { Stmt CoProcSync(Stmt stmt) {
return CoProcSyncInserter().Insert(stmt); return CoProcSyncInserter().Insert(std::move(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
namespace tvm { namespace tvm {
......
...@@ -21,9 +21,7 @@ ...@@ -21,9 +21,7 @@
* \file hoist_if_then_else.cc * \file hoist_if_then_else.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <unordered_map> #include <unordered_map>
......
...@@ -23,8 +23,7 @@ ...@@ -23,8 +23,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h" #include "ir_util.h"
...@@ -35,7 +34,7 @@ namespace tvm { ...@@ -35,7 +34,7 @@ namespace tvm {
namespace ir { namespace ir {
// Get fragment information from tensor intrinsics // Get fragment information from tensor intrinsics
class FragmentGetter : public IRVisitor { class FragmentGetter : public StmtExprVisitor {
public: public:
// fragment metadata // fragment metadata
struct FragmentInfo { struct FragmentInfo {
...@@ -48,8 +47,8 @@ class FragmentGetter : public IRVisitor { ...@@ -48,8 +47,8 @@ class FragmentGetter : public IRVisitor {
: m(_m), n(_n), k(_k), layout(_layout) {} : m(_m), n(_n), k(_k), layout(_layout) {}
}; };
void Visit_(const Call* op) final { void VisitExpr_(const Call* op) final {
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
...@@ -116,13 +115,13 @@ class FragmentGetter : public IRVisitor { ...@@ -116,13 +115,13 @@ class FragmentGetter : public IRVisitor {
} }
// Get memory scope // Get memory scope
void Visit_(const AttrStmt* op) final { void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Variable* buffer = op->node.as<Variable>(); const Variable* buffer = op->node.as<Variable>();
CHECK(buffer); CHECK(buffer);
scopes[buffer] = op->value.as<StringImm>()->value; scopes[buffer] = op->value.as<StringImm>()->value;
} }
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
// Memory scope for allocations // Memory scope for allocations
...@@ -132,11 +131,12 @@ class FragmentGetter : public IRVisitor { ...@@ -132,11 +131,12 @@ class FragmentGetter : public IRVisitor {
}; };
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync // Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public IRVisitor { class FragmentChecker : public StmtExprVisitor {
public: public:
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
void Visit_(const Call* op) final { void VisitExpr_(const Call* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync // Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
CHECK_EQ(op->args.size(), 8U); CHECK_EQ(op->args.size(), 8U);
...@@ -170,12 +170,12 @@ class FragmentChecker : public IRVisitor { ...@@ -170,12 +170,12 @@ class FragmentChecker : public IRVisitor {
}; };
// Store the metadata into attributes // Store the metadata into attributes
class InferFragmenter : public IRMutator { class InferFragmenter : public StmtMutator {
public: public:
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
const Variable* buffer = op->buffer_var.get(); const Variable* buffer = op->buffer_var.get();
if (fragment_getter.fragments.count(buffer)) { if (fragment_getter.fragments.count(buffer)) {
// Add attribute to fragments allocation // Add attribute to fragments allocation
...@@ -206,9 +206,10 @@ class InferFragmenter : public IRMutator { ...@@ -206,9 +206,10 @@ class InferFragmenter : public IRMutator {
Stmt InferFragment(Stmt stmt) { Stmt InferFragment(Stmt stmt) {
FragmentGetter getter; FragmentGetter getter;
getter.Visit(stmt); getter(stmt);
FragmentChecker(getter).Visit(stmt); FragmentChecker checker(getter);
stmt = InferFragmenter(getter).Mutate(stmt); checker(stmt);
stmt = InferFragmenter(getter)(std::move(stmt));
return stmt; return stmt;
} }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "../arithmetic/pattern_match.h" #include "../arithmetic/pattern_match.h"
...@@ -32,7 +32,7 @@ namespace ir { ...@@ -32,7 +32,7 @@ namespace ir {
using runtime::PackedFunc; using runtime::PackedFunc;
class CopyIntrinInjector : public IRMutator { class CopyIntrinInjector : public StmtMutator {
public: public:
CopyIntrinInjector(const std::string& pragma_key, CopyIntrinInjector(const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) const PackedFunc& flower_copy_fromto)
...@@ -40,7 +40,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -40,7 +40,7 @@ class CopyIntrinInjector : public IRMutator {
flower_copy_fromto_(flower_copy_fromto) { flower_copy_fromto_(flower_copy_fromto) {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = op->value.as<StringImm>()->value; storage_scope_[buf] = op->value.as<StringImm>()->value;
...@@ -50,7 +50,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -50,7 +50,7 @@ class CopyIntrinInjector : public IRMutator {
<< "Cannot match copy pattern of " << op->body; << "Cannot match copy pattern of " << op->body;
return ret; return ret;
} }
return IRMutator::Mutate_(op, s); return StmtMutator::VisitStmt_(op);
} }
private: private:
...@@ -193,8 +193,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -193,8 +193,7 @@ class CopyIntrinInjector : public IRMutator {
Stmt InjectCopyIntrin(Stmt stmt, Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key, const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) { const PackedFunc& flower_copy_fromto) {
return CopyIntrinInjector(pragma_key, flower_copy_fromto) return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
.Mutate(stmt);
} }
} // namespace ir } // namespace ir
......
...@@ -22,8 +22,7 @@ ...@@ -22,8 +22,7 @@
* \file inject_double_buffer.cc * \file inject_double_buffer.cc
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include "ir_util.h" #include "ir_util.h"
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
...@@ -32,18 +31,18 @@ namespace tvm { ...@@ -32,18 +31,18 @@ namespace tvm {
namespace ir { namespace ir {
// Detect double buffer variables. // Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor { class DoubleBufferDetector : public StmtExprVisitor {
public: public:
void Visit_(const AttrStmt* op) final { void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_scope) { if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<Variable>()); touched_.insert(op->node.as<Variable>());
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
} }
void Visit_(const Variable* op) final { void VisitExpr_(const Variable* op) final {
if (touched_.count(op)) { if (touched_.count(op)) {
touched_.erase(op); touched_.erase(op);
} }
...@@ -53,55 +52,55 @@ class DoubleBufferDetector : public IRVisitor { ...@@ -53,55 +52,55 @@ class DoubleBufferDetector : public IRVisitor {
}; };
class StripDoubleBufferWrite : public IRMutator { class StripDoubleBufferWrite : public StmtMutator {
public: public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_write) { if (op->attr_key == attr::double_buffer_write) {
return Mutate(op->body); return VisitStmt(op->body);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtMutator::VisitStmt_(op);
} }
} }
}; };
class DoubleBufferInjector : public IRMutator { class DoubleBufferInjector : public StmtExprMutator {
public: public:
explicit DoubleBufferInjector(int split_loop) explicit DoubleBufferInjector(int split_loop)
: split_loop_(split_loop) {} : split_loop_(split_loop) {}
Stmt Inject(const Stmt& stmt) { Stmt Inject(Stmt stmt) {
DoubleBufferDetector detector; DoubleBufferDetector detector;
detector.Visit(stmt); detector(stmt);
if (detector.touched_.empty()) return stmt; if (detector.touched_.empty()) return stmt;
for (const Variable* v : detector.touched_) { for (const Variable* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry(); dbuffer_info_[v] = StorageEntry();
} }
return ConvertSSA(this->Mutate(stmt)); return ConvertSSA(operator()(std::move(stmt)));
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
auto it = dbuffer_info_.find(buf); auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
it->second.scope = op->value.as<StringImm>()->value; it->second.scope = op->value.as<StringImm>()->value;
return Mutate(op->body); return this->VisitStmt(op->body);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} else if (op->attr_key == attr::double_buffer_scope) { } else if (op->attr_key == attr::double_buffer_scope) {
return MakeProducer(op, s); return MakeProducer(op);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
auto it = dbuffer_info_.find(op->buffer_var.get()); auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul>( it->second.stride = arith::ComputeReduce<Mul>(
op->extents, Expr()) * op->dtype.lanes(); op->extents, Expr()) * op->dtype.lanes();
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)}; Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
for (Expr e : op->extents) { for (Expr e : op->extents) {
...@@ -118,13 +117,13 @@ class DoubleBufferInjector : public IRMutator { ...@@ -118,13 +117,13 @@ class DoubleBufferInjector : public IRMutator {
Evaluate::make(0))); Evaluate::make(0)));
return op->body; return op->body;
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const For* op, const Stmt& s) final { Stmt VisitStmt_(const For* op) final {
loop_nest_.push_back(op); loop_nest_.push_back(op);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = loop_pre_.find(op); auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) { if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>(); const For* old_loop = stmt.as<For>();
...@@ -151,7 +150,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -151,7 +150,7 @@ class DoubleBufferInjector : public IRMutator {
MergeSeq(loop_seq)); MergeSeq(loop_seq));
// tail // tail
std::vector<Stmt> tail_seq; std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body); Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) { for (int32_t i = 0; i < split_loop_; ++i) {
Expr idx = tail_base + make_const(tail_base.dtype(), i); Expr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx; vmap[old_loop->loop_var.get()] = idx;
...@@ -171,8 +170,8 @@ class DoubleBufferInjector : public IRMutator { ...@@ -171,8 +170,8 @@ class DoubleBufferInjector : public IRMutator {
return stmt; return stmt;
} }
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt VisitStmt_(const Store* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>(); op = stmt.as<Store>();
auto it = dbuffer_info_.find(op->buffer_var.get()); auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
...@@ -188,8 +187,8 @@ class DoubleBufferInjector : public IRMutator { ...@@ -188,8 +187,8 @@ class DoubleBufferInjector : public IRMutator {
} }
} }
Expr Mutate_(const Load* op, const Expr& e) final { Expr VisitExpr_(const Load* op) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>(); op = expr.as<Load>();
auto it = dbuffer_info_.find(op->buffer_var.get()); auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
...@@ -205,20 +204,20 @@ class DoubleBufferInjector : public IRMutator { ...@@ -205,20 +204,20 @@ class DoubleBufferInjector : public IRMutator {
} }
} }
Expr Mutate_(const Variable* op, const Expr& e) final { Expr VisitExpr_(const Variable* op) final {
CHECK(!dbuffer_info_.count(op)); CHECK(!dbuffer_info_.count(op));
return e; return GetRef<Expr>(op);
} }
private: private:
Stmt MakeProducer(const AttrStmt* op, const Stmt& s) { Stmt MakeProducer(const AttrStmt* op) {
const VarExpr buffer = Downcast<VarExpr>(op->node); const VarExpr buffer = Downcast<VarExpr>(op->node);
CHECK_NE(loop_nest_.size(), 0U) CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop"; << "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get()); auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) { if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node; LOG(WARNING) << "Skip double buffer scope " << op->node;
return Mutate(op->body); return this->VisitStmt(op->body);
} }
StorageEntry& e = it->second; StorageEntry& e = it->second;
e.loop = loop_nest_.back(); e.loop = loop_nest_.back();
...@@ -230,7 +229,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -230,7 +229,7 @@ class DoubleBufferInjector : public IRMutator {
e.loop->loop_var.dtype()); e.loop->loop_var.dtype());
e.switch_read_var = indexmod(e.loop->loop_var, two); e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true; in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body); Stmt body = this->VisitStmt(op->body);
in_double_buffer_scope_ = false; in_double_buffer_scope_ = false;
std::unordered_map<const Variable*, Expr> vmap; std::unordered_map<const Variable*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero; vmap[e.switch_write_var.get()] = zero;
......
...@@ -22,8 +22,7 @@ ...@@ -22,8 +22,7 @@
*/ */
// Inject prefetch op in HalideIR // Inject prefetch op in HalideIR
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <unordered_set> #include <unordered_set>
...@@ -34,10 +33,10 @@ namespace ir { ...@@ -34,10 +33,10 @@ namespace ir {
using arith::IntSet; using arith::IntSet;
using arith::DomainTouched; using arith::DomainTouched;
class PrefetchInjector : public IRMutator { class PrefetchInjector : public StmtMutator {
public: public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
Stmt ret = IRMutator::Mutate_(op, s); Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmt>(); op = ret.as<AttrStmt>();
if (op && op->attr_key == attr::prefetch_scope) { if (op && op->attr_key == attr::prefetch_scope) {
Tensor ts = Downcast<Tensor>(op->node); Tensor ts = Downcast<Tensor>(op->node);
...@@ -65,13 +64,13 @@ class PrefetchInjector : public IRMutator { ...@@ -65,13 +64,13 @@ class PrefetchInjector : public IRMutator {
return ret; return ret;
} }
Stmt Mutate_(const For* op, const Stmt& s) final { Stmt VisitStmt_(const For* op) final {
auto &var = op->loop_var; auto &var = op->loop_var;
loop_nest_.push_back(var); loop_nest_.push_back(var);
if (op->for_type == ForType::Vectorized) { if (op->for_type == ForType::Vectorized) {
vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1); vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1);
} }
Stmt ret = IRMutator::Mutate_(op, s); Stmt ret = StmtMutator::VisitStmt_(op);
if (op->for_type == ForType::Vectorized) { if (op->for_type == ForType::Vectorized) {
vectorized_.erase(var.get()); vectorized_.erase(var.get());
} }
...@@ -88,7 +87,7 @@ class PrefetchInjector : public IRMutator { ...@@ -88,7 +87,7 @@ class PrefetchInjector : public IRMutator {
const Range PrefetchInjector::none; const Range PrefetchInjector::none;
Stmt InjectPrefetch(Stmt stmt) { Stmt InjectPrefetch(Stmt stmt) {
return PrefetchInjector().Mutate(stmt); return PrefetchInjector()(std::move(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* \file inline.cc * \file inline.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -30,13 +30,13 @@ namespace ir { ...@@ -30,13 +30,13 @@ namespace ir {
// inliner to inline a function // inliner to inline a function
// the result may not be SSA, // the result may not be SSA,
// ConvertSSA need to be applied after this pass // ConvertSSA need to be applied after this pass
class IRInline final : public IRMutator { class IRInline final : public StmtExprMutator {
public: public:
IRInline(FunctionRef f, Array<Var> args, Expr body) IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {} : f_(f), args_(args), body_(body) {}
Expr Mutate_(const Call* op, const Expr& e) final { Expr VisitExpr_(const Call* op) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
if (op->func == f_) { if (op->func == f_) {
...@@ -78,7 +78,7 @@ Stmt Inline(Stmt stmt, ...@@ -78,7 +78,7 @@ Stmt Inline(Stmt stmt,
Expr body) { Expr body) {
CHECK_EQ(f->num_outputs(), 1) CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation"; << "can only inline output single value operation";
Stmt ret = IRInline(f, args, body).Mutate(stmt); Stmt ret = IRInline(f, args, body)(std::move(stmt));
if (ret.same_as(stmt)) return ret; if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret); return ConvertSSA(ret);
} }
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
*/ */
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/ir_visitor.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
* \file lift_attr_scope.cc * \file lift_attr_scope.cc
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include "ir_util.h" #include "ir_util.h"
namespace tvm { namespace tvm {
...@@ -32,13 +32,13 @@ namespace ir { ...@@ -32,13 +32,13 @@ namespace ir {
// NOTE: this optimization can only be applied // NOTE: this optimization can only be applied
// to a few specified attr keys // to a few specified attr keys
class AttrScopeLifter : public IRMutator { class AttrScopeLifter : public StmtMutator {
public: public:
explicit AttrScopeLifter(std::string attr_key) explicit AttrScopeLifter(std::string attr_key)
: attr_key_(attr_key) {} : attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) { Stmt Lift(Stmt stmt) {
stmt = Mutate(stmt); stmt = operator()(std::move(stmt));
if (attr_node_.defined()) { if (attr_node_.defined()) {
stmt = AttrStmt::make( stmt = AttrStmt::make(
attr_node_, attr_key_, attr_value_, stmt); attr_node_, attr_key_, attr_value_, stmt);
...@@ -47,8 +47,8 @@ class AttrScopeLifter : public IRMutator { ...@@ -47,8 +47,8 @@ class AttrScopeLifter : public IRMutator {
} }
// do not go beyond // do not go beyond
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
if (attr_node_.defined()) { if (attr_node_.defined()) {
Stmt body = AttrStmt::make( Stmt body = AttrStmt::make(
...@@ -65,17 +65,17 @@ class AttrScopeLifter : public IRMutator { ...@@ -65,17 +65,17 @@ class AttrScopeLifter : public IRMutator {
} }
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr_key_) { if (op->attr_key == attr_key_) {
attr_node_ = op->node; attr_node_ = op->node;
attr_value_ = op->value; attr_value_ = op->value;
return op->body; return op->body;
} else { } else {
return IRMutator::Mutate_(op, s); return StmtMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const Block* op, const Stmt& s) final { Stmt VisitStmt_(const Block* op) final {
std::vector<Stmt> seq; std::vector<Stmt> seq;
FlattenSeq(op->first, &seq); FlattenSeq(op->first, &seq);
FlattenSeq(op->rest, &seq); FlattenSeq(op->rest, &seq);
...@@ -83,21 +83,21 @@ class AttrScopeLifter : public IRMutator { ...@@ -83,21 +83,21 @@ class AttrScopeLifter : public IRMutator {
if (seq.size() == 2 && if (seq.size() == 2 &&
seq[0].same_as(op->first) && seq[0].same_as(op->first) &&
seq[1].same_as(op->rest)) { seq[1].same_as(op->rest)) {
return s; return GetRef<Stmt>(op);
} }
return MergeSeq(seq); return MergeSeq(seq);
} }
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { Stmt VisitStmt_(const IfThenElse* op) final {
if (!op->else_case.defined()) { if (!op->else_case.defined()) {
return IRMutator::Mutate_(op, s); return StmtMutator::VisitStmt_(op);
} }
Stmt then_case = this->Mutate(op->then_case); Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node; ObjectRef first_node;
Expr first_value; Expr first_value;
std::swap(first_node, attr_node_); std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_); std::swap(first_value, attr_value_);
Stmt else_case = this->Mutate(op->else_case); Stmt else_case = this->VisitStmt(op->else_case);
if (attr_node_.defined() && if (attr_node_.defined() &&
attr_value_.defined() && attr_value_.defined() &&
first_node.defined() && first_node.defined() &&
...@@ -106,7 +106,7 @@ class AttrScopeLifter : public IRMutator { ...@@ -106,7 +106,7 @@ class AttrScopeLifter : public IRMutator {
ValueSame(attr_value_, first_value)) { ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) && if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return s; return GetRef<Stmt>(op);
} else { } else {
return IfThenElse::make(op->condition, then_case, else_case); return IfThenElse::make(op->condition, then_case, else_case);
} }
...@@ -124,7 +124,7 @@ class AttrScopeLifter : public IRMutator { ...@@ -124,7 +124,7 @@ class AttrScopeLifter : public IRMutator {
} }
if (then_case.same_as(op->then_case) && if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return s; return GetRef<Stmt>(op);
} else { } else {
return IfThenElse::make(op->condition, then_case, else_case); return IfThenElse::make(op->condition, then_case, else_case);
} }
...@@ -155,7 +155,7 @@ class AttrScopeLifter : public IRMutator { ...@@ -155,7 +155,7 @@ class AttrScopeLifter : public IRMutator {
for (const Stmt & stmt : seq) { for (const Stmt & stmt : seq) {
attr_node_ = ObjectRef(); attr_node_ = ObjectRef();
attr_value_ = Expr(); attr_value_ = Expr();
Stmt rest = this->Mutate(stmt); Stmt rest = this->VisitStmt(stmt);
if (attr_node_.defined() && if (attr_node_.defined() &&
attr_value_.defined() && attr_value_.defined() &&
curr_node.defined() && curr_node.defined() &&
...@@ -214,7 +214,7 @@ class AttrScopeLifter : public IRMutator { ...@@ -214,7 +214,7 @@ class AttrScopeLifter : public IRMutator {
}; };
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(stmt); return AttrScopeLifter(attr_key).Lift(std::move(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \brief Pass for lowering custom datatypes * \brief Pass for lowering custom datatypes
*/ */
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include "../codegen/datatype/registry.h" #include "../codegen/datatype/registry.h"
...@@ -37,17 +37,17 @@ namespace ir { ...@@ -37,17 +37,17 @@ namespace ir {
* datatype) for lowering this type of expression, and uses it to lower the * datatype) for lowering this type of expression, and uses it to lower the
* expression. * expression.
*/ */
class CustomDatatypesLowerer : public IRMutator { class CustomDatatypesLowerer : public StmtExprMutator {
public: public:
explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
inline Expr Mutate_(const Cast* op, const Expr& e) final { inline Expr VisitExpr_(const Cast* op) final {
auto type_code = op->dtype.code(); auto type_code = op->dtype.code();
auto src_type_code = op->value.dtype().code(); auto src_type_code = op->value.dtype().code();
// If either datatype is a registered custom datatype, we must lower. // If either datatype is a registered custom datatype, we must lower.
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
datatype::Registry::Global()->GetTypeRegistered(src_type_code); datatype::Registry::Global()->GetTypeRegistered(src_type_code);
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Cast>(); op = expr.as<Cast>();
if (toBeLowered) { if (toBeLowered) {
auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
...@@ -59,8 +59,9 @@ class CustomDatatypesLowerer : public IRMutator { ...@@ -59,8 +59,9 @@ class CustomDatatypesLowerer : public IRMutator {
return expr; return expr;
} }
inline Expr Mutate_(const FloatImm* imm, const Expr& e) final { inline Expr VisitExpr_(const FloatImm* imm) final {
auto type_code = imm->dtype.code(); auto type_code = imm->dtype.code();
auto e = GetRef<Expr>(imm);
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
CHECK(lower) << "FloatImm lowering function for target " << target_ << " type " CHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
...@@ -70,9 +71,9 @@ class CustomDatatypesLowerer : public IRMutator { ...@@ -70,9 +71,9 @@ class CustomDatatypesLowerer : public IRMutator {
return e; return e;
} }
inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final { inline Stmt VisitStmt_(const Allocate* allocate) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code()); bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
Stmt stmt = IRMutator::Mutate_(allocate, s); Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
allocate = stmt.as<Allocate>(); allocate = stmt.as<Allocate>();
if (toBeLowered) { if (toBeLowered) {
...@@ -84,9 +85,9 @@ class CustomDatatypesLowerer : public IRMutator { ...@@ -84,9 +85,9 @@ class CustomDatatypesLowerer : public IRMutator {
return stmt; return stmt;
} }
inline Expr Mutate_(const Load* load, const Expr& e) final { inline Expr VisitExpr_(const Load* load) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
Expr expr = IRMutator::Mutate_(load, e); Expr expr = StmtExprMutator::VisitExpr_(load);
load = expr.as<Load>(); load = expr.as<Load>();
if (toBeLowered) { if (toBeLowered) {
auto new_load_type = DataType::UInt(load->dtype.bits()); auto new_load_type = DataType::UInt(load->dtype.bits());
...@@ -96,10 +97,10 @@ class CustomDatatypesLowerer : public IRMutator { ...@@ -96,10 +97,10 @@ class CustomDatatypesLowerer : public IRMutator {
} }
#define DEFINE_MUTATE__(OP) \ #define DEFINE_MUTATE__(OP) \
inline Expr Mutate_(const OP* op, const Expr& e) final { \ inline Expr VisitExpr_(const OP* op) final { \
auto type_code = op->dtype.code(); \ auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
Expr expr = IRMutator::Mutate_(op, e); \ Expr expr = StmtExprMutator::VisitExpr_(op); \
op = expr.as<OP>(); \ op = expr.as<OP>(); \
if (toBeLowered) { \ if (toBeLowered) { \
auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
...@@ -131,7 +132,7 @@ class CustomDatatypesLowerer : public IRMutator { ...@@ -131,7 +132,7 @@ class CustomDatatypesLowerer : public IRMutator {
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->()); auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = CustomDatatypesLowerer(target).Mutate(n->body); n->body = CustomDatatypesLowerer(target)(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* \file lower_intrin.cc * \file lower_intrin.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
...@@ -34,9 +33,10 @@ ...@@ -34,9 +33,10 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class IntrinInjecter : public arith::IRMutatorWithAnalyzer { class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public: public:
using IRMutatorWithAnalyzer::Mutate_; using IRMutatorWithAnalyzer::VisitStmt_;
using IRMutatorWithAnalyzer::VisitExpr_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target) IntrinInjecter(arith::Analyzer* analyzer, std::string target)
: IRMutatorWithAnalyzer(analyzer) { : IRMutatorWithAnalyzer(analyzer) {
...@@ -51,28 +51,29 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -51,28 +51,29 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
} }
} }
Expr Mutate_(const Call* op, const Expr& e) final { Expr VisitExpr_(const Call* op) final {
if (op->call_type == Call::Intrinsic || if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) { op->call_type == Call::PureIntrinsic) {
Expr r = ApplyPattern(op->name, e); Expr r = ApplyPattern(op->name, GetRef<Expr>(op));
if (r.defined()) return r; if (r.defined()) return r;
} }
return IRMutator::Mutate_(op, e); return IRMutatorWithAnalyzer::VisitExpr_(op);
} }
Expr Mutate_(const Add* op, const Expr& e) final { Expr VisitExpr_(const Add* op) final {
if (const Mul* mb = op->b.as<Mul>()) { if (const Mul* mb = op->b.as<Mul>()) {
return MakeFMA(mb->a, mb->b, op->a, op, e); return MakeFMA(mb->a, mb->b, op->a, op);
} else if (const Mul* ma = op->a.as<Mul>()) { } else if (const Mul* ma = op->a.as<Mul>()) {
return MakeFMA(ma->a, ma->b, op->b, op, e); return MakeFMA(ma->a, ma->b, op->b, op);
} }
return IRMutator::Mutate_(op, e); return IRMutatorWithAnalyzer::VisitExpr_(op);
} }
// We use floordiv for integer analysis, // We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions // but will need to lower them to native truncdiv instructions
Expr Mutate_(const FloorDiv* op, const Expr& e) final { Expr VisitExpr_(const FloorDiv* op) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e); auto e = GetRef<Expr>(op);
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDiv>(); op = ret.as<FloorDiv>();
if (op == nullptr) return ret; if (op == nullptr) return ret;
int shift; int shift;
...@@ -117,8 +118,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -117,8 +118,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
} }
} }
Expr Mutate_(const FloorMod* op, const Expr& e) final { Expr VisitExpr_(const FloorMod* op) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorMod>(); op = ret.as<FloorMod>();
if (op == nullptr) return ret; if (op == nullptr) return ret;
// Lower floordiv to native truncdiv. // Lower floordiv to native truncdiv.
...@@ -167,34 +168,37 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -167,34 +168,37 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
} }
} }
Expr Mutate_(const Max* op, const Expr& e) final { Expr VisitExpr_(const Max* op) final {
using namespace arith; using namespace arith;
PVar<Expr> x, y; PVar<Expr> x, y;
PVar<Integer> c; PVar<Integer> c;
auto e = GetRef<Expr>(op);
if (max(floordiv(x, y), c).Match(e) && if (max(floordiv(x, y), c).Match(e) &&
c.Eval()->value >= 0 && c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(Mutate(truncdiv(x, y).Eval()), c.Eval()); return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
} }
return IRMutatorWithAnalyzer::Mutate_(op, e); return IRMutatorWithAnalyzer::VisitExpr_(op);
} }
Expr Mutate_(const EQ* op, const Expr& e) final { Expr VisitExpr_(const EQ* op) final {
using namespace arith; using namespace arith;
PVar<Expr> x, y; PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
if ((floormod(x, y) == 0).Match(e)) { if ((floormod(x, y) == 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval()); return VisitExpr((truncmod(x, y) == 0).Eval());
} }
return IRMutatorWithAnalyzer::Mutate_(op, e); return IRMutatorWithAnalyzer::VisitExpr_(op);
} }
Expr Mutate_(const NE* op, const Expr& e) final { Expr VisitExpr_(const NE* op) final {
using namespace arith; using namespace arith;
PVar<Expr> x, y; PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
if ((floormod(x, y) != 0).Match(e)) { if ((floormod(x, y) != 0).Match(e)) {
return Mutate((truncmod(x, y) != 0).Eval()); return VisitExpr((truncmod(x, y) != 0).Eval());
} }
return IRMutatorWithAnalyzer::Mutate_(op, e); return IRMutatorWithAnalyzer::VisitExpr_(op);
} }
private: private:
...@@ -231,7 +235,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -231,7 +235,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
} }
Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c, Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
const Add* op, const Expr& e) { const Add* op) {
// emit fma instruction: a * b + c // emit fma instruction: a * b + c
Expr lhs = SwapBroadcastCast(a); Expr lhs = SwapBroadcastCast(a);
Expr rhs = SwapBroadcastCast(b); Expr rhs = SwapBroadcastCast(b);
...@@ -239,14 +243,14 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -239,14 +243,14 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (fma_ != nullptr && op->dtype.is_float()) { if (fma_ != nullptr && op->dtype.is_float()) {
Expr r = (*fma_)(Call::make( Expr r = (*fma_)(Call::make(
op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic)); op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
if (r.defined()) return this->Mutate(r); if (r.defined()) return this->VisitExpr(r);
} else { } else {
if (!lhs.same_as(a) || !rhs.same_as(b)) { if (!lhs.same_as(a) || !rhs.same_as(b)) {
Expr mul = this->Mutate(Mul::make(lhs, rhs)); Expr mul = this->VisitExpr(Mul::make(lhs, rhs));
return Add::make(mul, this->Mutate(c)); return Add::make(mul, this->VisitExpr(c));
} }
} }
return IRMutator::Mutate_(op, e); return IRMutatorWithAnalyzer::VisitExpr_(op);
} }
Expr ApplyPattern(const std::string& name, const Expr& e) { Expr ApplyPattern(const std::string& name, const Expr& e) {
...@@ -262,7 +266,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -262,7 +266,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
Expr r = (*f)(e); Expr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr"; CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) { if (!r.same_as(e)) {
return this->Mutate(r); return this->VisitExpr(r);
} }
} }
} }
...@@ -277,7 +281,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -277,7 +281,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target).Mutate(stmt); return IntrinInjecter(&analyzer, target)(std::move(stmt));
} }
LoweredFunc LoweredFunc
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \file lower_thread_allreduce.cc * \file lower_thread_allreduce.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h" #include "ir_util.h"
...@@ -32,19 +32,19 @@ ...@@ -32,19 +32,19 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class ThreadAllreduceBuilder final : public IRMutator { class ThreadAllreduceBuilder final : public StmtExprMutator {
public: public:
explicit ThreadAllreduceBuilder(int warp_size) explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {} : warp_size_(warp_size) {}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op); thread_extents_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s); Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back(); thread_extents_.pop_back();
return ret; return ret;
} else if (op->attr_key == attr::storage_scope) { } else if (op->attr_key == attr::storage_scope) {
Stmt ret = IRMutator::Mutate_(op, s); Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmt>(); op = ret.as<AttrStmt>();
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
if (alloc_remap_.count(v)) { if (alloc_remap_.count(v)) {
...@@ -56,15 +56,15 @@ class ThreadAllreduceBuilder final : public IRMutator { ...@@ -56,15 +56,15 @@ class ThreadAllreduceBuilder final : public IRMutator {
const CommReducerNode *combiner = op->node.as<CommReducerNode>(); const CommReducerNode *combiner = op->node.as<CommReducerNode>();
CHECK(combiner); CHECK(combiner);
reduce_combiner_.push_back(combiner); reduce_combiner_.push_back(combiner);
Stmt ret = IRMutator::Mutate_(op, s); Stmt ret = StmtExprMutator::VisitStmt_(op);
reduce_combiner_.pop_back(); reduce_combiner_.pop_back();
return ret; return ret;
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const Evaluate* op, const Stmt& s) final { Stmt VisitStmt_(const Evaluate* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Evaluate>(); op = stmt.as<Evaluate>();
const Call* call = op->value.as<Call>(); const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
...@@ -73,8 +73,8 @@ class ThreadAllreduceBuilder final : public IRMutator { ...@@ -73,8 +73,8 @@ class ThreadAllreduceBuilder final : public IRMutator {
return stmt; return stmt;
} }
} }
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
auto it = alloc_remap_.find(op->buffer_var.get()); auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) { if (it != alloc_remap_.end()) {
...@@ -93,13 +93,13 @@ class ThreadAllreduceBuilder final : public IRMutator { ...@@ -93,13 +93,13 @@ class ThreadAllreduceBuilder final : public IRMutator {
return stmt; return stmt;
} }
} }
Expr Mutate_(const Load* op, const Expr& e) final { Expr VisitExpr_(const Load* op) final {
auto it = load_remap_.find(op->buffer_var.get()); auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) { if (it != load_remap_.end()) {
CHECK(is_zero(op->index)) << e; CHECK(is_zero(op->index));
return it->second; return it->second;
} else { } else {
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
} }
...@@ -339,7 +339,7 @@ LoweredFunc ...@@ -339,7 +339,7 @@ LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) { LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc); CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->()); auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); n->body = ThreadAllreduceBuilder(warp_size)(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
} // namespace ir } // namespace ir
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \file lower_tvm_buildin.cc * \file lower_tvm_buildin.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h" #include "ir_util.h"
...@@ -43,14 +43,14 @@ inline Expr StackAlloca(std::string type, size_t num) { ...@@ -43,14 +43,14 @@ inline Expr StackAlloca(std::string type, size_t num) {
// Calculate the statistics of packed function. // Calculate the statistics of packed function.
// These information are needed during codegen. // These information are needed during codegen.
class BuiltinLower : public IRMutator { class BuiltinLower : public StmtExprMutator {
public: public:
Stmt Build(Stmt stmt) { Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", DataType::Handle()); stack_shape_ = Var("stack_shape", DataType::Handle());
stack_array_ = Var("stack_array", DataType::Handle()); stack_array_ = Var("stack_array", DataType::Handle());
stack_value_ = Var("stack_value", DataType::Handle()); stack_value_ = Var("stack_value", DataType::Handle());
stack_tcode_ = Var("stack_tcode", DataType::Handle()); stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->Mutate(stmt); stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) { if (max_shape_stack_ != 0) {
stmt = LetStmt::make( stmt = LetStmt::make(
stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
...@@ -68,8 +68,8 @@ class BuiltinLower : public IRMutator { ...@@ -68,8 +68,8 @@ class BuiltinLower : public IRMutator {
return stmt; return stmt;
} }
Stmt Mutate(Stmt stmt) final { Stmt VisitStmt(const Stmt& s) final {
stmt = IRMutator::Mutate(stmt); auto stmt = StmtExprMutator::VisitStmt(s);
CHECK_EQ(run_shape_stack_, 0); CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_array_stack_, 0); CHECK_EQ(run_array_stack_, 0);
while (prep_seq_.size() != 0) { while (prep_seq_.size() != 0) {
...@@ -79,9 +79,9 @@ class BuiltinLower : public IRMutator { ...@@ -79,9 +79,9 @@ class BuiltinLower : public IRMutator {
return stmt; return stmt;
} }
Stmt Mutate_(const Allocate* op, const Stmt& s) { Stmt VisitStmt_(const Allocate* op) {
// Lower allocate to device allocate when needed. // Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
if (op->new_expr.defined()) return stmt; if (op->new_expr.defined()) return stmt;
// Get constant allocation bound. // Get constant allocation bound.
...@@ -141,39 +141,39 @@ class BuiltinLower : public IRMutator { ...@@ -141,39 +141,39 @@ class BuiltinLower : public IRMutator {
return body; return body;
} }
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::device_context_id) { if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined()); CHECK(!device_id_.defined());
device_id_ = op->value; device_id_ = op->value;
return Mutate(op->body); return this->VisitStmt(op->body);
} else if (op->attr_key == attr::device_context_type) { } else if (op->attr_key == attr::device_context_type) {
CHECK(!device_type_.defined()); CHECK(!device_type_.defined());
device_type_ = op->value; device_type_ = op->value;
return Mutate(op->body); return this->VisitStmt(op->body);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Expr Mutate_(const Call* op, const Expr &e) final { Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) { if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e); return MakeCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) { } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
return MakeCallTracePacked(op, e); return MakeCallTracePacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) { } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e); return MakeShape(op);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) { } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e); return MakeArray(op);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) { } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return make_zero(op->dtype); return make_zero(op->dtype);
} else { } else {
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
} }
// call shape // call shape
Expr MakeShape(const Call* op, const Expr& e) { Expr MakeShape(const Call* op) {
size_t stack_begin = run_shape_stack_; size_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size(); run_shape_stack_ += op->args.size();
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back( prep_seq_.emplace_back(
...@@ -183,10 +183,10 @@ class BuiltinLower : public IRMutator { ...@@ -183,10 +183,10 @@ class BuiltinLower : public IRMutator {
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
} }
// make array // make array
Expr MakeArray(const Call* op, const Expr& e) { Expr MakeArray(const Call* op) {
size_t idx = run_array_stack_; size_t idx = run_array_stack_;
run_array_stack_ += 1; run_array_stack_ += 1;
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
prep_seq_.emplace_back( prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
...@@ -230,13 +230,13 @@ class BuiltinLower : public IRMutator { ...@@ -230,13 +230,13 @@ class BuiltinLower : public IRMutator {
return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
} }
// call packed. // call packed.
Expr MakeCallPacked(const Call* op, const Expr& e) { Expr MakeCallPacked(const Call* op) {
size_t restore_shape_stack = run_shape_stack_; size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_; size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_; size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size(); run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic // Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) { for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1); Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
...@@ -278,14 +278,14 @@ class BuiltinLower : public IRMutator { ...@@ -278,14 +278,14 @@ class BuiltinLower : public IRMutator {
packed_args, Call::Intrinsic); packed_args, Call::Intrinsic);
} }
Expr MakeCallTracePacked(const Call *op, const Expr &e) { Expr MakeCallTracePacked(const Call *op) {
size_t restore_shape_stack = run_shape_stack_; size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_; size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_; size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size(); run_arg_stack_ += op->args.size();
size_t args_size = op->args.size(); size_t args_size = op->args.size();
CHECK_GT(args_size, 0); CHECK_GT(args_size, 0);
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) { for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1); Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
......
...@@ -26,8 +26,7 @@ ...@@ -26,8 +26,7 @@
// Thanks to Andrew Adams and Vinod Grover for // Thanks to Andrew Adams and Vinod Grover for
// explaining the concept of warp shuffle. // explaining the concept of warp shuffle.
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h" #include "ir_util.h"
...@@ -75,7 +74,7 @@ namespace ir { ...@@ -75,7 +74,7 @@ namespace ir {
// Visitor to find m in pattern // Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x] // store warp_mem[m * warp_index + (warp_size * m) * y + x]
class WarpStoreCoeffFinder : private IRVisitor { class WarpStoreCoeffFinder : private StmtVisitor {
public: public:
WarpStoreCoeffFinder(const Variable* buffer, WarpStoreCoeffFinder(const Variable* buffer,
Var warp_index, Var warp_index,
...@@ -86,13 +85,13 @@ class WarpStoreCoeffFinder : private IRVisitor { ...@@ -86,13 +85,13 @@ class WarpStoreCoeffFinder : private IRVisitor {
} }
// find the warp co-efficient in the statement given the warp size // find the warp co-efficient in the statement given the warp size
int Find(const Stmt& stmt) { int Find(const Stmt& stmt) {
this->Visit(stmt); this->VisitStmt(stmt);
return warp_coeff_; return warp_coeff_;
} }
private: private:
/// Visitor implementation /// Visitor implementation
void Visit_(const Store *op) final { void VisitStmt_(const Store *op) final {
if (op->buffer_var.get() == buffer_) { if (op->buffer_var.get() == buffer_) {
if (op->value.dtype().lanes() == 1) { if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index); UpdatePattern(op->index);
...@@ -104,7 +103,7 @@ class WarpStoreCoeffFinder : private IRVisitor { ...@@ -104,7 +103,7 @@ class WarpStoreCoeffFinder : private IRVisitor {
UpdatePattern(base); UpdatePattern(base);
} }
} else { } else {
IRVisitor::Visit_(op); StmtVisitor::VisitStmt_(op);
} }
} }
...@@ -141,14 +140,14 @@ class WarpStoreCoeffFinder : private IRVisitor { ...@@ -141,14 +140,14 @@ class WarpStoreCoeffFinder : private IRVisitor {
// Visitor to find the warp index // Visitor to find the warp index
class WarpIndexFinder : private IRVisitor { class WarpIndexFinder : private StmtVisitor {
public: public:
explicit WarpIndexFinder(int warp_size) explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) { : warp_size_(warp_size) {
} }
// find the warp co-efficient in the statement given the warp size // find the warp co-efficient in the statement given the warp size
IterVar Find(const Stmt& stmt) { IterVar Find(const Stmt& stmt) {
this->Visit(stmt); this->VisitStmt(stmt);
CHECK(warp_index_.defined()) CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory"; << "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_; return warp_index_;
...@@ -156,7 +155,7 @@ class WarpIndexFinder : private IRVisitor { ...@@ -156,7 +155,7 @@ class WarpIndexFinder : private IRVisitor {
private: private:
/// Visitor implementation /// Visitor implementation
void Visit_(const AttrStmt *op) final { void VisitStmt_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") { if (iv->thread_tag == "threadIdx.x") {
...@@ -177,7 +176,7 @@ class WarpIndexFinder : private IRVisitor { ...@@ -177,7 +176,7 @@ class WarpIndexFinder : private IRVisitor {
} }
} }
} }
IRVisitor::Visit_(op); StmtVisitor::VisitStmt_(op);
} }
// warp size // warp size
int warp_size_{0}; int warp_size_{0};
...@@ -185,13 +184,13 @@ class WarpIndexFinder : private IRVisitor { ...@@ -185,13 +184,13 @@ class WarpIndexFinder : private IRVisitor {
IterVar warp_index_{nullptr}; IterVar warp_index_{nullptr};
}; };
// Mutator to change the read pattern // Mutator to change the read pattern
class WarpAccessRewriter : protected IRMutator { class WarpAccessRewriter : protected StmtExprMutator {
public: public:
explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer) explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer)
: warp_size_(warp_size), analyzer_(analyzer) {} : warp_size_(warp_size), analyzer_(analyzer) {}
// Rewrite the allocate statement which transforms // Rewrite the allocate statement which transforms
// warp memory to local memory. // warp memory to local memory.
Stmt Rewrite(const Allocate* op, const Stmt& stmt) { Stmt Rewrite(const Allocate* op) {
buffer_ = op->buffer_var.get(); buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size(); int alloc_size = op->constant_allocation_size();
CHECK_GT(alloc_size, 0) CHECK_GT(alloc_size, 0)
...@@ -208,27 +207,27 @@ class WarpAccessRewriter : protected IRMutator { ...@@ -208,27 +207,27 @@ class WarpAccessRewriter : protected IRMutator {
op->dtype, op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)}, {make_const(DataType::Int(32), alloc_size / warp_size_)},
op->condition, op->condition,
this->Mutate(op->body)); this->VisitStmt(op->body));
} }
protected: protected:
Expr Mutate_(const Variable* op, const Expr& expr) { Expr Mutate_(const Variable* op) {
CHECK(op != buffer_) CHECK(op != buffer_)
<< "Cannot access address of warp memory directly"; << "Cannot access address of warp memory directly";
return IRMutator::Mutate_(op, expr); return StmtExprMutator::VisitExpr_(op);
} }
Stmt Mutate_(const Store* op, const Stmt& stmt) { Stmt VisitStmt_(const Store* op) {
if (op->buffer_var.get() == buffer_) { if (op->buffer_var.get() == buffer_) {
Expr local_index, group; Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index); std::tie(local_index, group) = SplitIndexByGroup(op->index);
return Store::make(op->buffer_var, op->value, local_index, op->predicate); return Store::make(op->buffer_var, op->value, local_index, op->predicate);
} else { } else {
return IRMutator::Mutate_(op, stmt); return StmtExprMutator::VisitStmt_(op);
} }
} }
Expr Mutate_(const Load* op, const Expr& expr) { Expr Mutate_(const Load* op) {
if (op->buffer_var.get() == buffer_) { if (op->buffer_var.get() == buffer_) {
Expr local_index, group; Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index); std::tie(local_index, group) = SplitIndexByGroup(op->index);
...@@ -243,7 +242,7 @@ class WarpAccessRewriter : protected IRMutator { ...@@ -243,7 +242,7 @@ class WarpAccessRewriter : protected IRMutator {
{load_value, group}, {load_value, group},
Call::Intrinsic); Call::Intrinsic);
} else { } else {
return IRMutator::Mutate_(op, expr); return StmtExprMutator::VisitExpr_(op);
} }
} }
// Split the index to the two component // Split the index to the two component
...@@ -297,18 +296,18 @@ class WarpAccessRewriter : protected IRMutator { ...@@ -297,18 +296,18 @@ class WarpAccessRewriter : protected IRMutator {
// Bind bound information of variables to make analyzer more effective // Bind bound information of variables to make analyzer more effective
// TODO(tqchen): consider a pass to inline the bound info into the expr // TODO(tqchen): consider a pass to inline the bound info into the expr
// so analysis can be context independent. // so analysis can be context independent.
class BindVarBoundInfo : public IRVisitor { class BindVarBoundInfo : public StmtVisitor {
public: public:
explicit BindVarBoundInfo(arith::Analyzer* analyzer) explicit BindVarBoundInfo(arith::Analyzer* analyzer)
: analyzer_(analyzer) {} : analyzer_(analyzer) {}
void Visit_(const For* op) final { void VisitStmt_(const For* op) final {
const Var& loop_var = op->loop_var; const Var& loop_var = op->loop_var;
analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent)); analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op); StmtVisitor::VisitStmt_(op);
} }
void Visit_(const AttrStmt* op) { void VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
...@@ -319,7 +318,7 @@ class BindVarBoundInfo : public IRVisitor { ...@@ -319,7 +318,7 @@ class BindVarBoundInfo : public IRVisitor {
analyzer_->Bind(iv->var, dom); analyzer_->Bind(iv->var, dom);
} }
} }
IRVisitor::Visit_(op); StmtVisitor::VisitStmt_(op);
} }
protected: protected:
...@@ -330,7 +329,7 @@ class BindVarBoundInfo : public IRVisitor { ...@@ -330,7 +329,7 @@ class BindVarBoundInfo : public IRVisitor {
}; };
// Mutator to change the read pattern // Mutator to change the read pattern
class WarpMemoryRewriter : private IRMutator { class WarpMemoryRewriter : private StmtMutator {
public: public:
explicit WarpMemoryRewriter(int warp_size) explicit WarpMemoryRewriter(int warp_size)
: warp_size_(warp_size) { : warp_size_(warp_size) {
...@@ -338,36 +337,37 @@ class WarpMemoryRewriter : private IRMutator { ...@@ -338,36 +337,37 @@ class WarpMemoryRewriter : private IRMutator {
Stmt Rewrite(Stmt stmt) { Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt; if (warp_size_ == 1) return stmt;
BindVarBoundInfo(&analyzer_).Visit(stmt); BindVarBoundInfo binder(&analyzer_);
stmt = this->Mutate(stmt); binder(stmt);
stmt = operator()(std::move(stmt));
stmt = CanonicalSimplify(stmt); stmt = CanonicalSimplify(stmt);
return stmt; return stmt;
} }
private: private:
Stmt Mutate_(const Allocate* op, const Stmt& stmt) { Stmt VisitStmt_(const Allocate* op) {
if (warp_buffer_.count(op->buffer_var.get())) { if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_, &analyzer_); WarpAccessRewriter rewriter(warp_size_, &analyzer_);
return rewriter.Rewrite(op, stmt); return rewriter.Rewrite(op);
} else { } else {
return IRMutator::Mutate_(op, stmt); return StmtMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { Stmt VisitStmt_(const AttrStmt* op) {
using runtime::StorageScope; using runtime::StorageScope;
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value); StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
if (scope.rank == runtime::StorageRank::kWarp) { if (scope.rank == runtime::StorageRank::kWarp) {
warp_buffer_.insert(buf); warp_buffer_.insert(buf);
Stmt ret = IRMutator::Mutate_(op, stmt); Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmt>(); op = ret.as<AttrStmt>();
return AttrStmt::make( return AttrStmt::make(
op->node, op->attr_key, StringImm::make("local"), op->body); op->node, op->attr_key, StringImm::make("local"), op->body);
} }
} }
return IRMutator::Mutate_(op, stmt); return StmtMutator::VisitStmt_(op);
} }
int warp_size_{0}; int warp_size_{0};
......
...@@ -22,8 +22,7 @@ ...@@ -22,8 +22,7 @@
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <vector> #include <vector>
...@@ -207,29 +206,29 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -207,29 +206,29 @@ LoweredFunc MakeAPI(Stmt body,
return f; return f;
} }
class DeviceTypeBinder: public IRMutator { class DeviceTypeBinder: public StmtExprMutator {
public: public:
explicit DeviceTypeBinder(int device_type) explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {} : device_type_(device_type) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::device_context_type) { if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) { if (const Variable* var = op->value.as<Variable>()) {
var_ = var; var_ = var;
Expr value = make_const(op->value.dtype(), device_type_); Expr value = make_const(op->value.dtype(), device_type_);
Stmt body = IRMutator::Mutate_(op, s); Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr; var_ = nullptr;
std::ostringstream os; std::ostringstream os;
os << "device_type need to be " << device_type_; os << "device_type need to be " << device_type_;
return AssertStmt::make(op->value == value, os.str(), body); return AssertStmt::make(op->value == value, os.str(), body);
} }
} }
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { Stmt VisitStmt_(const IfThenElse* op) final {
// eager simplify if guard. // eager simplify if guard.
Stmt res = IRMutator::Mutate_(op, s); Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElse>(); op = res.as<IfThenElse>();
if (is_zero(op->condition)) { if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case; if (op->else_case.defined()) return op->else_case;
...@@ -241,9 +240,9 @@ class DeviceTypeBinder: public IRMutator { ...@@ -241,9 +240,9 @@ class DeviceTypeBinder: public IRMutator {
return res; return res;
} }
Expr Mutate_(const NE* op, const Expr& e) final { Expr VisitExpr_(const NE* op) final {
// eager check NE for device check // eager check NE for device check
Expr res = IRMutator::Mutate_(op, e); Expr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NE>(); op = res.as<NE>();
if (ir::Equal(op->a, op->b)) { if (ir::Equal(op->a, op->b)) {
return make_const(op->dtype, false); return make_const(op->dtype, false);
...@@ -251,11 +250,11 @@ class DeviceTypeBinder: public IRMutator { ...@@ -251,11 +250,11 @@ class DeviceTypeBinder: public IRMutator {
return res; return res;
} }
Expr Mutate_(const Variable* op, const Expr& e) final { Expr VisitExpr_(const Variable* op) final {
if (op == var_) { if (op == var_) {
return make_const(op->dtype, device_type_); return make_const(op->dtype, device_type_);
} else { } else {
return e; return GetRef<Expr>(op);
} }
} }
...@@ -267,7 +266,7 @@ class DeviceTypeBinder: public IRMutator { ...@@ -267,7 +266,7 @@ class DeviceTypeBinder: public IRMutator {
LoweredFunc BindDeviceType(LoweredFunc f, LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) { int device_type) {
auto n = make_object<LoweredFuncNode>(*f.operator->()); auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type).Mutate(n->body); n->body = DeviceTypeBinder(device_type)(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -21,8 +21,7 @@ ...@@ -21,8 +21,7 @@
* \file remap_thread_axis.cc * \file remap_thread_axis.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_map> #include <unordered_map>
...@@ -31,7 +30,7 @@ namespace tvm { ...@@ -31,7 +30,7 @@ namespace tvm {
namespace ir { namespace ir {
// Mutator to change the read pattern // Mutator to change the read pattern
class ThreadAxisRewriter : private IRMutator { class ThreadAxisRewriter : private StmtExprMutator {
public: public:
explicit ThreadAxisRewriter( explicit ThreadAxisRewriter(
const std::unordered_map<std::string, IterVar>& tmap) const std::unordered_map<std::string, IterVar>& tmap)
...@@ -39,11 +38,11 @@ class ThreadAxisRewriter : private IRMutator { ...@@ -39,11 +38,11 @@ class ThreadAxisRewriter : private IRMutator {
} }
Stmt Rewrite(Stmt stmt) { Stmt Rewrite(Stmt stmt) {
return Mutate(stmt); return operator()(std::move(stmt));
} }
private: private:
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
...@@ -56,18 +55,18 @@ class ThreadAxisRewriter : private IRMutator { ...@@ -56,18 +55,18 @@ class ThreadAxisRewriter : private IRMutator {
} else { } else {
CHECK(vmap_[v].same_as(new_iv->var)); CHECK(vmap_[v].same_as(new_iv->var));
} }
Stmt body = this->Mutate(op->body); Stmt body = this->VisitStmt(op->body);
return AttrStmt::make( return AttrStmt::make(
new_iv, op->attr_key, op->value, body); new_iv, op->attr_key, op->value, body);
} }
} }
return IRMutator::Mutate_(op, stmt); return StmtExprMutator::VisitStmt_(op);
} }
Expr Mutate_(const Variable* op, const Expr& expr) final { Expr VisitExpr_(const Variable* op) final {
auto it = vmap_.find(op); auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second; if (it != vmap_.end()) return it->second;
return IRMutator::Mutate_(op, expr); return StmtExprMutator::VisitExpr_(op);
} }
// The thread map // The thread map
const std::unordered_map<std::string, IterVar>& tmap_; const std::unordered_map<std::string, IterVar>& tmap_;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -23,30 +23,30 @@ ...@@ -23,30 +23,30 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <unordered_map> #include <unordered_map>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
// Mark the statment of each stage. // Mark the statment of each stage.
class NoOpRemover : public IRMutator { class NoOpRemover : public StmtMutator {
public: public:
Stmt Mutate_(const LetStmt* op, const Stmt& s) final { Stmt VisitStmt_(const LetStmt* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<LetStmt>(); op = stmt.as<LetStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == "pragma_debug_skip_region") { if (op->attr_key == "pragma_debug_skip_region") {
return MakeEvaluate(0); return MakeEvaluate(0);
} }
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>(); op = stmt.as<AttrStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
} }
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { Stmt VisitStmt_(const IfThenElse* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<IfThenElse>(); op = stmt.as<IfThenElse>();
if (op->else_case.defined()) { if (op->else_case.defined()) {
if (is_no_op(op->else_case)) { if (is_no_op(op->else_case)) {
...@@ -66,35 +66,35 @@ class NoOpRemover : public IRMutator { ...@@ -66,35 +66,35 @@ class NoOpRemover : public IRMutator {
} }
} }
} }
Stmt Mutate_(const For* op, const Stmt& s) final { Stmt VisitStmt_(const For* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<For>(); op = stmt.as<For>();
if (is_zero(op->extent)) { if (is_zero(op->extent)) {
return Evaluate::make(0); return Evaluate::make(0);
} }
return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt; return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
} }
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
} }
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final { Stmt VisitStmt_(const ProducerConsumer* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ProducerConsumer>(); op = stmt.as<ProducerConsumer>();
return is_no_op(op->body) ? op->body : stmt; return is_no_op(op->body) ? op->body : stmt;
} }
Stmt Mutate_(const Realize* op, const Stmt& s) final { Stmt VisitStmt_(const Realize* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Realize>(); op = stmt.as<Realize>();
return is_no_op(op->body) ? op->body : stmt; return is_no_op(op->body) ? op->body : stmt;
} }
Stmt Mutate_(const Evaluate* op, const Stmt& s) final { Stmt VisitStmt_(const Evaluate* op) final {
if (HasSideEffect(op->value)) return s; if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
return Evaluate::make(0); return Evaluate::make(0);
} }
Stmt Mutate_(const Block* op, const Stmt& s) final { Stmt VisitStmt_(const Block* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Block>(); op = stmt.as<Block>();
if (is_no_op(op->first)) { if (is_no_op(op->first)) {
return op->rest; return op->rest;
...@@ -129,7 +129,7 @@ class NoOpRemover : public IRMutator { ...@@ -129,7 +129,7 @@ class NoOpRemover : public IRMutator {
}; };
Stmt RemoveNoOp(Stmt stmt) { Stmt RemoveNoOp(Stmt stmt) {
return NoOpRemover().Mutate(stmt); return NoOpRemover()(std::move(stmt));
} }
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
namespace tvm { namespace tvm {
...@@ -109,10 +108,10 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> { ...@@ -109,10 +108,10 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
} }
}; };
class UnsafeSelectRewriter : public IRMutator { class UnsafeSelectRewriter : public StmtExprMutator {
public: public:
Expr Mutate_(const Select* op, const Expr& e) { Expr VisitExpr_(const Select* op) {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Select>(); op = expr.as<Select>();
UnsafeExprDetector unsafe; UnsafeExprDetector unsafe;
bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
...@@ -131,7 +130,7 @@ class UnsafeSelectRewriter : public IRMutator { ...@@ -131,7 +130,7 @@ class UnsafeSelectRewriter : public IRMutator {
}; };
Stmt RewriteUnsafeSelect(Stmt stmt) { Stmt RewriteUnsafeSelect(Stmt stmt) {
return UnsafeSelectRewriter().Mutate(stmt); return UnsafeSelectRewriter()(std::move(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -22,25 +22,24 @@ ...@@ -22,25 +22,24 @@
* \brief Implementation of simple passes * \brief Implementation of simple passes
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class IRSideEffect : public IRVisitor { class IRSideEffect : public ExprVisitor {
public: public:
void Visit(const ObjectRef& e) final { void VisitExpr(const Expr& e) final {
if (has_side_effect_) return; if (has_side_effect_) return;
IRVisitor::Visit(e); ExprVisitor::VisitExpr(e);
} }
void Visit_(const Call* op) final { void VisitExpr_(const Call* op) final {
if (!op->is_pure()) { if (!op->is_pure()) {
has_side_effect_ = true; return; has_side_effect_ = true; return;
} else { } else {
IRVisitor::Visit_(op); ExprVisitor::VisitExpr_(op);
} }
} }
...@@ -49,23 +48,23 @@ class IRSideEffect : public IRVisitor { ...@@ -49,23 +48,23 @@ class IRSideEffect : public IRVisitor {
bool HasSideEffect(const Expr& e) { bool HasSideEffect(const Expr& e) {
IRSideEffect v; IRSideEffect v;
v.Visit(e); v(e);
return v.has_side_effect_; return v.has_side_effect_;
} }
class IRSubstitue : public IRMutator { class IRSubstitue : public StmtExprMutator {
public: public:
explicit IRSubstitue( explicit IRSubstitue(
const std::unordered_map<const Variable*, Expr>& smap) const std::unordered_map<const Variable*, Expr>& smap)
: smap_(smap) { : smap_(smap) {
} }
Expr Mutate_(const Variable* op, const Expr& e) final { Expr VisitExpr_(const Variable* op) final {
auto it = smap_.find(op); auto it = smap_.find(op);
if (it != smap_.end()) { if (it != smap_.end()) {
return it->second; return it->second;
} else { } else {
return e; return GetRef<Expr>(op);
} }
} }
...@@ -76,13 +75,13 @@ class IRSubstitue : public IRMutator { ...@@ -76,13 +75,13 @@ class IRSubstitue : public IRMutator {
Stmt Substitute(Stmt stmt, Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map) { const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return stmt; if (value_map.size() == 0) return stmt;
return IRSubstitue(value_map).Mutate(stmt); return IRSubstitue(value_map)(std::move(stmt));
} }
Expr Substitute(Expr expr, Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map) { const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return expr; if (value_map.size() == 0) return expr;
return IRSubstitue(value_map).Mutate(expr); return IRSubstitue(value_map)(std::move(expr));
} }
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) { Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
...@@ -101,20 +100,20 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) { ...@@ -101,20 +100,20 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
return Substitute(expr, vmap); return Substitute(expr, vmap);
} }
class VarTouchVisitor : public IRVisitor { class VarTouchVisitor : public ExprVisitor {
public: public:
void Visit(const ObjectRef& e) final { void VisitExpr(const Expr& e) final {
if (use_var_) return; if (use_var_) return;
IRVisitor::Visit(e); ExprVisitor::VisitExpr(e);
} }
void Visit_(const Variable* op) final { void VisitExpr_(const Variable* op) final {
Handle(op); Handle(op);
} }
void Visit_(const Load* op) final { void VisitExpr_(const Load* op) final {
Handle(op->buffer_var.get()); Handle(op->buffer_var.get());
IRVisitor::Visit_(op); ExprVisitor::VisitExpr_(op);
} }
virtual void Handle(const Variable* var) = 0; virtual void Handle(const Variable* var) = 0;
...@@ -149,14 +148,14 @@ class ExprUseVSetVisitor : public VarTouchVisitor { ...@@ -149,14 +148,14 @@ class ExprUseVSetVisitor : public VarTouchVisitor {
bool ExprUseVar(const Expr& e, const Var& v) { bool ExprUseVar(const Expr& e, const Var& v) {
ExprUseVarVisitor visitor(v.get()); ExprUseVarVisitor visitor(v.get());
visitor.Visit(e); visitor(e);
return visitor.use_var_; return visitor.use_var_;
} }
bool ExprUseVar(const Expr& e, bool ExprUseVar(const Expr& e,
const std::unordered_set<const Variable*>& vset) { const std::unordered_set<const Variable*>& vset) {
ExprUseVSetVisitor visitor(vset); ExprUseVSetVisitor visitor(vset);
visitor.Visit(e); visitor(e);
return visitor.use_var_; return visitor.use_var_;
} }
......
...@@ -19,22 +19,22 @@ ...@@ -19,22 +19,22 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class AssertSkipper : public IRMutator { class AssertSkipper : public StmtMutator {
public: public:
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AssertStmt* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AssertStmt>(); op = stmt.as<AssertStmt>();
return op->body; return op->body;
} }
}; };
Stmt SkipAssert(Stmt stmt) { Stmt SkipAssert(Stmt stmt) {
return AssertSkipper().Mutate(stmt); return AssertSkipper()(std::move(stmt));
} }
LoweredFunc SkipAssert(LoweredFunc f) { LoweredFunc SkipAssert(LoweredFunc f) {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <unordered_map> #include <unordered_map>
...@@ -32,9 +32,9 @@ namespace tvm { ...@@ -32,9 +32,9 @@ namespace tvm {
namespace ir { namespace ir {
// use/def analysis, also delete unreferenced lets // use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public IRMutator { class IRUseDefAnalysis : public StmtExprMutator {
public: public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
...@@ -48,75 +48,77 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -48,75 +48,77 @@ class IRUseDefAnalysis : public IRMutator {
Expr value = op->value; Expr value = op->value;
if (visit_thread_extent_) { if (visit_thread_extent_) {
value = this->Mutate(value); value = this->VisitExpr(value);
}
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} }
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) return s;
return AttrStmt::make(op->node, op->attr_key, value, body); return AttrStmt::make(op->node, op->attr_key, value, body);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const LetStmt *op, const Stmt& s) final { Stmt VisitStmt_(const LetStmt* op) final {
this->HandleDef(op->var.get()); this->HandleDef(op->var.get());
Stmt body = this->Mutate(op->body); Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let // eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) { !HasSideEffect(op->value)) {
return body; return body;
} else { } else {
Expr value = this->Mutate(op->value); Expr value = this->VisitExpr(op->value);
if (body.same_as(op->body) && if (body.same_as(op->body) &&
value.same_as(op->value)) { value.same_as(op->value)) {
return s; return GetRef<Stmt>(op);
} else { } else {
return LetStmt::make(op->var, value, body); return LetStmt::make(op->var, value, body);
} }
} }
} }
Stmt Mutate_(const For *op, const Stmt& s) final { Stmt VisitStmt_(const For* op) final {
this->HandleDef(op->loop_var.get()); this->HandleDef(op->loop_var.get());
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
Stmt Mutate_(const Allocate *op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
this->HandleDef(op->buffer_var.get()); this->HandleDef(op->buffer_var.get());
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
Stmt Mutate_(const Store *op, const Stmt& s) final { Stmt VisitStmt_(const Store* op) final {
this->HandleUse(op->buffer_var); this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
Expr Mutate_(const Let *op, const Expr& e) final { Expr VisitExpr_(const Let* op) final {
this->HandleDef(op->var.get()); this->HandleDef(op->var.get());
Expr body = this->Mutate(op->body); Expr body = this->VisitExpr(op->body);
// eliminate unreferenced let // eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) { !HasSideEffect(op->value)) {
return body; return body;
} else { } else {
Expr value = this->Mutate(op->value); Expr value = this->VisitExpr(op->value);
if (body.same_as(op->body) && if (body.same_as(op->body) &&
value.same_as(op->value)) { value.same_as(op->value)) {
return e; return GetRef<Expr>(op);
} else { } else {
return Let::make(op->var, value, body); return Let::make(op->var, value, body);
} }
} }
} }
Expr Mutate_(const Variable *op, const Expr& e) final { Expr VisitExpr_(const Variable* op) final {
this->HandleUse(e); this->HandleUse(GetRef<Expr>(op));
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
Expr Mutate_(const Load *op, const Expr& e) final { Expr VisitExpr_(const Load* op) final {
this->HandleUse(op->buffer_var); this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
void HandleDef(const Variable* v) { void HandleDef(const Variable* v) {
...@@ -154,20 +156,20 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -154,20 +156,20 @@ class IRUseDefAnalysis : public IRMutator {
std::unordered_map<const Variable*, int> def_count_; std::unordered_map<const Variable*, int> def_count_;
}; };
class HostDeviceSplitter : public IRMutator { class HostDeviceSplitter : public StmtMutator {
public: public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return IRMutator::Mutate_(op, s); return StmtMutator::VisitStmt_(op);
} }
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) { op->attr_key == attr::device_scope) {
return SplitDeviceFunc(s); return SplitDeviceFunc(GetRef<Stmt>(op));
} }
return IRMutator::Mutate_(op, s); return StmtMutator::VisitStmt_(op);
} }
Array<LoweredFunc> Split(LoweredFunc f) { Array<LoweredFunc> Split(LoweredFunc f) {
...@@ -178,7 +180,7 @@ class HostDeviceSplitter : public IRMutator { ...@@ -178,7 +180,7 @@ class HostDeviceSplitter : public IRMutator {
name_ = f->name; name_ = f->name;
ObjectPtr<LoweredFuncNode> n = ObjectPtr<LoweredFuncNode> n =
make_object<LoweredFuncNode>(*f.operator->()); make_object<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body); n->body = operator()(f->body);
n->func_type = kHostFunc; n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)}; Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) { for (LoweredFunc x : device_funcs_) {
...@@ -195,7 +197,7 @@ class HostDeviceSplitter : public IRMutator { ...@@ -195,7 +197,7 @@ class HostDeviceSplitter : public IRMutator {
// isolate the device function. // isolate the device function.
IRUseDefAnalysis m; IRUseDefAnalysis m;
m.visit_thread_extent_ = false; m.visit_thread_extent_ = false;
n->body = m.Mutate(body); n->body = m(std::move(body));
n->name = os.str(); n->name = os.str();
n->func_type = kDeviceFunc; n->func_type = kDeviceFunc;
n->thread_axis = m.thread_axis_; n->thread_axis = m.thread_axis_;
...@@ -243,7 +245,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) { ...@@ -243,7 +245,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
for (Var arg : args) { for (Var arg : args) {
m.use_count_[arg.get()] = 0; m.use_count_[arg.get()] = 0;
} }
m.Mutate(stmt); m(stmt);
return m.undefined_; return m.undefined_;
} }
......
...@@ -24,8 +24,7 @@ ...@@ -24,8 +24,7 @@
* \file ssa.cc * \file ssa.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
...@@ -34,29 +33,33 @@ ...@@ -34,29 +33,33 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
namespace { namespace {
class IRVerifySSA final : public IRVisitor { class IRVerifySSA final : public StmtExprVisitor {
public: public:
bool is_ssa{true}; bool is_ssa{true};
void Visit(const ObjectRef& n) final { void VisitExpr(const Expr& n) final {
if (!is_ssa) return; if (!is_ssa) return;
IRVisitor::Visit(n); StmtExprVisitor::VisitExpr(n);
} }
void Visit_(const Let* op) final { void VisitStmt(const Stmt& n) final {
if (!is_ssa) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr_(const Let* op) final {
MarkDef(op->var.get()); MarkDef(op->var.get());
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} }
void Visit_(const LetStmt* op) final { void VisitStmt_(const LetStmt* op) final {
MarkDef(op->var.get()); MarkDef(op->var.get());
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
void Visit_(const For* op) final { void VisitStmt_(const For* op) final {
MarkDef(op->loop_var.get()); MarkDef(op->loop_var.get());
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
void Visit_(const Allocate* op) final { void VisitStmt_(const Allocate* op) final {
MarkDef(op->buffer_var.get()); MarkDef(op->buffer_var.get());
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
private: private:
...@@ -70,31 +73,32 @@ class IRVerifySSA final : public IRVisitor { ...@@ -70,31 +73,32 @@ class IRVerifySSA final : public IRVisitor {
std::unordered_map<const Variable*, int> defined_; std::unordered_map<const Variable*, int> defined_;
}; };
class IRConvertSSA final : public IRMutator {
class IRConvertSSA final : public StmtExprMutator {
public: public:
Expr Mutate_(const Variable* op, const Expr& e) final { Expr VisitExpr_(const Variable* op) final {
if (scope_.count(op)) { if (scope_.count(op)) {
return scope_[op].back(); return scope_[op].back();
} else { } else {
return e; return GetRef<Expr>(op);
} }
} }
Expr Mutate_(const Let* op, const Expr& e) final { Expr VisitExpr_(const Let* op) final {
const VarExpr& v = op->var; const VarExpr& v = op->var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value); Expr value = this->VisitExpr(op->value);
VarExpr new_var = Variable::make(v.dtype(), v->name_hint); VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
Expr body = IRMutator::Mutate(op->body); Expr body = this->VisitExpr(op->body);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
return Let::make(new_var, value, body); return Let::make(new_var, value, body);
} else { } else {
defined_.insert(v.get()); defined_.insert(v.get());
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
} }
Expr Mutate_(const Load* op, const Expr& e) final { Expr VisitExpr_(const Load* op) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>(); op = expr.as<Load>();
if (scope_.count(op->buffer_var.get())) { if (scope_.count(op->buffer_var.get())) {
return Load::make( return Load::make(
...@@ -104,8 +108,8 @@ class IRConvertSSA final : public IRMutator { ...@@ -104,8 +108,8 @@ class IRConvertSSA final : public IRMutator {
return expr; return expr;
} }
} }
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt VisitStmt_(const Store* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>(); op = stmt.as<Store>();
if (scope_.count(op->buffer_var.get())) { if (scope_.count(op->buffer_var.get())) {
return Store::make( return Store::make(
...@@ -115,41 +119,41 @@ class IRConvertSSA final : public IRMutator { ...@@ -115,41 +119,41 @@ class IRConvertSSA final : public IRMutator {
return stmt; return stmt;
} }
} }
Stmt Mutate_(const LetStmt* op, const Stmt& s) final { Stmt VisitStmt_(const LetStmt* op) final {
const VarExpr& v = op->var; const VarExpr& v = op->var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value); Expr value = this->VisitExpr(op->value);
VarExpr new_var = Variable::make(v.dtype(), v->name_hint); VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
Stmt body = IRMutator::Mutate(op->body); Stmt body = this->VisitStmt(op->body);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
return LetStmt::make(new_var, value, body); return LetStmt::make(new_var, value, body);
} else { } else {
defined_.insert(v.get()); defined_.insert(v.get());
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const For* op, const Stmt& s) final { Stmt VisitStmt_(const For* op) final {
const VarExpr& v = op->loop_var; const VarExpr& v = op->loop_var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.dtype(), v->name_hint); VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
op = stmt.as<For>(); op = stmt.as<For>();
return For::make( return For::make(
new_var, op->min, op->extent, op->for_type, op->device_api, op->body); new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
} else { } else {
defined_.insert(v.get()); defined_.insert(v.get());
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
const VarExpr& v = op->buffer_var; const VarExpr& v = op->buffer_var;
if (defined_.count(v.get())) { if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.dtype(), v->name_hint); VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var); scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back(); scope_[v.get()].pop_back();
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
return Allocate::make( return Allocate::make(
...@@ -157,23 +161,23 @@ class IRConvertSSA final : public IRMutator { ...@@ -157,23 +161,23 @@ class IRConvertSSA final : public IRMutator {
op->body, op->new_expr, op->free_function); op->body, op->new_expr, op->free_function);
} else { } else {
defined_.insert(v.get()); defined_.insert(v.get());
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (const Variable* v = op->node.as<Variable>()) { if (const Variable* v = op->node.as<Variable>()) {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Allocate* alloc = op->body.as<Allocate>(); const Allocate* alloc = op->body.as<Allocate>();
if (alloc && op->node.same_as(alloc->buffer_var)) { if (alloc && op->node.same_as(alloc->buffer_var)) {
Stmt new_alloc = Mutate(op->body); Stmt new_alloc = this->VisitStmt(op->body);
if (new_alloc.same_as(op->body)) return s; if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
alloc = new_alloc.as<Allocate>(); alloc = new_alloc.as<Allocate>();
CHECK(alloc); CHECK(alloc);
return AttrStmt::make( return AttrStmt::make(
alloc->buffer_var, op->attr_key, op->value, new_alloc); alloc->buffer_var, op->attr_key, op->value, new_alloc);
} }
} }
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>(); op = stmt.as<AttrStmt>();
if (scope_.count(v) && scope_[v].size() != 0) { if (scope_.count(v) && scope_[v].size() != 0) {
return AttrStmt::make( return AttrStmt::make(
...@@ -182,7 +186,7 @@ class IRConvertSSA final : public IRMutator { ...@@ -182,7 +186,7 @@ class IRConvertSSA final : public IRMutator {
return stmt; return stmt;
} }
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
...@@ -194,13 +198,13 @@ class IRConvertSSA final : public IRMutator { ...@@ -194,13 +198,13 @@ class IRConvertSSA final : public IRMutator {
} // namespace } // namespace
bool VerifySSA(const Stmt& ir) { bool VerifySSA(const Stmt& ir) {
IRVerifySSA v; IRVerifySSA visitor;
v.Visit(ir); visitor(ir);
return v.is_ssa; return visitor.is_ssa;
} }
Stmt ConvertSSA(Stmt stmt) { Stmt ConvertSSA(Stmt stmt) {
return IRConvertSSA().Mutate(stmt); return IRConvertSSA()(std::move(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* \file storage_access.cc * \file storage_access.cc
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/target_info.h> #include <tvm/target_info.h>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -32,7 +31,7 @@ ...@@ -32,7 +31,7 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
void StorageAccessVisitor::Visit_(const Load* op) { void StorageAccessVisitor::VisitExpr_(const Load* op) {
const Variable* buf = op->buffer_var.as<Variable>(); const Variable* buf = op->buffer_var.as<Variable>();
StorageScope scope = GetScope(buf); StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) { if (Enabled(buf, scope)) {
...@@ -47,10 +46,10 @@ void StorageAccessVisitor::Visit_(const Load* op) { ...@@ -47,10 +46,10 @@ void StorageAccessVisitor::Visit_(const Load* op) {
curr_stmt_.access.emplace_back(std::move(e)); curr_stmt_.access.emplace_back(std::move(e));
} }
// traverse child // traverse child
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} }
void StorageAccessVisitor::Visit_(const Store* op) { void StorageAccessVisitor::VisitStmt_(const Store* op) {
allow_append_ = true; allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U); CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op; curr_stmt_.stmt = op;
...@@ -67,7 +66,7 @@ void StorageAccessVisitor::Visit_(const Store* op) { ...@@ -67,7 +66,7 @@ void StorageAccessVisitor::Visit_(const Store* op) {
curr_stmt_.access.emplace_back(std::move(e)); curr_stmt_.access.emplace_back(std::move(e));
} }
// traverse child // traverse child
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
// push to the scope // push to the scope
scope_.back().push_back(curr_stmt_); scope_.back().push_back(curr_stmt_);
// clear access entry. // clear access entry.
...@@ -75,11 +74,11 @@ void StorageAccessVisitor::Visit_(const Store* op) { ...@@ -75,11 +74,11 @@ void StorageAccessVisitor::Visit_(const Store* op) {
allow_append_ = false; allow_append_ = false;
} }
void StorageAccessVisitor::Visit_(const Evaluate* op) { void StorageAccessVisitor::VisitStmt_(const Evaluate* op) {
allow_append_ = true; allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U); CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op; curr_stmt_.stmt = op;
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
// push to the scope // push to the scope
if (curr_stmt_.access.size() != 0) { if (curr_stmt_.access.size() != 0) {
scope_.back().push_back(curr_stmt_); scope_.back().push_back(curr_stmt_);
...@@ -88,17 +87,17 @@ void StorageAccessVisitor::Visit_(const Evaluate* op) { ...@@ -88,17 +87,17 @@ void StorageAccessVisitor::Visit_(const Evaluate* op) {
allow_append_ = false; allow_append_ = false;
} }
void StorageAccessVisitor::Visit_(const AttrStmt* op) { void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value); StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::double_buffer_write) { } else if (op->attr_key == attr::double_buffer_write) {
CHECK(double_buffer_write_ == nullptr); CHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<Variable>(); double_buffer_write_ = op->node.as<Variable>();
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
StmtEntry s; StmtEntry s;
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr); s.access = Summarize(std::move(scope_.back()), nullptr);
...@@ -115,7 +114,7 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { ...@@ -115,7 +114,7 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
} else if (op->attr_key == attr::coproc_scope) { } else if (op->attr_key == attr::coproc_scope) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv); env_threads_.push_back(iv);
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
env_threads_.CopyOnWrite()->data.pop_back(); env_threads_.CopyOnWrite()->data.pop_back();
} else if (op->attr_key == attr::thread_extent) { } else if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
...@@ -123,23 +122,23 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { ...@@ -123,23 +122,23 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
if (!in_device_env_) { if (!in_device_env_) {
in_device_env_ = true; in_device_env_ = true;
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
// no need to take the result as the thread barrier automatically syncs. // no need to take the result as the thread barrier automatically syncs.
Summarize(std::move(scope_.back()), nullptr); Summarize(std::move(scope_.back()), nullptr);
in_device_env_ = false; in_device_env_ = false;
scope_.pop_back(); scope_.pop_back();
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
env_threads_.CopyOnWrite()->data.pop_back(); env_threads_.CopyOnWrite()->data.pop_back();
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
} }
} }
void StorageAccessVisitor::Visit_(const For* op) { void StorageAccessVisitor::VisitStmt_(const For* op) {
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op); StmtExprVisitor::VisitStmt_(op);
StmtEntry s; StmtEntry s;
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op); s.access = Summarize(std::move(scope_.back()), op);
...@@ -161,11 +160,11 @@ void StorageAccessVisitor::Visit_(const For* op) { ...@@ -161,11 +160,11 @@ void StorageAccessVisitor::Visit_(const For* op) {
} }
} }
void StorageAccessVisitor::Visit_(const IfThenElse* op) { void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) {
++condition_counter_; ++condition_counter_;
this->Visit(op->condition); this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
this->Visit(op->then_case); this->VisitStmt(op->then_case);
StmtEntry s; StmtEntry s;
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr); s.access = Summarize(std::move(scope_.back()), nullptr);
...@@ -180,10 +179,10 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) { ...@@ -180,10 +179,10 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) {
--condition_counter_; --condition_counter_;
} }
void StorageAccessVisitor::Visit_(const Call* op) { void StorageAccessVisitor::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) { if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l); StmtExprVisitor::VisitExpr_(l);
} else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U); CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype(); DataType dtype = op->args[0].dtype();
...@@ -211,7 +210,7 @@ void StorageAccessVisitor::Visit_(const Call* op) { ...@@ -211,7 +210,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
curr_stmt_.access.emplace_back(e); curr_stmt_.access.emplace_back(e);
} }
} }
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
CHECK(allow_append_); CHECK(allow_append_);
const std::string& s = op->args[0].as<StringImm>()->value; const std::string& s = op->args[0].as<StringImm>()->value;
...@@ -224,7 +223,7 @@ void StorageAccessVisitor::Visit_(const Call* op) { ...@@ -224,7 +223,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
curr_stmt_.access.emplace_back(std::move(e)); curr_stmt_.access.emplace_back(std::move(e));
} }
} else { } else {
IRVisitor::Visit_(op); StmtExprVisitor::VisitExpr_(op);
} }
} }
...@@ -236,11 +235,12 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const { ...@@ -236,11 +235,12 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
return it->second; return it->second;
} }
class StorageAccessInfoLower : public IRMutator {
class StorageAccessInfoLower : public StmtExprMutator {
public: public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt VisitStmt_(const Allocate* op) final {
// Lower allocate to device allocate when needed. // Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
// For special memory, remove allocate, or use head expr // For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get()); auto it = storage_info_.find(op->buffer_var.get());
...@@ -259,7 +259,7 @@ class StorageAccessInfoLower : public IRMutator { ...@@ -259,7 +259,7 @@ class StorageAccessInfoLower : public IRMutator {
return stmt; return stmt;
} }
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value); StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
...@@ -270,26 +270,26 @@ class StorageAccessInfoLower : public IRMutator { ...@@ -270,26 +270,26 @@ class StorageAccessInfoLower : public IRMutator {
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
} }
storage_info_[buf] = e; storage_info_[buf] = e;
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} else { } else {
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
} }
Expr Mutate_(const Call* op, const Expr &e) final { Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e); return MakeAccessPtr(op);
} else { } else {
return IRMutator::Mutate_(op, e); return StmtExprMutator::VisitExpr_(op);
} }
} }
private: private:
// tvm_access_ptr // tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) { Expr MakeAccessPtr(const Call* op) {
// Specially handle the buffer packed intrinsic // Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U); CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype(); DataType dtype = op->args[0].dtype();
...@@ -337,7 +337,7 @@ class StorageAccessInfoLower : public IRMutator { ...@@ -337,7 +337,7 @@ class StorageAccessInfoLower : public IRMutator {
}; };
Stmt LowerStorageAccessInfo(Stmt stmt) { Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower().Mutate(stmt); return StorageAccessInfoLower()(std::move(stmt));
} }
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_functor_ext.h>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "../runtime/thread_storage_scope.h" #include "../runtime/thread_storage_scope.h"
...@@ -40,7 +40,7 @@ using runtime::StorageRank; ...@@ -40,7 +40,7 @@ using runtime::StorageRank;
/*! /*!
* \brief Base class of storage access analysis * \brief Base class of storage access analysis
*/ */
class StorageAccessVisitor : public IRVisitor { class StorageAccessVisitor : public StmtExprVisitor {
public: public:
/*! \brief Storage access type */ /*! \brief Storage access type */
enum AccessType { enum AccessType {
...@@ -76,13 +76,13 @@ class StorageAccessVisitor : public IRVisitor { ...@@ -76,13 +76,13 @@ class StorageAccessVisitor : public IRVisitor {
std::vector<AccessEntry> access; std::vector<AccessEntry> access;
}; };
// override visitor pattern // override visitor pattern
void Visit_(const Load* op) final; void VisitExpr_(const Load* op) final;
void Visit_(const Store* op) final; void VisitStmt_(const Store* op) final;
void Visit_(const Evaluate* op) final; void VisitStmt_(const Evaluate* op) final;
void Visit_(const AttrStmt* op) final; void VisitStmt_(const AttrStmt* op) final;
void Visit_(const For* op) final; void VisitStmt_(const For* op) final;
void Visit_(const IfThenElse* op) final; void VisitStmt_(const IfThenElse* op) final;
void Visit_(const Call* op) final; void VisitExpr_(const Call* op) final;
protected: protected:
StorageAccessVisitor() { StorageAccessVisitor() {
......
...@@ -26,8 +26,7 @@ ...@@ -26,8 +26,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
...@@ -48,7 +47,7 @@ using runtime::StorageScope; ...@@ -48,7 +47,7 @@ using runtime::StorageScope;
using runtime::ThreadScope; using runtime::ThreadScope;
using intrinsic::tvm_address_of; using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator { class StorageFlattener : public StmtExprMutator {
public: public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer, explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes, int cache_line_size, bool create_bound_attributes,
...@@ -64,8 +63,8 @@ class StorageFlattener : public IRMutator { ...@@ -64,8 +63,8 @@ class StorageFlattener : public IRMutator {
cache_line_size_ = cache_line_size; cache_line_size_ = cache_line_size;
} }
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt VisitStmt_(const Store* op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>(); op = stmt.as<Store>();
auto it = var_remap_.find(op->buffer_var.get()); auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() && if (it != var_remap_.end() &&
...@@ -78,14 +77,14 @@ class StorageFlattener : public IRMutator { ...@@ -78,14 +77,14 @@ class StorageFlattener : public IRMutator {
} }
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::realize_scope) { if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value; storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body); return this->VisitStmt(op->body);
} else if (op->attr_key == attr::double_buffer_scope && } else if (op->attr_key == attr::double_buffer_scope &&
op->node->IsInstance<OperationNode>()) { op->node->IsInstance<OperationNode>()) {
Operation func = Downcast<Operation>(op->node); Operation func = Downcast<Operation>(op->node);
Stmt body = Mutate(op->body); Stmt body = this->VisitStmt(op->body);
for (int i = 0; i < func->num_outputs(); ++i) { for (int i = 0; i < func->num_outputs(); ++i) {
TensorKey key{func, i}; TensorKey key{func, i};
auto it = buf_map_.find(key); auto it = buf_map_.find(key);
...@@ -99,7 +98,7 @@ class StorageFlattener : public IRMutator { ...@@ -99,7 +98,7 @@ class StorageFlattener : public IRMutator {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
ThreadScope ts = ThreadScope::make(iv->thread_tag); ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts); curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
curr_thread_scope_.pop_back(); curr_thread_scope_.pop_back();
return stmt; return stmt;
} else if (op->attr_key == attr::buffer_bind_scope) { } else if (op->attr_key == attr::buffer_bind_scope) {
...@@ -116,17 +115,17 @@ class StorageFlattener : public IRMutator { ...@@ -116,17 +115,17 @@ class StorageFlattener : public IRMutator {
} }
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value; vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value; vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
return this->Mutate(op->body); return this->VisitStmt(op->body);
} else if (op->attr_key == attr::opengl_stage_scope) { } else if (op->attr_key == attr::opengl_stage_scope) {
is_opengl_ = true; is_opengl_ = true;
} }
return IRMutator::Mutate_(op, s); return StmtExprMutator::VisitStmt_(op);
} }
Stmt Mutate_(const Provide* op, const Stmt& s) final { Stmt VisitStmt_(const Provide* op) final {
if (create_bound_attributes_) if (create_bound_attributes_)
shape_collector_.clear(); shape_collector_.clear();
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Provide>(); op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index}; TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key); auto it = buf_map_.find(key);
...@@ -159,11 +158,11 @@ class StorageFlattener : public IRMutator { ...@@ -159,11 +158,11 @@ class StorageFlattener : public IRMutator {
} }
} }
Stmt Mutate_(const Realize* op, const Stmt& s) final { Stmt VisitStmt_(const Realize* op) final {
TensorKey key{op->func, op->value_index}; TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) { if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external); CHECK(buf_map_.at(key).external);
return this->Mutate(op->body); return this->VisitStmt(op->body);
} else { } else {
// create a buffer entry // create a buffer entry
BufferEntry e; BufferEntry e;
...@@ -226,7 +225,7 @@ class StorageFlattener : public IRMutator { ...@@ -226,7 +225,7 @@ class StorageFlattener : public IRMutator {
align, 0, kDefault); align, 0, kDefault);
buf_map_[key] = e; buf_map_[key] = e;
Stmt body = this->Mutate(op->body); Stmt body = this->VisitStmt(op->body);
buf_map_[key].released = true; buf_map_[key].released = true;
Stmt ret; Stmt ret;
...@@ -263,8 +262,8 @@ class StorageFlattener : public IRMutator { ...@@ -263,8 +262,8 @@ class StorageFlattener : public IRMutator {
} }
} }
Expr Mutate_(const Load* op, const Expr& e) final { Expr VisitExpr_(const Load* op) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>(); op = expr.as<Load>();
auto it = var_remap_.find(op->buffer_var.get()); auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() && if (it != var_remap_.end() &&
...@@ -277,17 +276,17 @@ class StorageFlattener : public IRMutator { ...@@ -277,17 +276,17 @@ class StorageFlattener : public IRMutator {
} }
} }
Expr Mutate_(const Variable* op, const Expr& e) final { Expr VisitExpr_(const Variable* op) final {
auto it = var_remap_.find(op); auto it = var_remap_.find(op);
if (it != var_remap_.end()) { if (it != var_remap_.end()) {
return it->second; return it->second;
} else { } else {
return e; return GetRef<Expr>(op);
} }
} }
Expr Mutate_(const Call* op, const Expr& olde) final { Expr VisitExpr_(const Call* op) final {
Expr expr = IRMutator::Mutate_(op, olde); Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>(); op = expr.as<Call>();
if (op != nullptr && op->call_type == Call::Halide) { if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index}; TensorKey key{op->func, op->value_index};
...@@ -308,8 +307,8 @@ class StorageFlattener : public IRMutator { ...@@ -308,8 +307,8 @@ class StorageFlattener : public IRMutator {
} }
} }
Stmt Mutate_(const Prefetch *op, const Stmt &s) final { Stmt VisitStmt_(const Prefetch *op) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Prefetch>(); op = stmt.as<Prefetch>();
CHECK(op != nullptr); CHECK(op != nullptr);
TensorKey key{op->func, op->value_index}; TensorKey key{op->func, op->value_index};
...@@ -443,7 +442,7 @@ class StorageFlattener : public IRMutator { ...@@ -443,7 +442,7 @@ class StorageFlattener : public IRMutator {
// Apply the remaps // Apply the remaps
Stmt body = MergeNest(binder.asserts(), op->body); Stmt body = MergeNest(binder.asserts(), op->body);
body = MergeNest(binder.init_nest(), body); body = MergeNest(binder.init_nest(), body);
body = this->Mutate(body); body = this->VisitStmt(body);
// remove the binds // remove the binds
for (const Var& v : binder.defs()) { for (const Var& v : binder.defs()) {
var_remap_.erase(v.get()); var_remap_.erase(v.get());
...@@ -531,10 +530,10 @@ class StorageFlattener : public IRMutator { ...@@ -531,10 +530,10 @@ class StorageFlattener : public IRMutator {
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer, Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) { int cache_line_size, bool create_bound_attributes) {
IRVisitorWithAnalyzer bounded_analyzer; IRVisitorWithAnalyzer bounded_analyzer;
bounded_analyzer.Visit(stmt); bounded_analyzer(stmt);
stmt = stmt =
StorageFlattener(extern_buffer, cache_line_size, StorageFlattener(extern_buffer, cache_line_size,
create_bound_attributes, &bounded_analyzer).Mutate(stmt); create_bound_attributes, &bounded_analyzer)(std::move(stmt));
return stmt; return stmt;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment