Unverified Commit 203ca7a0 by Tianqi Chen Committed by GitHub

[REFACTOR] Migrate Low-level IR Passes into the New Stmt/Expr Mutator (#4607)

* CombineContextCall

* Migrate BoundChecker

* Migrate CoprocSync

* Migrate detect_device

* Migrate loop_partition

* Migrate infer_fragement

* Migrate inject_copy_intrin

* Migrate inject double buffer

* Migrate lower_intrin and simplify

* Migrate storage flatten

* Migrate inject prefetch

* Migrate inject_virtual_thread

* migrate inline

* Migrate lift attr scope

* Migrate custom datatypes

* migrate lower_thread_all_reduce

* Migrate lower_tvm_builtin

* migrate lower_warp memory

* Migrate make_api.cc

* Migrate remap_thread_axis

* Migrate remove_no_op

* migrate rewrite_unsafe_select

* Migrate skip_assert simple_passes

* Migrate split_host_device

* Migrate ssa

* Migrate storage_access

* Migrate storage_rewrite

* Migrate tensor_core

* Migrate unroll_loop

* Migrate vectorize

* Migrate verify compact_buffer gpu_code

* Migrate verify_memory

* Migrate storage_sync

* Remove unused refs to mutator

* Migrate hybrid_op

* Migrate tensorize

* Migrate schedule ops

* Migrate schedule_dataflow_rewrite

* Migrate auto_inline_elemwise

* Remove unecessary ref to visitor

* remove unecessary ref

* Migrate bound_deducer

* Migrate domain_touched

* Migrate autotvm feature touch extractor

* Add annotations
parent 3f43bee0
......@@ -546,6 +546,35 @@ class StmtExprMutator :
}
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
......@@ -122,27 +122,6 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const StringImm* op, const Expr& e);
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
......@@ -145,15 +145,6 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op);
};
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir
} // namespace tvm
......
......@@ -25,8 +25,7 @@
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/api_registry.h>
namespace tvm {
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <tvm/api_registry.h>
......@@ -38,17 +38,17 @@ using namespace ir;
// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
class VariablePathFinder: public ExprVisitor {
public:
explicit VariablePathFinder(Expr target) : target_(target) {}
void Visit(const ObjectRef& node) final {
void VisitExpr(const Expr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true;
IRVisitor::Visit(node);
ExprVisitor::VisitExpr(node);
if (!found_) path_.pop_back();
}
......@@ -64,14 +64,14 @@ class VariablePathFinder: public IRVisitor {
// return empty vector to represent failure
std::vector<const Object*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
v(expr);
return v.path_;
}
enum CompareOp {kGreater, kLess, kEqual};
// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
class BoundDeducer: public ExprVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
......@@ -82,39 +82,39 @@ class BoundDeducer: public IRVisitor {
void Deduce();
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
} else {
success_ = false;
return;
}
}
void Visit_(const LT* op) final {
void VisitExpr_(const LT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const LE* op) final {
void VisitExpr_(const LE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const GT* op) final {
void VisitExpr_(const GT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const GE* op) final {
void VisitExpr_(const GE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const Add* op) final {
void VisitExpr_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}
void Visit_(const Sub* op) final {
void VisitExpr_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result_ += op->b;
......@@ -123,10 +123,10 @@ class BoundDeducer: public IRVisitor {
result_ = - result_;
comp_op = ReverseOp(comp_op);
}
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}
void Visit_(const Mul* op) final {
void VisitExpr_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;
......@@ -171,7 +171,7 @@ class BoundDeducer: public IRVisitor {
// ( x <= -3/-2 --> x <= 1)
}
}
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}
Expr result_;
......@@ -194,17 +194,17 @@ class BoundDeducer: public IRVisitor {
Analyzer analyzer_;
};
class BoundDeduceInputChecker: public IRVisitor {
class BoundDeduceInputChecker: public ExprVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
Visit(deducer_->expr_);
this->VisitExpr(deducer_->expr_);
return target_count == 1;
}
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
private:
......@@ -305,7 +305,7 @@ void BoundDeducer::Deduce() {
}
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_);
this->VisitExpr(expr_);
}
void BoundDeducer::Relax() {
......
......@@ -23,7 +23,6 @@
*/
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"
......@@ -435,30 +434,30 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
Expr CanonicalSimplify(Expr expr) {
expr = Mutate(expr);
expr = operator()(expr);
return expr;
}
// override the original mutate function.
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
Expr VisitExpr(const Expr& input_expr) final {
auto expr = Rewriter::VisitExpr(input_expr);
return Normalize(expr);
}
// Normal mutation without normalization.
Expr CanonicalMutate(Expr expr) {
return IRMutator::Mutate(expr);
return Rewriter::VisitExpr(expr);
}
using Rewriter::Mutate_;
Expr Mutate_(const Add* op, const Expr& self) final;
Expr Mutate_(const Sub* op, const Expr& self) final;
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
Expr Mutate_(const FloorDiv* op, const Expr& self) final;
Expr Mutate_(const FloorMod* op, const Expr& self) final;
Expr Mutate_(const Reduce* op, const Expr& self) final;
using Rewriter::VisitExpr_;
Expr VisitExpr_(const Add* op) final;
Expr VisitExpr_(const Sub* op) final;
Expr VisitExpr_(const Mul* op) final;
Expr VisitExpr_(const Div* op) final;
Expr VisitExpr_(const Mod* op) final;
Expr VisitExpr_(const FloorDiv* op) final;
Expr VisitExpr_(const FloorMod* op) final;
Expr VisitExpr_(const Reduce* op) final;
private:
/*!
......@@ -567,9 +566,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
};
Expr CanonicalSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
VisitExpr_(const Add* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -593,9 +592,9 @@ Mutate_(const Add* op, const Expr& self) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
VisitExpr_(const Sub* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -620,9 +619,9 @@ Mutate_(const Sub* op, const Expr& self) {
Expr CanonicalSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) {
VisitExpr_(const Mul* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -652,7 +651,7 @@ Mutate_(const Mul* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return Mul::make(a, b);
}
......@@ -727,9 +726,9 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) {
VisitExpr_(const Div* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
Expr a = this->CanonicalMutate(op->a);
......@@ -781,16 +780,16 @@ Mutate_(const Div* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return Div::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) {
VisitExpr_(const FloorDiv* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
......@@ -837,7 +836,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return FloorDiv::make(a, b);
}
......@@ -866,7 +865,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
// Do a recursive call to simplify the mod with the new factor.
if (new_upper_factor < lhs->upper_factor &&
lhs->upper_factor != SplitExprNode::kPosInf) {
auto updated = ToSplitExpr(Mutate(ModImpl(
auto updated = ToSplitExpr(this->VisitExpr(ModImpl(
lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
// re-apply the lower_factor
if (lhs->lower_factor != 1) {
......@@ -894,9 +893,9 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
VisitExpr_(const Mod* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -957,16 +956,16 @@ Mutate_(const Mod* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return Mod::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) {
VisitExpr_(const FloorMod* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
return Rewriter::VisitExpr_(op);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
......@@ -1017,7 +1016,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
return GetRef<Expr>(op);
} else {
return FloorMod::make(a, b);
}
......@@ -1029,7 +1028,7 @@ SimplifyReduceCombiner(const Reduce* op) {
// First simplify the results
Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) {
Expr new_res = Mutate(res);
Expr new_res = this->VisitExpr(res);
simplified_result.push_back(new_res);
}
......@@ -1078,7 +1077,7 @@ SimplifyReduceCombiner(const Reduce* op) {
if (used[i]) {
// We simplify the result and identity, but not the source
new_result.push_back(simplified_result[i]);
new_identity.push_back(Mutate(op->combiner->identity_element[i]));
new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i]));
new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]);
......@@ -1095,9 +1094,9 @@ SimplifyReduceCombiner(const Reduce* op) {
}
Expr CanonicalSimplifier::Impl::
Mutate_(const Reduce* op, const Expr& self) {
VisitExpr_(const Reduce* op) {
// Recursively call simplification when necessary.
Expr ret = RewriteSimplifier::Impl::Mutate_(op, self);
Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
op = ret.as<Reduce>();
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
......@@ -1106,7 +1105,7 @@ Mutate_(const Reduce* op, const Expr& self) {
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
return Mutate(
return this->VisitExpr(
Select::make(op->condition,
op->source[op->value_index],
op->combiner->identity_element[op->value_index]));
......
......@@ -25,7 +25,6 @@
#define TVM_ARITHMETIC_CONST_FOLD_H_
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <algorithm>
#include <cmath>
......
......@@ -23,7 +23,6 @@
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
......@@ -36,13 +36,13 @@ namespace arith {
using namespace ir;
// Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public IRVisitor {
class FuncTouchedDomain final : public StmtExprVisitor {
public:
FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides)
: tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {}
Domain Find(const Stmt& stmt) {
this->Visit(stmt);
operator()(stmt);
Domain ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
......@@ -51,49 +51,49 @@ class FuncTouchedDomain final : public IRVisitor {
return ret;
}
void Visit_(const For *op) final {
void VisitStmt_(const For *op) final {
const Variable* var = op->loop_var.get();
dom_map_[var] = IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
void Visit_(const LetStmt* op) final {
void VisitStmt_(const LetStmt* op) final {
dom_map_[op->var.get()] =
arith::EvalSet(op->value, dom_map_);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->var.get());
}
/* TODO: Thread extent unitest not generated.*/
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const Provide* op) final {
void VisitStmt_(const Provide* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
private:
......
......@@ -30,41 +30,44 @@ namespace arith {
using namespace ir;
Stmt IRMutatorWithAnalyzer::
Mutate_(const For* op, const Stmt& s) {
VisitStmt_(const For* op) {
analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
VisitStmt_(const LetStmt* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return LetStmt::make(op->var, value, body);
auto n = this->CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
VisitStmt_(const IfThenElse* op) {
Expr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, condition);
then_case = this->Mutate(op->then_case);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not::make(condition)));
else_case = this->Mutate(op->else_case);
else_case = this->VisitStmt(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
......@@ -77,57 +80,65 @@ Mutate_(const IfThenElse* op, const Stmt& s) {
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(condition, then_case, else_case);
auto n = this->CopyOnWrite(op);
n->condition = std::move(condition);
n->then_case = std::move(then_case);
n->else_case = std::move(else_case);
return Stmt(n);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AttrStmt* op, const Stmt& s) {
VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var,
Range::make_by_min_extent(0, op->value));
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
return stmt;
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt IRMutatorWithAnalyzer::
Mutate_(const AssertStmt* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
VisitStmt_(const AssertStmt* op) {
Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return AssertStmt::make(condition, message, body);
auto n = this->CopyOnWrite(op);
n->condition = std::move(condition);
n->message = std::move(message);
n->body = std::move(body);
return Stmt(n);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Call* op, const Expr& self) {
VisitExpr_(const Call* op) {
// add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr cond = this->VisitExpr(op->args[0]);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->args[1]);
true_value = this->VisitExpr(op->args[1]);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->args[2]);
false_value = this->VisitExpr(op->args[2]);
}
if (is_zero(cond)) {
return false_value;
......@@ -138,45 +149,45 @@ Mutate_(const Call* op, const Expr& self) {
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
return self;
return GetRef<Expr>(op);
} else {
return Call::make(op->dtype, op->name,
{cond, true_value, false_value},
op->call_type);
}
}
return IRMutator::Mutate_(op, self);
return StmtExprMutator::VisitExpr_(op);
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
VisitExpr_(const Let* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
Expr body = this->Mutate(op->body);
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Select* op, const Expr& self) {
Expr cond = Mutate(op->condition);
VisitExpr_(const Select* op) {
Expr cond = this->VisitExpr(op->condition);
Expr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = Mutate(op->true_value);
true_value = VisitExpr(op->true_value);
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
false_value = Mutate(op->false_value);
false_value = VisitExpr(op->false_value);
}
if (is_zero(cond)) {
return false_value;
......@@ -188,20 +199,20 @@ Mutate_(const Select* op, const Expr& self) {
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return self;
return GetRef<Expr>(op);
} else {
return Select::make(cond, true_value, false_value);
}
}
Expr IRMutatorWithAnalyzer::
Mutate_(const Reduce* op, const Expr& self) {
VisitExpr_(const Reduce* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
return IRMutator::Mutate_(op, self);
return StmtExprMutator::VisitExpr_(op);
}
} // namespace arith
......
......@@ -24,9 +24,9 @@
#ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <utility>
namespace tvm {
namespace arith {
......@@ -40,23 +40,24 @@ namespace arith {
*
* \sa src/arithmetic/ir_mutator_with_analyzer.cc
*/
class IRMutatorWithAnalyzer : public ir::IRMutator {
class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
public:
explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {}
using IRMutator::Mutate_;
using StmtExprMutator::VisitStmt_;
using StmtExprMutator::VisitExpr_;
// override functions that need to populate the context information.
Stmt Mutate_(const ir::For* op, const Stmt& self) override;
Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::IfThenElse* op, const Stmt& self) override;
Stmt Mutate_(const ir::AttrStmt* op, const Stmt& self) override;
Stmt Mutate_(const ir::AssertStmt* op, const Stmt& self) override;
Expr Mutate_(const ir::Let* op, const Expr& self) override;
Expr Mutate_(const ir::Select* op, const Expr& self) override;
Expr Mutate_(const ir::Call* op, const Expr& self) override;
Expr Mutate_(const ir::Reduce* op, const Expr& self) override;
Stmt VisitStmt_(const ir::For* op) override;
Stmt VisitStmt_(const ir::LetStmt* op) override;
Stmt VisitStmt_(const ir::IfThenElse* op) override;
Stmt VisitStmt_(const ir::AttrStmt* op) override;
Stmt VisitStmt_(const ir::AssertStmt* op) override;
Expr VisitExpr_(const ir::Let* op) override;
Expr VisitExpr_(const ir::Select* op) override;
Expr VisitExpr_(const ir::Call* op) override;
Expr VisitExpr_(const ir::Reduce* op) override;
protected:
/*! \brief internal analyzer field. */
......
......@@ -27,43 +27,43 @@
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
class IRVisitorWithAnalyzer final : public IRVisitor {
class IRVisitorWithAnalyzer final : public StmtExprVisitor {
public:
Expr Simplify(const Expr& expr) {
return analyzer_.Simplify(expr);
}
void Visit_(const For* op) {
void VisitStmt_(const For* op) {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRVisitor::Visit_(op);
return StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const AttrStmt* op) {
void VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var,
Range::make_by_min_extent(0, op->value));
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Reduce* op) {
void VisitExpr_(const Reduce* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
protected:
......
......@@ -24,7 +24,6 @@
// Acknowledgement: Most rewrite-rules are from Halide.
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <algorithm>
#include "const_fold.h"
#include "pattern_match.h"
......@@ -69,7 +68,7 @@ using namespace ir;
// try to prove x equals val
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
TryCompare(const Expr& x, int64_t val) {
Expr diff = Mutate(x);
Expr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImm>()) {
if (ptr->value == val) {
return kEQ;
......@@ -117,8 +116,8 @@ Update(const Var& var, const Expr& info, bool override) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Add* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Add>();
Expr const_res = TryConstFold<Add>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -232,8 +231,8 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& const
}
Expr RewriteSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Sub* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Sub>();
Expr const_res = TryConstFold<Sub>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -431,8 +430,8 @@ Mutate_(const Sub* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Mul* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Mul>();
Expr const_res = TryConstFold<Mul>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -470,8 +469,8 @@ Mutate_(const Mul* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Div* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Div>();
Expr const_res = TryConstFold<Div>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -692,8 +691,8 @@ Mutate_(const Div* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Mod* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Mod>();
Expr const_res = TryConstFold<Mod>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -782,8 +781,8 @@ Mutate_(const Mod* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const FloorDiv* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDiv>();
Expr const_res = TryConstFold<FloorDiv>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -926,8 +925,8 @@ Mutate_(const FloorDiv* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const FloorMod* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorMod>();
Expr const_res = TryConstFold<FloorMod>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -996,8 +995,8 @@ Mutate_(const FloorMod* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Min* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Min* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Min>();
Expr const_res = TryConstFold<Min>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1181,8 +1180,8 @@ Mutate_(const Min* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Max* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Max* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Max>();
Expr const_res = TryConstFold<Max>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1354,8 +1353,8 @@ Mutate_(const Max* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const EQ* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const EQ* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQ>();
Expr const_res = TryConstFold<EQ>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1388,28 +1387,28 @@ Mutate_(const EQ* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const NE* op, const Expr& self) {
return Mutate(Not::make(op->a == op->b));
VisitExpr_(const NE* op) {
return this->VisitExpr(Not::make(op->a == op->b));
}
Expr RewriteSimplifier::Impl::
Mutate_(const LE* op, const Expr& self) {
return Mutate(Not::make(op->b < op->a));
VisitExpr_(const LE* op) {
return this->VisitExpr(Not::make(op->b < op->a));
}
Expr RewriteSimplifier::Impl::
Mutate_(const GT* op, const Expr& self) {
return Mutate(op->b < op->a);
VisitExpr_(const GT* op) {
return this->VisitExpr(op->b < op->a);
}
Expr RewriteSimplifier::Impl::
Mutate_(const GE* op, const Expr& self) {
return Mutate(Not::make(op->a < op->b));
VisitExpr_(const GE* op) {
return this->VisitExpr(Not::make(op->a < op->b));
}
Expr RewriteSimplifier::Impl::
Mutate_(const LT* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const LT* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LT>();
Expr const_res = TryConstFold<LT>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1564,8 +1563,8 @@ Mutate_(const LT* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Not* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Not* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Not>();
Expr const_res = TryConstFold<Not>(op->a);
if (const_res.defined()) return const_res;
......@@ -1589,8 +1588,8 @@ Mutate_(const Not* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const And* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const And* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<And>();
Expr const_res = TryConstFold<And>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1638,8 +1637,8 @@ Mutate_(const And* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Or* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Or* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Or>();
Expr const_res = TryConstFold<Or>(op->a, op->b);
if (const_res.defined()) return const_res;
......@@ -1688,8 +1687,8 @@ Mutate_(const Or* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
VisitExpr_(const Select* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Select>();
if (op == nullptr) return ret;
// Pattern var to match any expression
......@@ -1699,9 +1698,9 @@ Mutate_(const Select* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) {
VisitExpr_(const Call* op) {
// add condition context to if_then_else
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self);
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Call>();
if (op == nullptr) return ret;
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
......@@ -1729,35 +1728,35 @@ Mutate_(const Call* op, const Expr& self) {
}
Expr RewriteSimplifier::Impl::
Mutate_(const Variable* op, const Expr& self) {
VisitExpr_(const Variable* op) {
Var var = GetRef<Var>(op);
auto it = var_map_.find(var);
if (it != var_map_.end()) {
return it->second;
}
return self;
return GetRef<Expr>(op);
}
Expr RewriteSimplifier::Impl::
Mutate_(const Cast* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
VisitExpr_(const Cast* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<Cast>();
return cast(op->dtype, op->value);
}
Expr RewriteSimplifier::Impl::
Mutate_(const Let* op, const Expr& self) {
Expr value = this->Mutate(op->value);
VisitExpr_(const Let* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
return this->Mutate(op->body);
return this->VisitExpr(op->body);
}
Expr body = this->Mutate(op->body);
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return self;
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
......@@ -1768,7 +1767,7 @@ Expr RewriteSimplifier::operator()(const Expr& expr) {
Expr res = expr;
int max_iter = 2;
for (int i = 0; i < max_iter; ++i) {
Expr new_expr = impl_->Mutate(res);
Expr new_expr = impl_->operator()(res);
if (new_expr.same_as(res)) return res;
res = new_expr;
}
......
......@@ -26,7 +26,6 @@
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
#include <vector>
#include "const_fold.h"
......@@ -45,35 +44,35 @@ using namespace ir;
*/
class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::Mutate_;
using IRMutatorWithAnalyzer::VisitExpr_;
explicit Impl(Analyzer* parent)
: IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override);
Expr Mutate_(const Add* op, const Expr& self) override;
Expr Mutate_(const Sub* op, const Expr& self) override;
Expr Mutate_(const Mul* op, const Expr& self) override;
Expr Mutate_(const Div* op, const Expr& self) override;
Expr Mutate_(const Mod* op, const Expr& self) override;
Expr Mutate_(const FloorDiv* op, const Expr& self) override;
Expr Mutate_(const FloorMod* op, const Expr& self) override;
Expr Mutate_(const Min* op, const Expr& self) override;
Expr Mutate_(const Max* op, const Expr& self) override;
Expr Mutate_(const EQ* op, const Expr& self) override;
Expr Mutate_(const NE* op, const Expr& self) override;
Expr Mutate_(const LT* op, const Expr& self) override;
Expr Mutate_(const LE* op, const Expr& self) override;
Expr Mutate_(const GT* op, const Expr& self) override;
Expr Mutate_(const GE* op, const Expr& self) override;
Expr Mutate_(const And* op, const Expr& self) override;
Expr Mutate_(const Or* op, const Expr& self) override;
Expr Mutate_(const Not* op, const Expr& self) override;
Expr Mutate_(const Select* op, const Expr& self) override;
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
void Update(const Var& var, const Expr& info, bool override_info);
Expr VisitExpr_(const Add* op) override;
Expr VisitExpr_(const Sub* op) override;
Expr VisitExpr_(const Mul* op) override;
Expr VisitExpr_(const Div* op) override;
Expr VisitExpr_(const Mod* op) override;
Expr VisitExpr_(const FloorDiv* op) override;
Expr VisitExpr_(const FloorMod* op) override;
Expr VisitExpr_(const Min* op) override;
Expr VisitExpr_(const Max* op) override;
Expr VisitExpr_(const EQ* op) override;
Expr VisitExpr_(const NE* op) override;
Expr VisitExpr_(const LT* op) override;
Expr VisitExpr_(const LE* op) override;
Expr VisitExpr_(const GT* op) override;
Expr VisitExpr_(const GE* op) override;
Expr VisitExpr_(const And* op) override;
Expr VisitExpr_(const Or* op) override;
Expr VisitExpr_(const Not* op) override;
Expr VisitExpr_(const Select* op) override;
Expr VisitExpr_(const Call* op) override;
Expr VisitExpr_(const Variable* op) override;
Expr VisitExpr_(const Cast* op) override;
Expr VisitExpr_(const Let* op) override;
std::function<void()> EnterConstraint(const Expr& constraint);
......@@ -123,7 +122,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
Expr res = this->VisitExpr(x);
--recur_depth_;
return res;
}
......
......@@ -24,7 +24,6 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
#include "ir_mutator_with_analyzer.h"
......@@ -40,44 +39,47 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
: IRMutatorWithAnalyzer(analyzer) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::Mutate;
using Parent::Mutate_;
using Parent::VisitStmt;
using Parent::VisitStmt_;
Expr Mutate(Expr expr) final {
Expr VisitExpr(const Expr& expr) final {
return analyzer_->Simplify(expr);
}
Stmt Simplify(Stmt stmt) {
return Mutate(stmt);
return operator()(std::move(stmt));
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return IRMutator::Mutate_(op, s);
return Parent::VisitStmt_(op);
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt VisitStmt_(const LetStmt* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value);
return Mutate(op->body);
return this->VisitStmt(op->body);
}
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return LetStmt::make(op->var, value, body);
auto n = this->CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
}
// eliminate useless stores
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
......@@ -85,7 +87,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Evaluate::make(0);
}
}
return stmt;
return GetRef<Stmt>(op);
}
};
......@@ -98,7 +100,7 @@ Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return arith::StmtSimplifier(&analyzer).Simplify(stmt);
return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt));
}
Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
......@@ -119,7 +121,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
}
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(stmt, vrange);
return CanonicalSimplify(std::move(stmt), vrange);
}
} // namespace ir
} // namespace tvm
......@@ -29,7 +29,7 @@ namespace tvm {
namespace autotvm {
// for loop
void FeatureVisitor::Visit_(const For *op) {
void FeatureVisitor::VisitStmt_(const For* op) {
const auto *extent = op->extent.as<IntImm>();
int64_t loop_extent = -1;
if (extent != nullptr)
......@@ -51,13 +51,13 @@ void FeatureVisitor::Visit_(const For *op) {
}
if (EnterItervar_(op->loop_var, loop_extent, ann)) {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitItervar_();
}
}
// parallel axis, virtual thread
void FeatureVisitor::Visit_(const AttrStmt *op) {
void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
......@@ -86,24 +86,24 @@ void FeatureVisitor::Visit_(const AttrStmt *op) {
}
if (EnterItervar_(var, extent->value, ann)) {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitItervar_();
}
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
// memory access
void FeatureVisitor::Visit_(const Load *op) {
void FeatureVisitor::VisitExpr_(const Load* op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
ExitMem_();
}
void FeatureVisitor::Visit_(const Store *op) {
void FeatureVisitor::VisitStmt_(const Store* op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitMem_();
}
......
......@@ -27,7 +27,7 @@
#define TVM_AUTOTVM_FEATURE_VISITOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <string>
namespace tvm {
......@@ -48,15 +48,18 @@ enum AnnotationType {
* \brief A base class for feature extractor, used for processing
* for loop and memory access in the IR
*/
class FeatureVisitor : public IRVisitor {
class FeatureVisitor : public StmtExprVisitor {
public:
// for loop
void Visit_(const For *op);
void Visit_(const AttrStmt *op);
void VisitStmt_(const For *op);
void VisitStmt_(const AttrStmt *op);
// memory access
void Visit_(const Load *op);
void Visit_(const Store *op);
void VisitExpr_(const Load *op);
void VisitStmt_(const Store *op);
using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_;
protected:
/*!
......
......@@ -44,14 +44,14 @@ int ParallelLevel(AnnotationType ann) {
}
// get touch pattern from index expression
class IndexParser: public IRVisitor {
class IndexParser: public ExprVisitor {
public:
void Parse(Expr expr) {
pattern_map.clear();
this->Visit(expr);
this->VisitExpr(expr);
}
void Visit_(const Variable *op) {
void VisitExpr_(const Variable *op) {
// TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern();
......@@ -60,13 +60,13 @@ class IndexParser: public IRVisitor {
}
}
void Visit_(const Mul *op) {
void VisitExpr_(const Mul *op) {
if (op->a.as<Variable>()) {
if (const auto stride = op->b.as<IntImm>()) {
next_stride_ = stride->value;
}
}
IRVisitor::Visit_(op);
ExprVisitor::VisitExpr_(op);
}
std::unordered_map<const Variable*, TouchPattern> pattern_map;
......
......@@ -26,7 +26,7 @@
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/api_registry.h>
#include <stack>
#include <vector>
......@@ -85,39 +85,39 @@ struct ItervarFeature {
// extract iter vars and their touch pattern from ir
class TouchExtractor : public FeatureVisitor {
public:
void Analyze(Stmt stmt) {
this->Visit(stmt);
void Analyze(const Stmt& stmt) {
operator()(stmt);
}
// arithmetic stats
void Visit_(const Add *op) {
void VisitExpr_(const Add *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Sub *op) {
void VisitExpr_(const Sub *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Mul *op) {
void VisitExpr_(const Mul *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Div *op) {
void VisitExpr_(const Div *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
void Visit_(const Mod *op) {
void VisitExpr_(const Mod *op) {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
FeatureVisitor::VisitExpr_(op);
}
std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
......@@ -134,7 +134,7 @@ class TouchExtractor : public FeatureVisitor {
std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing
std::deque<size_t> skip_stack_size_;
using IRVisitor::Visit_;
using FeatureVisitor::VisitExpr_;
};
} // namespace autotvm
......
......@@ -24,8 +24,8 @@
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_set>
#include <string>
#include <utility>
......@@ -538,7 +538,7 @@ namespace {
* must be Reduce as well; and their inputs should have the
* same attribute except value_index.
*/
class ComputeVerifier final : protected ir::IRVisitor {
class ComputeVerifier final : protected ir::ExprVisitor {
public:
/// Special member functions
//@{
......@@ -567,20 +567,20 @@ class ComputeVerifier final : protected ir::IRVisitor {
}
level_ = 0;
ir::IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
}
protected:
/// Visitor implementation
//@{
void Visit(const ObjectRef& n) final {
void VisitExpr(const Expr& n) final {
++level_;
ir::IRVisitor::Visit(n);
ExprVisitor::VisitExpr(n);
--level_;
}
void Visit_(const ir::Reduce* op) final {
void VisitExpr_(const ir::Reduce* op) final {
// Check for non top level reductions
CHECK(0 == level_)
<< "Reductions are only allowed at the top level of compute. "
......
......@@ -24,7 +24,7 @@
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include <unordered_set>
......@@ -221,7 +221,7 @@ namespace op {
Stmt ApplyLoopShapes(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
class LoopSpliter : public IRMutator {
class LoopSpliter : public StmtExprMutator {
Expr factor;
const Variable *parent;
IterVar inner, outer;
......@@ -247,7 +247,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
}
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For *op) final {
if (op->loop_var.get() == parent) {
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = inner + outer * factor;
......@@ -261,11 +261,11 @@ Stmt ApplyLoopShapes(const Stage &stage,
splitted = true;
return ret;
}
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
};
class LoopFuser : public IRMutator {
class LoopFuser : public StmtExprMutator {
const IterVar &parent;
const Variable *inner;
const Variable *outer;
......@@ -280,8 +280,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
extent(0), fused(false) {}
// TODO(@were): Handle imperfect loops
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For* op) final {
if (op->loop_var.get() == inner) {
CHECK(under_outer);
std::unordered_map<const Variable *, Expr> rmap;
......@@ -291,7 +290,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
return ir::Substitute(op->body, rmap);
} else if (op->loop_var.get() == outer) {
under_outer = true;
Stmt body = IRMutator::Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap);
......@@ -299,25 +298,25 @@ Stmt ApplyLoopShapes(const Stage &stage,
return For::make(parent->var, Expr(0), extent * op->extent,
op->for_type, op->device_api, body);
} else if (under_outer) {
Stmt body = IRMutator::Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap);
extent = extent * op->extent;
return body;
}
return IRMutator::Mutate(stmt);
return StmtExprMutator::VisitStmt_(op);
}
};
for (auto &rel : stage->relations) {
if (const SplitNode *split = rel.as<SplitNode>()) {
LoopSpliter Spliter(split, dom_map);
stmt = Spliter.Mutate(stmt);
stmt = Spliter(stmt);
CHECK(Spliter.splitted);
} else if (const FuseNode *fuse = rel.as<FuseNode>()) {
LoopFuser Fuser(fuse);
stmt = Fuser.Mutate(stmt);
stmt = Fuser(stmt);
CHECK(Fuser.fused);
}
}
......@@ -327,14 +326,14 @@ Stmt ApplyLoopShapes(const Stage &stage,
Stmt ApplyLoopAnnotations(const Stage &stage,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
class LoopAnnotator : public IRMutator {
class LoopAnnotator : public StmtMutator {
const Variable *var;
const IterVarAttr &attr;
public:
LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For *op) final {
if (op->loop_var.get() == var) {
if (attr->bind_thread.defined()) {
const auto &iter_var = attr->bind_thread;
......@@ -352,7 +351,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
}
}
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
};
......@@ -381,7 +380,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
CHECK_EQ(found, 1) << " iter var should be found exactly once!";
if (need_change) {
stmt = LoopAnnotator(var, attr).Mutate(stmt);
stmt = LoopAnnotator(var, attr)(std::move(stmt));
}
}
return stmt;
......@@ -411,7 +410,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
}
}
class LoopReorder : public IRMutator {
class LoopReorder : public StmtMutator {
const Stage &stage;
const std::unordered_map<IterVar, Range> &dom_map;
const std::unordered_map<const Variable *, IterVar> &reorder;
......@@ -422,13 +421,13 @@ Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<const Variable*, IterVar> &reorder)
: stage(stage), dom_map(dom_map), reorder(reorder) {}
Stmt Mutate_(const For *op, const Stmt &stmt) {
Stmt VisitStmt_(const For* op) final {
// Reorder from in to out
Stmt body_ = IRMutator::Mutate(op->body);
Stmt body_ = this->VisitStmt(op->body);
CHECK(reorder.count(op->loop_var.get()));
auto target = reorder.find(op->loop_var.get())->second;
if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
return stmt;
return GetRef<Stmt>(op);
const Stmt &body = op->body.same_as(body_) ? op->body : body_;
ForType for_type = IterVarTypeToForType(target->iter_type);
if (stage->iter_var_attrs.count(target)) {
......@@ -441,7 +440,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
};
if (need_reorder)
return LoopReorder(stage, dom_map, reorder).Mutate(stmt);
return LoopReorder(stage, dom_map, reorder)(stmt);
return stmt;
}
......@@ -479,21 +478,21 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) {
}
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
class ProviderReplacer : public ir::StmtMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt &s) {
Stmt VisitStmt_(const ir::Provide* op) final {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
return this->VisitStmt(ret);
}
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
// whether it is found.
......@@ -506,7 +505,7 @@ class ProviderReplacer : public ir::IRMutator {
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor> &replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
Stmt ret = repl(stmt);
return repl.found ? ret : stmt;
}
} // namespace op
......
......@@ -25,8 +25,6 @@
#define TVM_OP_HYBRID_OP_H_
#include <tvm/expr.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include <unordered_set>
......
......@@ -23,8 +23,8 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <string>
#include "op_util.h"
#include "../schedule/message_passing.h"
......@@ -186,12 +186,12 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
class TensorReplacer : public ir::StmtExprMutator {
public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
Expr VisitExpr_(const ir::Call* op) final {
if (op->call_type == ir::Call::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
......@@ -200,10 +200,10 @@ class TensorReplacer : public ir::IRMutator {
op->dtype, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
return this->VisitExpr(ret);
}
}
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
// whether it is found.
......@@ -216,13 +216,13 @@ class TensorReplacer : public ir::IRMutator {
Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
Stmt ret = repl(stmt);
return repl.found ? ret : stmt;
}
Expr ReplaceTensor(Expr expr,
const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
Expr ret = repl.Mutate(expr);
Expr ret = repl(expr);
return repl.found ? ret : expr;
}
......
......@@ -24,7 +24,6 @@
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./op_util.h"
......
......@@ -22,7 +22,7 @@
* \file tensorize.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include "op_util.h"
......@@ -157,10 +157,10 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
}
// Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public IRMutator {
class TensorIntrinMatcher final : public StmtExprMutator {
public:
Expr Mutate_(const Call* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Call* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
if (op->call_type == Call::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
......@@ -180,17 +180,17 @@ class TensorIntrinMatcher final : public IRMutator {
return expr;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
return e;
return GetRef<Expr>(op);
}
}
Expr Mutate_(const Reduce* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Reduce* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Reduce>();
Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) {
......@@ -317,7 +317,7 @@ Array<Expr> MatchTensorizeBody(
matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
Array<Expr> ret;
for (Expr expr : self->body) {
ret.push_back(matcher.Mutate(expr));
ret.push_back(matcher(expr));
}
return ret;
}
......
......@@ -23,9 +23,8 @@
// Instrument checkers for out of the bounds access.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <vector>
#include <unordered_map>
#include <utility>
......@@ -33,48 +32,48 @@
namespace tvm {
namespace ir {
class BoundCollector : public IRVisitor {
class BoundCollector : public StmtVisitor {
public:
BoundCollector() {}
void Visit_(const AttrStmt *op) {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == ir::attr::buffer_bound) {
if (const Variable *key = op->node.as<Variable>()) {
mem_to_shape[key] = op->value;
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape;
};
class BoundChecker : public IRMutator {
class BoundChecker : public StmtExprMutator {
public:
explicit BoundChecker(
const std::unordered_map<const Variable *, Expr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
Stmt Mutate_(const Allocate *op, const Stmt &s) final {
Stmt VisitStmt_(const Allocate* op) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->dtype);
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Expr Mutate_(const Call *op, const Expr &ex) final {
Expr VisitExpr_(const Call* op) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true;
}
return IRMutator::Mutate_(op, ex);
return StmtExprMutator::VisitExpr_(op);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
Stmt VisitStmt_(const Store* op) final {
store_scope_bound_collector_.clear();
process_store_ = true;
unsafe_rewritten_ = false;
IRMutator::Mutate_(op, s);
StmtExprMutator::VisitStmt_(op);
process_store_ = false;
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
......@@ -92,23 +91,24 @@ class BoundChecker : public IRMutator {
return body;
}
}
return s;
return GetRef<Stmt>(op);
}
Expr Mutate_(const Load *op, const Expr &ex) final {
Expr VisitExpr_(const Load* op) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
return IRMutator::Mutate_(op, ex);
return StmtExprMutator::VisitExpr_(op);
}
private:
bool UpdateIsNeeded(const VarExpr &buffer_var) const {
bool UpdateIsNeeded(const VarExpr& buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}
void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape,
const DataType &type) {
void Update(const VarExpr& buffer_var,
const Array<Expr>& new_shape,
const DataType& type) {
// Sanity check at first.
if (!new_shape.size()) {
return;
......@@ -132,7 +132,7 @@ class BoundChecker : public IRMutator {
mem_to_shape_[buffer_var.get()] = shape;
}
bool IndexIsValid(const Expr &index) const {
bool IndexIsValid(const Expr& index) const {
if (!index.defined()) {
return false;
}
......@@ -146,7 +146,7 @@ class BoundChecker : public IRMutator {
return true;
}
bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const {
bool CanInstrument(const Expr& index, const VarExpr& buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
IndexIsValid(index) && !unsafe_rewritten_;
}
......@@ -206,8 +206,8 @@ class BoundChecker : public IRMutator {
Stmt InstrumentBoundCheckers(Stmt stmt) {
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector.Visit(stmt);
return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt);
bound_collector(stmt);
return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
}
} // namespace ir
} // namespace tvm
......@@ -23,7 +23,7 @@
* \file combine_context_call.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <map>
......@@ -32,7 +32,7 @@ namespace ir {
// Calculate the statistics of packed function.
// These information are needed during codegen.
class ContextCallCombiner final : public IRMutator {
class ContextCallCombiner final : public StmtExprMutator {
public:
struct CompareExpr {
bool operator()(const Expr& lhs, const Expr& rhs) const {
......@@ -40,7 +40,7 @@ class ContextCallCombiner final : public IRMutator {
}
};
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U);
Expr ctx = op->args[0];
......@@ -60,39 +60,39 @@ class ContextCallCombiner final : public IRMutator {
return std::move(ctx_var);
}
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
return BuildContext(temp, stmt);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
if (op->for_type == ForType::Parallel) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
return BuildContext(temp, stmt);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Combine(Stmt stmt) {
return BuildContext(ctx_map_, this->Mutate(stmt));
return BuildContext(ctx_map_, this->VisitStmt(stmt));
}
private:
......
......@@ -22,8 +22,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
......@@ -33,25 +32,25 @@ namespace tvm {
namespace ir {
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public IRVisitor {
class CoProcTouchedBuffer : public StmtExprVisitor {
public:
void Visit_(const Load* op) final {
void VisitExpr_(const Load* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const Store* op) final {
void VisitStmt_(const Store* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
const Variable* buffer = op->args[1].as<Variable>();
if (in_scope_) {
......@@ -60,17 +59,17 @@ class CoProcTouchedBuffer : public IRVisitor {
touched_[buffer].normal = true;
}
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv = Downcast<IterVar>(op->node);
coproc_.insert(iv);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
in_scope_ = false;
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
......@@ -96,7 +95,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
}
void Plan(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
PlanSync(scope_.back(), nullptr, true);
if (sync_.size() == 0) {
sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync");
......@@ -218,14 +217,14 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
write_barrier_name_ = coproc_name + ".coproc_write_barrier";
}
void PlanReadBarrier(Stmt stmt) {
void PlanReadBarrier(const Stmt& stmt) {
read_barrier_ = true;
this->Visit(stmt);
this->VisitStmt(stmt);
PlanReadBarrier(scope_.back(), nullptr);
}
void PlanWriteBarrier(Stmt stmt) {
void PlanWriteBarrier(const Stmt& stmt) {
read_barrier_ = false;
this->Visit(stmt);
this->VisitStmt(stmt);
PlanWriteBarrier(scope_.back(), nullptr);
}
......@@ -356,7 +355,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
};
class CoProcInstDepDetector : public IRVisitor {
class CoProcInstDepDetector : public StmtVisitor {
public:
explicit CoProcInstDepDetector(
const IterVar& coproc_axis,
......@@ -366,15 +365,15 @@ class CoProcInstDepDetector : public IRVisitor {
sync_pop_name_ = coproc_name + ".coproc_dep_pop";
}
void Plan(Stmt stmt) {
this->Visit(stmt);
void Plan(const Stmt& stmt) {
this->VisitStmt(stmt);
if (last_state_.node != nullptr) {
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
}
}
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope &&
op->node.same_as(coproc_axis_)) {
const IntImm* ctx_id = op->value.as<IntImm>();
......@@ -385,15 +384,15 @@ class CoProcInstDepDetector : public IRVisitor {
curr_state_.exit_ctx.insert(ctx_id->value);
UpdateState();
} else {
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
}
void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
SyncState temp_first, temp_last;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
this->Visit(op->body);
this->VisitStmt(op->body);
curr_state_.clear();
if (last_state_.node != nullptr) {
curr_state_.node = op;
......@@ -412,13 +411,13 @@ class CoProcInstDepDetector : public IRVisitor {
}
}
void Visit_(const IfThenElse* op) final {
void VisitStmt_(const IfThenElse* op) final {
SyncState temp_first, temp_last, curr_state;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
{
// then stmt
this->Visit(op->then_case);
this->VisitStmt(op->then_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
......@@ -434,7 +433,7 @@ class CoProcInstDepDetector : public IRVisitor {
last_state_.clear();
}
if (op->else_case.defined()) {
this->Visit(op->else_case);
this->VisitStmt(op->else_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
......@@ -606,11 +605,11 @@ class CoProcInstDepDetector : public IRVisitor {
};
class CoProcSyncInserter : public IRMutator {
class CoProcSyncInserter : public StmtMutator {
public:
Stmt Insert(Stmt stmt) {
CoProcTouchedBuffer visitor;
visitor.Visit(stmt);
visitor(stmt);
if (visitor.coproc_.size() == 0) return stmt;
std::unordered_set<const Variable*> touched;
......@@ -652,10 +651,10 @@ class CoProcSyncInserter : public IRMutator {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
return Mutate(stmt);
return operator()(std::move(stmt));
}
Stmt Mutate(Stmt stmt) final {
Stmt VisitStmt(const Stmt& stmt) final {
Stmt before, after;
auto it = insert_before_.find(stmt.get());
if (it != insert_before_.end()) {
......@@ -666,14 +665,14 @@ class CoProcSyncInserter : public IRMutator {
if (it != insert_after_.end()) {
after = MergeSeq(it->second);
}
stmt = IRMutator::Mutate(stmt);
Stmt new_stmt = StmtMutator::VisitStmt(stmt);
if (before.defined()) {
stmt = Block::make(before, stmt);
new_stmt = Block::make(before, new_stmt);
}
if (after.defined()) {
stmt = Block::make(stmt, after);
new_stmt = Block::make(new_stmt, after);
}
return stmt;
return new_stmt;
}
private:
......@@ -685,7 +684,7 @@ class CoProcSyncInserter : public IRMutator {
Stmt CoProcSync(Stmt stmt) {
return CoProcSyncInserter().Insert(stmt);
return CoProcSyncInserter().Insert(std::move(stmt));
}
} // namespace ir
......
......@@ -22,7 +22,6 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include "../pass/ir_util.h"
namespace tvm {
......
......@@ -21,9 +21,7 @@
* \file hoist_if_then_else.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <tvm/api_registry.h>
#include <unordered_map>
......
......@@ -23,8 +23,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
......@@ -35,7 +34,7 @@ namespace tvm {
namespace ir {
// Get fragment information from tensor intrinsics
class FragmentGetter : public IRVisitor {
class FragmentGetter : public StmtExprVisitor {
public:
// fragment metadata
struct FragmentInfo {
......@@ -48,8 +47,8 @@ class FragmentGetter : public IRVisitor {
: m(_m), n(_n), k(_k), layout(_layout) {}
};
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
void VisitExpr_(const Call* op) final {
StmtExprVisitor::VisitExpr_(op);
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
......@@ -116,13 +115,13 @@ class FragmentGetter : public IRVisitor {
}
// Get memory scope
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buffer = op->node.as<Variable>();
CHECK(buffer);
scopes[buffer] = op->value.as<StringImm>()->value;
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
// Memory scope for allocations
......@@ -132,11 +131,12 @@ class FragmentGetter : public IRVisitor {
};
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public IRVisitor {
class FragmentChecker : public StmtExprVisitor {
public:
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
CHECK_EQ(op->args.size(), 8U);
......@@ -170,12 +170,12 @@ class FragmentChecker : public IRVisitor {
};
// Store the metadata into attributes
class InferFragmenter : public IRMutator {
class InferFragmenter : public StmtMutator {
public:
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const Variable* buffer = op->buffer_var.get();
if (fragment_getter.fragments.count(buffer)) {
// Add attribute to fragments allocation
......@@ -206,9 +206,10 @@ class InferFragmenter : public IRMutator {
Stmt InferFragment(Stmt stmt) {
FragmentGetter getter;
getter.Visit(stmt);
FragmentChecker(getter).Visit(stmt);
stmt = InferFragmenter(getter).Mutate(stmt);
getter(stmt);
FragmentChecker checker(getter);
checker(stmt);
stmt = InferFragmenter(getter)(std::move(stmt));
return stmt;
}
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include "../arithmetic/pattern_match.h"
......@@ -32,7 +32,7 @@ namespace ir {
using runtime::PackedFunc;
class CopyIntrinInjector : public IRMutator {
class CopyIntrinInjector : public StmtMutator {
public:
CopyIntrinInjector(const std::string& pragma_key,
const PackedFunc& flower_copy_fromto)
......@@ -40,7 +40,7 @@ class CopyIntrinInjector : public IRMutator {
flower_copy_fromto_(flower_copy_fromto) {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = op->value.as<StringImm>()->value;
......@@ -50,7 +50,7 @@ class CopyIntrinInjector : public IRMutator {
<< "Cannot match copy pattern of " << op->body;
return ret;
}
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
private:
......@@ -193,8 +193,7 @@ class CopyIntrinInjector : public IRMutator {
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const PackedFunc& flower_copy_fromto) {
return CopyIntrinInjector(pragma_key, flower_copy_fromto)
.Mutate(stmt);
return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
}
} // namespace ir
......
......@@ -22,8 +22,7 @@
* \file inject_double_buffer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/expr_operator.h>
#include "ir_util.h"
#include "../arithmetic/compute_expr.h"
......@@ -32,18 +31,18 @@ namespace tvm {
namespace ir {
// Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor {
class DoubleBufferDetector : public StmtExprVisitor {
public:
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<Variable>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Variable* op) final {
void VisitExpr_(const Variable* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
......@@ -53,55 +52,55 @@ class DoubleBufferDetector : public IRVisitor {
};
class StripDoubleBufferWrite : public IRMutator {
class StripDoubleBufferWrite : public StmtMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_write) {
return Mutate(op->body);
return VisitStmt(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
}
};
class DoubleBufferInjector : public IRMutator {
class DoubleBufferInjector : public StmtExprMutator {
public:
explicit DoubleBufferInjector(int split_loop)
: split_loop_(split_loop) {}
Stmt Inject(const Stmt& stmt) {
Stmt Inject(Stmt stmt) {
DoubleBufferDetector detector;
detector.Visit(stmt);
detector(stmt);
if (detector.touched_.empty()) return stmt;
for (const Variable* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry();
}
return ConvertSSA(this->Mutate(stmt));
return ConvertSSA(operator()(std::move(stmt)));
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
it->second.scope = op->value.as<StringImm>()->value;
return Mutate(op->body);
return this->VisitStmt(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
} else if (op->attr_key == attr::double_buffer_scope) {
return MakeProducer(op, s);
return MakeProducer(op);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul>(
op->extents, Expr()) * op->dtype.lanes();
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
for (Expr e : op->extents) {
......@@ -118,13 +117,13 @@ class DoubleBufferInjector : public IRMutator {
Evaluate::make(0)));
return op->body;
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
loop_nest_.push_back(op);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>();
......@@ -151,7 +150,7 @@ class DoubleBufferInjector : public IRMutator {
MergeSeq(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body);
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) {
Expr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx;
......@@ -171,8 +170,8 @@ class DoubleBufferInjector : public IRMutator {
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
......@@ -188,8 +187,8 @@ class DoubleBufferInjector : public IRMutator {
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
......@@ -205,20 +204,20 @@ class DoubleBufferInjector : public IRMutator {
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
CHECK(!dbuffer_info_.count(op));
return e;
return GetRef<Expr>(op);
}
private:
Stmt MakeProducer(const AttrStmt* op, const Stmt& s) {
Stmt MakeProducer(const AttrStmt* op) {
const VarExpr buffer = Downcast<VarExpr>(op->node);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node;
return Mutate(op->body);
return this->VisitStmt(op->body);
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
......@@ -230,7 +229,7 @@ class DoubleBufferInjector : public IRMutator {
e.loop->loop_var.dtype());
e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
in_double_buffer_scope_ = false;
std::unordered_map<const Variable*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero;
......
......@@ -22,8 +22,7 @@
*/
// Inject prefetch op in HalideIR
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <unordered_set>
......@@ -34,10 +33,10 @@ namespace ir {
using arith::IntSet;
using arith::DomainTouched;
class PrefetchInjector : public IRMutator {
class PrefetchInjector : public StmtMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt ret = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const AttrStmt* op) final {
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmt>();
if (op && op->attr_key == attr::prefetch_scope) {
Tensor ts = Downcast<Tensor>(op->node);
......@@ -65,13 +64,13 @@ class PrefetchInjector : public IRMutator {
return ret;
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
auto &var = op->loop_var;
loop_nest_.push_back(var);
if (op->for_type == ForType::Vectorized) {
vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1);
}
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtMutator::VisitStmt_(op);
if (op->for_type == ForType::Vectorized) {
vectorized_.erase(var.get());
}
......@@ -88,7 +87,7 @@ class PrefetchInjector : public IRMutator {
const Range PrefetchInjector::none;
Stmt InjectPrefetch(Stmt stmt) {
return PrefetchInjector().Mutate(stmt);
return PrefetchInjector()(std::move(stmt));
}
} // namespace ir
......
......@@ -21,8 +21,7 @@
* \file inject_virtual_thread.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "../arithmetic/compute_expr.h"
......@@ -31,25 +30,30 @@ namespace tvm {
namespace ir {
// If expression is touched by var.
class ExprTouched final : public IRVisitor {
class ExprTouched final : public StmtExprVisitor {
public:
explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
bool check_write)
: touched_var_(touched), check_write_(check_write) {}
void Visit(const ObjectRef& n) final {
void VisitExpr(const Expr& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
IRVisitor::Visit(n);
StmtExprVisitor::VisitExpr(n);
}
void Visit_(const Load *op) final {
void VisitStmt(const Stmt& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr_(const Load *op) final {
HandleUseVar(op->buffer_var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const Variable *op) final {
void VisitExpr_(const Variable *op) final {
HandleUseVar(op);
}
void Visit_(const Call *op) final {
void VisitExpr_(const Call *op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
int rw_mask = 0;
CHECK(arith::GetConstInt(op->args[4], &rw_mask));
......@@ -62,9 +66,9 @@ class ExprTouched final : public IRVisitor {
if (rw_mask & 2) {
HandleWriteVar(buffer_var);
}
this->Visit(op->args[2]);
this->VisitExpr(op->args[2]);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
}
void HandleUseVar(const Variable* var) {
......@@ -90,46 +94,46 @@ class ExprTouched final : public IRVisitor {
};
// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public IRVisitor {
class VarTouchedAnalysis : public StmtVisitor {
public:
void Visit_(const LetStmt *op) {
void VisitStmt_(const LetStmt* op) final {
ExprTouched tc(touched_var_, false);
tc.Visit(op->value);
tc(op->value);
Record(op->var.get(), tc);
this->Visit(op->body);
this->VisitStmt(op->body);
}
void Visit_(const Store *op) {
void VisitStmt_(const Store* op) final {
ExprTouched tc(touched_var_, false);
tc.Visit(op->value);
tc.Visit(op->index);
tc(op->value);
tc(op->index);
Record(op->buffer_var.get(), tc);
}
void Visit_(const For *op) {
void VisitStmt_(const For* op) final {
ExprTouched tc(touched_var_, false);
tc.Visit(op->min);
tc.Visit(op->extent);
tc(op->min);
tc(op->extent);
Record(op->loop_var.get(), tc);
this->Visit(op->body);
this->VisitStmt(op->body);
}
// external function call
void Visit_(const Evaluate *op) {
void VisitStmt_(const Evaluate* op) final {
ExprTouched tc(touched_var_, true);
tc.Visit(op->value);
tc(op->value);
for (const Variable* var : tc.write_vars_) {
Record(var, tc);
}
}
void Visit_(const Allocate *op) {
void VisitStmt_(const Allocate* op) final {
ExprTouched tc(touched_var_, false);
for (size_t i = 0; i < op->extents.size(); ++i) {
tc.Visit(op->extents[i]);
tc(op->extents[i]);
}
tc.Visit(op->condition);
tc.VisitExpr(op->condition);
if (op->new_expr.defined()) {
tc.Visit(op->new_expr);
tc(op->new_expr);
}
Record(op->buffer_var.get(), tc);
this->Visit(op->body);
this->VisitStmt(op->body);
}
void Record(const Variable* var,
const ExprTouched& tc) {
......@@ -149,7 +153,7 @@ class VarTouchedAnalysis : public IRVisitor {
TouchedVar(const Stmt& stmt,
const Variable* var) {
touched_var_.insert(var);
this->Visit(stmt);
this->VisitStmt(stmt);
// do a DFS to push affect around dependency.
std::vector<const Variable*> pending(
touched_var_.begin(), touched_var_.end());
......@@ -177,9 +181,8 @@ class VarTouchedAnalysis : public IRVisitor {
// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class VTInjector : public IRMutator {
class VTInjector : public StmtExprMutator {
public:
using IRMutator::Mutate;
// constructor
VTInjector(Var var,
int num_threads,
......@@ -189,9 +192,9 @@ class VTInjector : public IRMutator {
touched_var_(touched_var), allow_share_(allow_share) {
}
// Inject VTLoop when needed.
Stmt Mutate(Stmt stmt) final {
Stmt VisitStmt(const Stmt& s) final {
CHECK(!visit_touched_var_);
stmt = IRMutator::Mutate(stmt);
auto stmt = StmtExprMutator::VisitStmt(s);
if (visit_touched_var_ || trigger_base_inject_) {
if (!vt_loop_injected_) {
return InjectVTLoop(stmt, false);
......@@ -202,20 +205,20 @@ class VTInjector : public IRMutator {
return stmt;
}
// Variable
Expr Mutate_(const Variable *op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
CHECK(!alloc_remap_.count(op))
<< "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) {
visit_touched_var_ = true;
}
return e;
return GetRef<Expr>(op);
}
Expr RewriteIndex(Expr index, Expr alloc_extent) const {
return index + var_ * alloc_extent;
}
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
......@@ -230,16 +233,16 @@ class VTInjector : public IRMutator {
}
}
// Expression.
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e);
if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
visit_touched_var_ = true;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
Expr offset = this->VisitExpr(op->args[2]);
Expr extent = this->VisitExpr(op->args[3]);
Expr stride =
it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
......@@ -248,18 +251,18 @@ class VTInjector : public IRMutator {
{op->args[0], op->args[1], offset, extent, op->args[4]},
op->call_type);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return allow_share_ ? e : var_;
return allow_share_ ? GetRef<Expr>(op) : var_;
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
Stmt VisitStmt_(const Evaluate* op) final {
trigger_base_inject_ = !allow_share_;
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
// Store
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
......@@ -276,114 +279,114 @@ class VTInjector : public IRMutator {
}
}
// Attribute
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Expr value = Mutate(op->value);
Stmt VisitStmt_(const AttrStmt* op) final {
Expr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
return InjectVTLoop(GetRef<Stmt>(op), true);
} else if (!allow_share_ && !vt_loop_injected_ &&
(op->attr_key == attr::coproc_uop_scope ||
op->attr_key == attr::coproc_scope)) {
return InjectVTLoop(s, true);
return InjectVTLoop(GetRef<Stmt>(op), true);
} else {
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return AttrStmt::make(op->node, op->attr_key, value, body);
}
}
}
// LetStmt
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
Stmt VisitStmt_(const LetStmt* op) final {
Expr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
return InjectVTLoop(GetRef<Stmt>(op), true);
}
visit_touched_var_ = false;
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return LetStmt::make(op->var, value, body);
}
}
// For
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
CHECK(is_zero(op->min));
Expr extent = Mutate(op->extent);
Expr extent = this->VisitExpr(op->extent);
if (visit_touched_var_ && !vt_loop_injected_) {
Stmt stmt = InjectVTLoop(s, true);
Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true);
++max_loop_depth_;
return stmt;
}
visit_touched_var_ = false;
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
++max_loop_depth_;
if (extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return For::make(
op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Expr condition = this->Mutate(op->condition);
Stmt VisitStmt_(const IfThenElse* op) final {
Expr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
return InjectVTLoop(GetRef<Stmt>(op), true);
}
visit_touched_var_ = false;
CHECK_EQ(max_loop_depth_, 0);
Stmt then_case = this->Mutate(op->then_case);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
else_case = this->Mutate(op->else_case);
else_case = this->VisitStmt(op->else_case);
max_loop_depth_ = std::max(temp, max_loop_depth_);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// Block
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt VisitStmt_(const Block* op) final {
CHECK_EQ(max_loop_depth_, 0);
Stmt first = this->Mutate(op->first);
Stmt first = this->VisitStmt(op->first);
int temp = max_loop_depth_;
max_loop_depth_ = 0;
Stmt rest = this->Mutate(op->rest);
Stmt rest = this->VisitStmt(op->rest);
max_loop_depth_ = std::max(max_loop_depth_, temp);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
return GetRef<Stmt>(op);
} else {
return Block::make(first, rest);
}
}
// Allocate
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
if (op->new_expr.defined() && !vt_loop_injected_) {
return InjectVTLoop(s, true);
return InjectVTLoop(GetRef<Stmt>(op), true);
}
Expr condition = Mutate(op->condition);
Expr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
return InjectVTLoop(GetRef<Stmt>(op), true);
}
bool changed = false;
Array<Expr> extents;
for (size_t i = 0; i < op->extents.size(); i++) {
Expr new_ext = Mutate(op->extents[i]);
Expr new_ext = this->VisitExpr(op->extents[i]);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
return InjectVTLoop(GetRef<Stmt>(op), true);
}
if (!new_ext.same_as(op->extents[i])) changed = true;
extents.push_back(new_ext);
......@@ -406,15 +409,15 @@ class VTInjector : public IRMutator {
// mark this buffer get touched.
alloc_remap_[op->buffer_var.get()] = stride;
// Mutate the body.
body = Mutate(op->body);
body = this->VisitStmt(op->body);
} else {
// Mutate the body.
body = Mutate(op->body);
body = this->VisitStmt(op->body);
}
if (!changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
return GetRef<Stmt>(op);
} else {
return Allocate::make(
op->buffer_var, op->dtype,
......@@ -431,7 +434,7 @@ class VTInjector : public IRMutator {
trigger_base_inject_ = false;
vt_loop_injected_ = true;
if (before_mutation) {
stmt = this->Mutate(stmt);
stmt = this->VisitStmt(stmt);
}
// reset the flags after processing.
vt_loop_injected_ = false;
......@@ -478,10 +481,10 @@ class VTInjector : public IRMutator {
};
class VirtualThreadInjector : public IRMutator {
class VirtualThreadInjector : public StmtMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const AttrStmt* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>();
if (op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -490,21 +493,21 @@ class VirtualThreadInjector : public IRMutator {
VarTouchedAnalysis vs;
auto touched = vs.TouchedVar(op->body, iv->var.get());
VTInjector injecter(iv->var, nthread, touched, allow_share);
return injecter.Mutate(op->body);
return injecter(op->body);
} else {
return stmt;
}
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Stmt VisitStmt_(const Provide* op) final {
LOG(FATAL) << "Need to call StorageFlatten first";
return s;
return GetRef<Stmt>(op);
}
};
Stmt InjectVirtualThread(Stmt stmt) {
stmt = VirtualThreadInjector().Mutate(stmt);
return ConvertSSA(stmt);
stmt = VirtualThreadInjector()(std::move(stmt));
return ConvertSSA(std::move(stmt));
}
} // namespace ir
......
......@@ -21,8 +21,8 @@
* \file inline.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
......@@ -30,13 +30,13 @@ namespace ir {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline final : public IRMutator {
class IRInline final : public StmtExprMutator {
public:
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Call* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
if (op->func == f_) {
......@@ -78,7 +78,7 @@ Stmt Inline(Stmt stmt,
Expr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
Stmt ret = IRInline(f, args, body).Mutate(stmt);
Stmt ret = IRInline(f, args, body)(std::move(stmt));
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
}
......
......@@ -21,7 +21,6 @@
*/
#include <tvm/ir_functor_ext.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
......
......@@ -24,7 +24,7 @@
* \file lift_attr_scope.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include "ir_util.h"
namespace tvm {
......@@ -32,13 +32,13 @@ namespace ir {
// NOTE: this optimization can only be applied
// to a few specified attr keys
class AttrScopeLifter : public IRMutator {
class AttrScopeLifter : public StmtMutator {
public:
explicit AttrScopeLifter(std::string attr_key)
: attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) {
stmt = Mutate(stmt);
stmt = operator()(std::move(stmt));
if (attr_node_.defined()) {
stmt = AttrStmt::make(
attr_node_, attr_key_, attr_value_, stmt);
......@@ -47,8 +47,8 @@ class AttrScopeLifter : public IRMutator {
}
// do not go beyond
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
if (attr_node_.defined()) {
Stmt body = AttrStmt::make(
......@@ -65,17 +65,17 @@ class AttrScopeLifter : public IRMutator {
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr_key_) {
attr_node_ = op->node;
attr_value_ = op->value;
return op->body;
} else {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt VisitStmt_(const Block* op) final {
std::vector<Stmt> seq;
FlattenSeq(op->first, &seq);
FlattenSeq(op->rest, &seq);
......@@ -83,21 +83,21 @@ class AttrScopeLifter : public IRMutator {
if (seq.size() == 2 &&
seq[0].same_as(op->first) &&
seq[1].same_as(op->rest)) {
return s;
return GetRef<Stmt>(op);
}
return MergeSeq(seq);
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt VisitStmt_(const IfThenElse* op) final {
if (!op->else_case.defined()) {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
Stmt then_case = this->Mutate(op->then_case);
Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node;
Expr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->Mutate(op->else_case);
Stmt else_case = this->VisitStmt(op->else_case);
if (attr_node_.defined() &&
attr_value_.defined() &&
first_node.defined() &&
......@@ -106,7 +106,7 @@ class AttrScopeLifter : public IRMutator {
ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
......@@ -124,7 +124,7 @@ class AttrScopeLifter : public IRMutator {
}
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
......@@ -155,7 +155,7 @@ class AttrScopeLifter : public IRMutator {
for (const Stmt & stmt : seq) {
attr_node_ = ObjectRef();
attr_value_ = Expr();
Stmt rest = this->Mutate(stmt);
Stmt rest = this->VisitStmt(stmt);
if (attr_node_.defined() &&
attr_value_.defined() &&
curr_node.defined() &&
......@@ -214,7 +214,7 @@ class AttrScopeLifter : public IRMutator {
};
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(stmt);
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
}
} // namespace ir
......
......@@ -21,8 +21,7 @@
* \file loop_partition.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <unordered_map>
......@@ -50,7 +49,6 @@ struct PartitionKeyHash {
// condition cond is proven to have value cond_value (true or false) in interval.
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
bool success = false;
PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) {
......@@ -68,28 +66,28 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
// Rule:
// - the range should not be const
// - there exist a condition expression in the scope that use the var
class CandidateSelector final : public IRVisitor {
class CandidateSelector final : public StmtExprVisitor {
public:
using VarIsUsed = bool;
explicit CandidateSelector(bool split_const_loop)
: split_const_loop_(split_const_loop) {}
void Visit_(const For* op) {
void VisitStmt_(const For* op) final {
// partition const loop when sets split_const_loop_
if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
const Variable* var = op->loop_var.get();
record_.insert({var, false});
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
candidates.insert(op);
}
record_.erase(var);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const AttrStmt* op) {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
......@@ -97,7 +95,7 @@ class CandidateSelector final : public IRVisitor {
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) {
record_.insert({var.get(), false});
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
candidates.insert(op);
}
......@@ -105,34 +103,34 @@ class CandidateSelector final : public IRVisitor {
return;
}
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const Block* op) {
void VisitStmt_(const Block* op) final {
bool temp = no_split_;
this->Visit(op->first);
this->VisitStmt(op->first);
// erase the no split state of first when visit rest.
std::swap(temp, no_split_);
this->Visit(op->rest);
this->VisitStmt(op->rest);
// restore the no split flag.
no_split_ = no_split_ || temp;
}
void Visit_(const Call* op) {
void VisitExpr_(const Call* op) final {
if (op->is_intrinsic(Call::likely)) {
in_likely_ = true;
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
in_likely_ = false;
} else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
// no split if the body contains allreduce.
no_split_ = true;
return;
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
}
void Visit_(const Variable* op) {
void VisitExpr_(const Variable* op) final {
if (in_likely_ && record_.count(op)) {
record_.at(op) = true;
}
......@@ -150,7 +148,7 @@ class CandidateSelector final : public IRVisitor {
// Populate partitions data structure, i.e., for a specific variable,
// find an interval in which each condition
// (currently, "likely" conditions) has fixed true or false value
class PartitionFinder : public IRVisitor {
class PartitionFinder : public StmtExprVisitor {
public:
explicit PartitionFinder(VarExpr current_var,
const std::unordered_map<const Variable*, IntSet>& hint_map,
......@@ -164,18 +162,18 @@ class PartitionFinder : public IRVisitor {
}
}
void Visit_(const For* op) {
void VisitStmt_(const For* op) final {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
const Variable* var = op->loop_var.get();
hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
relax_map_.erase(var);
hint_map_.erase(var);
}
void Visit_(const AttrStmt* op) {
void VisitStmt_(const AttrStmt* op) final {
// handle thread_axis
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
......@@ -184,15 +182,15 @@ class PartitionFinder : public IRVisitor {
IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
hint_map_.insert({var, dom});
relax_map_.insert({var, dom});
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
relax_map_.erase(var);
hint_map_.erase(var);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Call* op) {
void VisitExpr_(const Call* op) final {
if (op->is_intrinsic(Call::likely)) {
Expr cond = op->args[0];
if (ExprUseVars(cond,
......@@ -217,7 +215,7 @@ class PartitionFinder : public IRVisitor {
}
}
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
}
......@@ -255,17 +253,16 @@ class PartitionFinder : public IRVisitor {
};
// Replace the set of conditions given by ps with cond_value (true or false)
class ConditionEliminator : public IRMutator {
class ConditionEliminator : public StmtExprMutator {
public:
explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
: ps_(ps), cond_value_(cond_value) {}
using IRMutator::Mutate;
Expr Mutate(Expr e) final {
Expr VisitExpr(const Expr& e) final {
if (ps_.find(e.get()) != ps_.end()) {
return Mutate(cond_value_ ? const_true() : const_false());
return VisitExpr(cond_value_ ? const_true() : const_false());
}
return IRMutator::Mutate(e);
return StmtExprMutator::VisitExpr(e);
}
private:
......@@ -275,26 +272,26 @@ class ConditionEliminator : public IRMutator {
// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public IRMutator {
class ThreadPartitionInserter : public StmtMutator {
public:
explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps,
Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
innermost_thread_scope_ = true;
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtMutator::VisitStmt_(op);
// add branch code inside the innermost thread scope
if (innermost_thread_scope_) {
Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body);
Stmt simplified_body = ConditionEliminator(ps_)(op->body);
Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
Expr value = this->Mutate(op->value);
Expr value = this->VisitExpr(op->value);
stmt = AttrStmt::make(op->node, op->attr_key, value, body);
}
innermost_thread_scope_ = false;
return stmt;
} else {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
}
......@@ -306,19 +303,19 @@ class ThreadPartitionInserter : public IRMutator {
// Try to partition range of iteration variables in order to remove (some)
// likely conditions
class LoopPartitioner : public IRMutator {
class LoopPartitioner : public StmtMutator {
public:
explicit LoopPartitioner(bool split_const_loop)
: selector(CandidateSelector(split_const_loop)) {}
Stmt VisitAndMutate(const Stmt& stmt) {
selector.Visit(stmt);
return Mutate(stmt);
Stmt VisitAndMutate(Stmt stmt) {
selector(stmt);
return operator()(std::move(stmt));
}
Stmt Mutate_(const For* op, const Stmt& stmt) {
Stmt VisitStmt_(const For* op) final {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, op->loop_var,
Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var,
op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
}
......@@ -327,21 +324,21 @@ class LoopPartitioner : public IRMutator {
// normal loop variable can be put into hint map.
hint_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
Stmt res = IRMutator::Mutate_(op, stmt);
Stmt res = StmtMutator::VisitStmt_(op);
hint_map_.erase(op->loop_var.get());
return res;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key != attr::thread_extent) {
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
Stmt s = TryPartition(op, GetRef<Stmt>(op), var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}
......@@ -352,12 +349,12 @@ class LoopPartitioner : public IRMutator {
// threadIdx should be put into relax map, in case of divergence.
relax_map_.insert({var.get(),
IntSet::interval(make_zero(var.dtype()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
res = StmtMutator::VisitStmt_(op);
relax_map_.erase(var.get());
} else {
hint_map_.insert({var.get(),
IntSet::interval(make_zero(var.dtype()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
res = StmtMutator::VisitStmt_(op);
hint_map_.erase(var.get());
}
return res;
......@@ -473,7 +470,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
hint_map_.insert({var.get(), IntSet::interval(min, max)});
PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body);
finder(body);
hint_map_.erase(var.get());
if (finder.partitions.empty()) return Stmt();
......@@ -564,7 +561,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
Stmt mid_stmt;
if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body);
Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body);
......@@ -586,7 +583,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
Expr cond = const_true();
if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt);
s = ThreadPartitionInserter(cond_set, cond)(stmt);
}
s = ConvertSSA(s);
return s;
......@@ -604,23 +601,21 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body)
}
}
class RemoveLikelyTags : public IRMutator {
class RemoveLikelyTags : public StmtExprMutator {
public:
using IRMutator::Mutate;
Expr Mutate_(const Call *op, const Expr& e) {
Expr VisitExpr_(const Call *op) final {
if (op->is_intrinsic(Call::likely)) {
CHECK_EQ(op->args.size(), 1);
return IRMutator::Mutate(op->args[0]);
return StmtExprMutator::VisitExpr(op->args[0]);
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
};
Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt);
stmt = LoopPartitioner(split_const_loop).VisitAndMutate(std::move(stmt));
stmt = RemoveLikelyTags()(std::move(stmt));
return stmt;
}
......
......@@ -21,7 +21,7 @@
* \brief Pass for lowering custom datatypes
*/
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h>
#include "../codegen/datatype/registry.h"
......@@ -37,17 +37,17 @@ namespace ir {
* datatype) for lowering this type of expression, and uses it to lower the
* expression.
*/
class CustomDatatypesLowerer : public IRMutator {
class CustomDatatypesLowerer : public StmtExprMutator {
public:
explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
inline Expr Mutate_(const Cast* op, const Expr& e) final {
inline Expr VisitExpr_(const Cast* op) final {
auto type_code = op->dtype.code();
auto src_type_code = op->value.dtype().code();
// If either datatype is a registered custom datatype, we must lower.
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Cast>();
if (toBeLowered) {
auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
......@@ -59,8 +59,9 @@ class CustomDatatypesLowerer : public IRMutator {
return expr;
}
inline Expr Mutate_(const FloatImm* imm, const Expr& e) final {
inline Expr VisitExpr_(const FloatImm* imm) final {
auto type_code = imm->dtype.code();
auto e = GetRef<Expr>(imm);
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
CHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
......@@ -70,9 +71,9 @@ class CustomDatatypesLowerer : public IRMutator {
return e;
}
inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final {
inline Stmt VisitStmt_(const Allocate* allocate) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
Stmt stmt = IRMutator::Mutate_(allocate, s);
Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
allocate = stmt.as<Allocate>();
if (toBeLowered) {
......@@ -84,9 +85,9 @@ class CustomDatatypesLowerer : public IRMutator {
return stmt;
}
inline Expr Mutate_(const Load* load, const Expr& e) final {
inline Expr VisitExpr_(const Load* load) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
Expr expr = IRMutator::Mutate_(load, e);
Expr expr = StmtExprMutator::VisitExpr_(load);
load = expr.as<Load>();
if (toBeLowered) {
auto new_load_type = DataType::UInt(load->dtype.bits());
......@@ -96,10 +97,10 @@ class CustomDatatypesLowerer : public IRMutator {
}
#define DEFINE_MUTATE__(OP) \
inline Expr Mutate_(const OP* op, const Expr& e) final { \
inline Expr VisitExpr_(const OP* op) final { \
auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
Expr expr = IRMutator::Mutate_(op, e); \
Expr expr = StmtExprMutator::VisitExpr_(op); \
op = expr.as<OP>(); \
if (toBeLowered) { \
auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
......@@ -131,7 +132,7 @@ class CustomDatatypesLowerer : public IRMutator {
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = CustomDatatypesLowerer(target).Mutate(n->body);
n->body = CustomDatatypesLowerer(target)(n->body);
return LoweredFunc(n);
}
......
......@@ -22,7 +22,6 @@
* \file lower_intrin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include <tvm/expr_operator.h>
......@@ -34,9 +33,10 @@
namespace tvm {
namespace ir {
class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::Mutate_;
using IRMutatorWithAnalyzer::VisitStmt_;
using IRMutatorWithAnalyzer::VisitExpr_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target)
: IRMutatorWithAnalyzer(analyzer) {
......@@ -51,28 +51,29 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
Expr r = ApplyPattern(op->name, e);
Expr r = ApplyPattern(op->name, GetRef<Expr>(op));
if (r.defined()) return r;
}
return IRMutator::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr Mutate_(const Add* op, const Expr& e) final {
Expr VisitExpr_(const Add* op) final {
if (const Mul* mb = op->b.as<Mul>()) {
return MakeFMA(mb->a, mb->b, op->a, op, e);
return MakeFMA(mb->a, mb->b, op->a, op);
} else if (const Mul* ma = op->a.as<Mul>()) {
return MakeFMA(ma->a, ma->b, op->b, op, e);
return MakeFMA(ma->a, ma->b, op->b, op);
}
return IRMutator::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
Expr Mutate_(const FloorDiv* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
Expr VisitExpr_(const FloorDiv* op) final {
auto e = GetRef<Expr>(op);
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDiv>();
if (op == nullptr) return ret;
int shift;
......@@ -117,8 +118,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
}
Expr Mutate_(const FloorMod* op, const Expr& e) final {
Expr ret = IRMutatorWithAnalyzer::Mutate_(op, e);
Expr VisitExpr_(const FloorMod* op) final {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorMod>();
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
......@@ -167,34 +168,37 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
}
Expr Mutate_(const Max* op, const Expr& e) final {
Expr VisitExpr_(const Max* op) final {
using namespace arith;
PVar<Expr> x, y;
PVar<Integer> c;
auto e = GetRef<Expr>(op);
if (max(floordiv(x, y), c).Match(e) &&
c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(Mutate(truncdiv(x, y).Eval()), c.Eval());
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr Mutate_(const EQ* op, const Expr& e) final {
Expr VisitExpr_(const EQ* op) final {
using namespace arith;
PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
if ((floormod(x, y) == 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
return VisitExpr((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr Mutate_(const NE* op, const Expr& e) final {
Expr VisitExpr_(const NE* op) final {
using namespace arith;
PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
if ((floormod(x, y) != 0).Match(e)) {
return Mutate((truncmod(x, y) != 0).Eval());
return VisitExpr((truncmod(x, y) != 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
private:
......@@ -231,7 +235,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
}
Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
const Add* op, const Expr& e) {
const Add* op) {
// emit fma instruction: a * b + c
Expr lhs = SwapBroadcastCast(a);
Expr rhs = SwapBroadcastCast(b);
......@@ -239,14 +243,14 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (fma_ != nullptr && op->dtype.is_float()) {
Expr r = (*fma_)(Call::make(
op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
if (r.defined()) return this->Mutate(r);
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
Expr mul = this->Mutate(Mul::make(lhs, rhs));
return Add::make(mul, this->Mutate(c));
Expr mul = this->VisitExpr(Mul::make(lhs, rhs));
return Add::make(mul, this->VisitExpr(c));
}
}
return IRMutator::Mutate_(op, e);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expr ApplyPattern(const std::string& name, const Expr& e) {
......@@ -262,7 +266,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
Expr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
return this->Mutate(r);
return this->VisitExpr(r);
}
}
}
......@@ -277,7 +281,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target).Mutate(stmt);
return IntrinInjecter(&analyzer, target)(std::move(stmt));
}
LoweredFunc
......
......@@ -22,7 +22,7 @@
* \file lower_thread_allreduce.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
......@@ -32,19 +32,19 @@
namespace tvm {
namespace ir {
class ThreadAllreduceBuilder final : public IRMutator {
class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
return ret;
} else if (op->attr_key == attr::storage_scope) {
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmt>();
const Variable* v = op->node.as<Variable>();
if (alloc_remap_.count(v)) {
......@@ -56,15 +56,15 @@ class ThreadAllreduceBuilder final : public IRMutator {
const CommReducerNode *combiner = op->node.as<CommReducerNode>();
CHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
reduce_combiner_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Evaluate* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Evaluate>();
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
......@@ -73,8 +73,8 @@ class ThreadAllreduceBuilder final : public IRMutator {
return stmt;
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
......@@ -93,13 +93,13 @@ class ThreadAllreduceBuilder final : public IRMutator {
return stmt;
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr VisitExpr_(const Load* op) final {
auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) {
CHECK(is_zero(op->index)) << e;
CHECK(is_zero(op->index));
return it->second;
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
......@@ -339,7 +339,7 @@ LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
n->body = ThreadAllreduceBuilder(warp_size)(n->body);
return LoweredFunc(n);
}
} // namespace ir
......
......@@ -22,7 +22,7 @@
* \file lower_tvm_buildin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
......@@ -43,14 +43,14 @@ inline Expr StackAlloca(std::string type, size_t num) {
// Calculate the statistics of packed function.
// These information are needed during codegen.
class BuiltinLower : public IRMutator {
class BuiltinLower : public StmtExprMutator {
public:
Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", DataType::Handle());
stack_array_ = Var("stack_array", DataType::Handle());
stack_value_ = Var("stack_value", DataType::Handle());
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->Mutate(stmt);
stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) {
stmt = LetStmt::make(
stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
......@@ -68,8 +68,8 @@ class BuiltinLower : public IRMutator {
return stmt;
}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
Stmt VisitStmt(const Stmt& s) final {
auto stmt = StmtExprMutator::VisitStmt(s);
CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_array_stack_, 0);
while (prep_seq_.size() != 0) {
......@@ -79,9 +79,9 @@ class BuiltinLower : public IRMutator {
return stmt;
}
Stmt Mutate_(const Allocate* op, const Stmt& s) {
Stmt VisitStmt_(const Allocate* op) {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound.
......@@ -141,39 +141,39 @@ class BuiltinLower : public IRMutator {
return body;
}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined());
device_id_ = op->value;
return Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::device_context_type) {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
return this->VisitStmt(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
return MakeCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
return MakeCallTracePacked(op, e);
return MakeCallTracePacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
return MakeShape(op);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e);
return MakeArray(op);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
return make_zero(op->dtype);
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
// call shape
Expr MakeShape(const Call* op, const Expr& e) {
Expr MakeShape(const Call* op) {
size_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
......@@ -183,10 +183,10 @@ class BuiltinLower : public IRMutator {
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
// make array
Expr MakeArray(const Call* op, const Expr& e) {
Expr MakeArray(const Call* op) {
size_t idx = run_array_stack_;
run_array_stack_ += 1;
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
......@@ -230,13 +230,13 @@ class BuiltinLower : public IRMutator {
return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packed.
Expr MakeCallPacked(const Call* op, const Expr& e) {
Expr MakeCallPacked(const Call* op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
......@@ -278,14 +278,14 @@ class BuiltinLower : public IRMutator {
packed_args, Call::Intrinsic);
}
Expr MakeCallTracePacked(const Call *op, const Expr &e) {
Expr MakeCallTracePacked(const Call *op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
size_t args_size = op->args.size();
CHECK_GT(args_size, 0);
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
......
......@@ -26,8 +26,7 @@
// Thanks to Andrew Adams and Vinod Grover for
// explaining the concept of warp shuffle.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
......@@ -75,7 +74,7 @@ namespace ir {
// Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
class WarpStoreCoeffFinder : private IRVisitor {
class WarpStoreCoeffFinder : private StmtVisitor {
public:
WarpStoreCoeffFinder(const Variable* buffer,
Var warp_index,
......@@ -86,13 +85,13 @@ class WarpStoreCoeffFinder : private IRVisitor {
}
// find the warp co-efficient in the statement given the warp size
int Find(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
return warp_coeff_;
}
private:
/// Visitor implementation
void Visit_(const Store *op) final {
void VisitStmt_(const Store *op) final {
if (op->buffer_var.get() == buffer_) {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
......@@ -104,7 +103,7 @@ class WarpStoreCoeffFinder : private IRVisitor {
UpdatePattern(base);
}
} else {
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
}
......@@ -141,14 +140,14 @@ class WarpStoreCoeffFinder : private IRVisitor {
// Visitor to find the warp index
class WarpIndexFinder : private IRVisitor {
class WarpIndexFinder : private StmtVisitor {
public:
explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) {
}
// find the warp co-efficient in the statement given the warp size
IterVar Find(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_;
......@@ -156,7 +155,7 @@ class WarpIndexFinder : private IRVisitor {
private:
/// Visitor implementation
void Visit_(const AttrStmt *op) final {
void VisitStmt_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
......@@ -177,7 +176,7 @@ class WarpIndexFinder : private IRVisitor {
}
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
// warp size
int warp_size_{0};
......@@ -185,13 +184,13 @@ class WarpIndexFinder : private IRVisitor {
IterVar warp_index_{nullptr};
};
// Mutator to change the read pattern
class WarpAccessRewriter : protected IRMutator {
class WarpAccessRewriter : protected StmtExprMutator {
public:
explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer)
: warp_size_(warp_size), analyzer_(analyzer) {}
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const Allocate* op, const Stmt& stmt) {
Stmt Rewrite(const Allocate* op) {
buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size();
CHECK_GT(alloc_size, 0)
......@@ -208,27 +207,27 @@ class WarpAccessRewriter : protected IRMutator {
op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)},
op->condition,
this->Mutate(op->body));
this->VisitStmt(op->body));
}
protected:
Expr Mutate_(const Variable* op, const Expr& expr) {
Expr Mutate_(const Variable* op) {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}
Stmt Mutate_(const Store* op, const Stmt& stmt) {
Stmt VisitStmt_(const Store* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
return Store::make(op->buffer_var, op->value, local_index, op->predicate);
} else {
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Load* op, const Expr& expr) {
Expr Mutate_(const Load* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
......@@ -243,7 +242,7 @@ class WarpAccessRewriter : protected IRMutator {
{load_value, group},
Call::Intrinsic);
} else {
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}
}
// Split the index to the two component
......@@ -297,18 +296,18 @@ class WarpAccessRewriter : protected IRMutator {
// Bind bound information of variables to make analyzer more effective
// TODO(tqchen): consider a pass to inline the bound info into the expr
// so analysis can be context independent.
class BindVarBoundInfo : public IRVisitor {
class BindVarBoundInfo : public StmtVisitor {
public:
explicit BindVarBoundInfo(arith::Analyzer* analyzer)
: analyzer_(analyzer) {}
void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
const Var& loop_var = op->loop_var;
analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
void Visit_(const AttrStmt* op) {
void VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -319,7 +318,7 @@ class BindVarBoundInfo : public IRVisitor {
analyzer_->Bind(iv->var, dom);
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
protected:
......@@ -330,7 +329,7 @@ class BindVarBoundInfo : public IRVisitor {
};
// Mutator to change the read pattern
class WarpMemoryRewriter : private IRMutator {
class WarpMemoryRewriter : private StmtMutator {
public:
explicit WarpMemoryRewriter(int warp_size)
: warp_size_(warp_size) {
......@@ -338,36 +337,37 @@ class WarpMemoryRewriter : private IRMutator {
Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt;
BindVarBoundInfo(&analyzer_).Visit(stmt);
stmt = this->Mutate(stmt);
BindVarBoundInfo binder(&analyzer_);
binder(stmt);
stmt = operator()(std::move(stmt));
stmt = CanonicalSimplify(stmt);
return stmt;
}
private:
Stmt Mutate_(const Allocate* op, const Stmt& stmt) {
Stmt VisitStmt_(const Allocate* op) {
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_, &analyzer_);
return rewriter.Rewrite(op, stmt);
return rewriter.Rewrite(op);
} else {
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
Stmt VisitStmt_(const AttrStmt* op) {
using runtime::StorageScope;
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
if (scope.rank == runtime::StorageRank::kWarp) {
warp_buffer_.insert(buf);
Stmt ret = IRMutator::Mutate_(op, stmt);
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmt>();
return AttrStmt::make(
op->node, op->attr_key, StringImm::make("local"), op->body);
}
}
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
int warp_size_{0};
......
......@@ -22,8 +22,7 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <vector>
......@@ -207,29 +206,29 @@ LoweredFunc MakeAPI(Stmt body,
return f;
}
class DeviceTypeBinder: public IRMutator {
class DeviceTypeBinder: public StmtExprMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) {
var_ = var;
Expr value = make_const(op->value.dtype(), device_type_);
Stmt body = IRMutator::Mutate_(op, s);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmt::make(op->value == value, os.str(), body);
}
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt VisitStmt_(const IfThenElse* op) final {
// eager simplify if guard.
Stmt res = IRMutator::Mutate_(op, s);
Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElse>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
......@@ -241,9 +240,9 @@ class DeviceTypeBinder: public IRMutator {
return res;
}
Expr Mutate_(const NE* op, const Expr& e) final {
Expr VisitExpr_(const NE* op) final {
// eager check NE for device check
Expr res = IRMutator::Mutate_(op, e);
Expr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NE>();
if (ir::Equal(op->a, op->b)) {
return make_const(op->dtype, false);
......@@ -251,11 +250,11 @@ class DeviceTypeBinder: public IRMutator {
return res;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
return e;
return GetRef<Expr>(op);
}
}
......@@ -267,7 +266,7 @@ class DeviceTypeBinder: public IRMutator {
LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type).Mutate(n->body);
n->body = DeviceTypeBinder(device_type)(n->body);
return LoweredFunc(n);
}
......
......@@ -21,8 +21,7 @@
* \file remap_thread_axis.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
......@@ -31,7 +30,7 @@ namespace tvm {
namespace ir {
// Mutator to change the read pattern
class ThreadAxisRewriter : private IRMutator {
class ThreadAxisRewriter : private StmtExprMutator {
public:
explicit ThreadAxisRewriter(
const std::unordered_map<std::string, IterVar>& tmap)
......@@ -39,11 +38,11 @@ class ThreadAxisRewriter : private IRMutator {
}
Stmt Rewrite(Stmt stmt) {
return Mutate(stmt);
return operator()(std::move(stmt));
}
private:
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
......@@ -56,18 +55,18 @@ class ThreadAxisRewriter : private IRMutator {
} else {
CHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
return AttrStmt::make(
new_iv, op->attr_key, op->value, body);
}
}
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
Expr Mutate_(const Variable* op, const Expr& expr) final {
Expr VisitExpr_(const Variable* op) final {
auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second;
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}
// The thread map
const std::unordered_map<std::string, IterVar>& tmap_;
......
......@@ -23,30 +23,30 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_map>
namespace tvm {
namespace ir {
// Mark the statment of each stage.
class NoOpRemover : public IRMutator {
class NoOpRemover : public StmtMutator {
public:
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const LetStmt* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<LetStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == "pragma_debug_skip_region") {
return MakeEvaluate(0);
}
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const IfThenElse* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<IfThenElse>();
if (op->else_case.defined()) {
if (is_no_op(op->else_case)) {
......@@ -66,35 +66,35 @@ class NoOpRemover : public IRMutator {
}
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const For* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<For>();
if (is_zero(op->extent)) {
return Evaluate::make(0);
}
return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const ProducerConsumer* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ProducerConsumer>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Realize* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Realize>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
if (HasSideEffect(op->value)) return s;
Stmt VisitStmt_(const Evaluate* op) final {
if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
return Evaluate::make(0);
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Block* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Block>();
if (is_no_op(op->first)) {
return op->rest;
......@@ -129,7 +129,7 @@ class NoOpRemover : public IRMutator {
};
Stmt RemoveNoOp(Stmt stmt) {
return NoOpRemover().Mutate(stmt);
return NoOpRemover()(std::move(stmt));
}
} // namespace ir
} // namespace tvm
......@@ -23,7 +23,6 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
......@@ -109,10 +108,10 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
}
};
class UnsafeSelectRewriter : public IRMutator {
class UnsafeSelectRewriter : public StmtExprMutator {
public:
Expr Mutate_(const Select* op, const Expr& e) {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Select* op) {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Select>();
UnsafeExprDetector unsafe;
bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
......@@ -131,7 +130,7 @@ class UnsafeSelectRewriter : public IRMutator {
};
Stmt RewriteUnsafeSelect(Stmt stmt) {
return UnsafeSelectRewriter().Mutate(stmt);
return UnsafeSelectRewriter()(std::move(stmt));
}
} // namespace ir
......
......@@ -22,25 +22,24 @@
* \brief Implementation of simple passes
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
class IRSideEffect : public IRVisitor {
class IRSideEffect : public ExprVisitor {
public:
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (has_side_effect_) return;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
if (!op->is_pure()) {
has_side_effect_ = true; return;
} else {
IRVisitor::Visit_(op);
ExprVisitor::VisitExpr_(op);
}
}
......@@ -49,23 +48,23 @@ class IRSideEffect : public IRVisitor {
bool HasSideEffect(const Expr& e) {
IRSideEffect v;
v.Visit(e);
v(e);
return v.has_side_effect_;
}
class IRSubstitue : public IRMutator {
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(
const std::unordered_map<const Variable*, Expr>& smap)
: smap_(smap) {
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = smap_.find(op);
if (it != smap_.end()) {
return it->second;
} else {
return e;
return GetRef<Expr>(op);
}
}
......@@ -76,13 +75,13 @@ class IRSubstitue : public IRMutator {
Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return stmt;
return IRSubstitue(value_map).Mutate(stmt);
return IRSubstitue(value_map)(std::move(stmt));
}
Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return expr;
return IRSubstitue(value_map).Mutate(expr);
return IRSubstitue(value_map)(std::move(expr));
}
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
......@@ -101,20 +100,20 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
return Substitute(expr, vmap);
}
class VarTouchVisitor : public IRVisitor {
class VarTouchVisitor : public ExprVisitor {
public:
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (use_var_) return;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
void Visit_(const Variable* op) final {
void VisitExpr_(const Variable* op) final {
Handle(op);
}
void Visit_(const Load* op) final {
void VisitExpr_(const Load* op) final {
Handle(op->buffer_var.get());
IRVisitor::Visit_(op);
ExprVisitor::VisitExpr_(op);
}
virtual void Handle(const Variable* var) = 0;
......@@ -149,14 +148,14 @@ class ExprUseVSetVisitor : public VarTouchVisitor {
bool ExprUseVar(const Expr& e, const Var& v) {
ExprUseVarVisitor visitor(v.get());
visitor.Visit(e);
visitor(e);
return visitor.use_var_;
}
bool ExprUseVar(const Expr& e,
const std::unordered_set<const Variable*>& vset) {
ExprUseVSetVisitor visitor(vset);
visitor.Visit(e);
visitor(e);
return visitor.use_var_;
}
......
......@@ -19,22 +19,22 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
class AssertSkipper : public IRMutator {
class AssertSkipper : public StmtMutator {
public:
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const AssertStmt* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AssertStmt>();
return op->body;
}
};
Stmt SkipAssert(Stmt stmt) {
return AssertSkipper().Mutate(stmt);
return AssertSkipper()(std::move(stmt));
}
LoweredFunc SkipAssert(LoweredFunc f) {
......
......@@ -24,7 +24,7 @@
#include <tvm/ir.h>
#include <tvm/lowered_func.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/runtime/module.h>
#include <unordered_map>
......@@ -32,9 +32,9 @@ namespace tvm {
namespace ir {
// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public IRMutator {
class IRUseDefAnalysis : public StmtExprMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
......@@ -48,75 +48,77 @@ class IRUseDefAnalysis : public IRMutator {
Expr value = op->value;
if (visit_thread_extent_) {
value = this->Mutate(value);
value = this->VisitExpr(value);
}
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) return s;
return AttrStmt::make(op->node, op->attr_key, value, body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const LetStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const LetStmt* op) final {
this->HandleDef(op->var.get());
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
Expr value = this->VisitExpr(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return s;
return GetRef<Stmt>(op);
} else {
return LetStmt::make(op->var, value, body);
}
}
}
Stmt Mutate_(const For *op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
this->HandleDef(op->loop_var.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Allocate *op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
this->HandleDef(op->buffer_var.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Store *op, const Stmt& s) final {
Stmt VisitStmt_(const Store* op) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Expr Mutate_(const Let *op, const Expr& e) final {
Expr VisitExpr_(const Let* op) final {
this->HandleDef(op->var.get());
Expr body = this->Mutate(op->body);
Expr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
Expr value = this->VisitExpr(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return e;
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
}
}
Expr Mutate_(const Variable *op, const Expr& e) final {
this->HandleUse(e);
return IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Variable* op) final {
this->HandleUse(GetRef<Expr>(op));
return StmtExprMutator::VisitExpr_(op);
}
Expr Mutate_(const Load *op, const Expr& e) final {
Expr VisitExpr_(const Load* op) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
void HandleDef(const Variable* v) {
......@@ -154,20 +156,20 @@ class IRUseDefAnalysis : public IRMutator {
std::unordered_map<const Variable*, int> def_count_;
};
class HostDeviceSplitter : public IRMutator {
class HostDeviceSplitter : public StmtMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
return SplitDeviceFunc(s);
return SplitDeviceFunc(GetRef<Stmt>(op));
}
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
Array<LoweredFunc> Split(LoweredFunc f) {
......@@ -178,7 +180,7 @@ class HostDeviceSplitter : public IRMutator {
name_ = f->name;
ObjectPtr<LoweredFuncNode> n =
make_object<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
n->body = operator()(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
......@@ -195,7 +197,7 @@ class HostDeviceSplitter : public IRMutator {
// isolate the device function.
IRUseDefAnalysis m;
m.visit_thread_extent_ = false;
n->body = m.Mutate(body);
n->body = m(std::move(body));
n->name = os.str();
n->func_type = kDeviceFunc;
n->thread_axis = m.thread_axis_;
......@@ -243,7 +245,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
for (Var arg : args) {
m.use_count_[arg.get()] = 0;
}
m.Mutate(stmt);
m(stmt);
return m.undefined_;
}
......
......@@ -24,8 +24,7 @@
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include <unordered_map>
......@@ -34,29 +33,33 @@
namespace tvm {
namespace ir {
namespace {
class IRVerifySSA final : public IRVisitor {
class IRVerifySSA final : public StmtExprVisitor {
public:
bool is_ssa{true};
void Visit(const ObjectRef& n) final {
void VisitExpr(const Expr& n) final {
if (!is_ssa) return;
IRVisitor::Visit(n);
StmtExprVisitor::VisitExpr(n);
}
void Visit_(const Let* op) final {
void VisitStmt(const Stmt& n) final {
if (!is_ssa) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr_(const Let* op) final {
MarkDef(op->var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const LetStmt* op) final {
void VisitStmt_(const LetStmt* op) final {
MarkDef(op->var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
MarkDef(op->loop_var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const Allocate* op) final {
void VisitStmt_(const Allocate* op) final {
MarkDef(op->buffer_var.get());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
private:
......@@ -70,31 +73,32 @@ class IRVerifySSA final : public IRVisitor {
std::unordered_map<const Variable*, int> defined_;
};
class IRConvertSSA final : public IRMutator {
class IRConvertSSA final : public StmtExprMutator {
public:
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
if (scope_.count(op)) {
return scope_[op].back();
} else {
return e;
return GetRef<Expr>(op);
}
}
Expr Mutate_(const Let* op, const Expr& e) final {
Expr VisitExpr_(const Let* op) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
Expr value = this->VisitExpr(op->value);
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Expr body = IRMutator::Mutate(op->body);
Expr body = this->VisitExpr(op->body);
scope_[v.get()].pop_back();
return Let::make(new_var, value, body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
if (scope_.count(op->buffer_var.get())) {
return Load::make(
......@@ -104,8 +108,8 @@ class IRConvertSSA final : public IRMutator {
return expr;
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
if (scope_.count(op->buffer_var.get())) {
return Store::make(
......@@ -115,41 +119,41 @@ class IRConvertSSA final : public IRMutator {
return stmt;
}
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const LetStmt* op) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
Expr value = this->VisitExpr(op->value);
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt body = IRMutator::Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
scope_[v.get()].pop_back();
return LetStmt::make(new_var, value, body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
const VarExpr& v = op->loop_var;
if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<For>();
return For::make(
new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
const VarExpr& v = op->buffer_var;
if (defined_.count(v.get())) {
VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<Allocate>();
return Allocate::make(
......@@ -157,23 +161,23 @@ class IRConvertSSA final : public IRMutator {
op->body, op->new_expr, op->free_function);
} else {
defined_.insert(v.get());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (const Variable* v = op->node.as<Variable>()) {
if (op->attr_key == attr::storage_scope) {
const Allocate* alloc = op->body.as<Allocate>();
if (alloc && op->node.same_as(alloc->buffer_var)) {
Stmt new_alloc = Mutate(op->body);
if (new_alloc.same_as(op->body)) return s;
Stmt new_alloc = this->VisitStmt(op->body);
if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
alloc = new_alloc.as<Allocate>();
CHECK(alloc);
return AttrStmt::make(
alloc->buffer_var, op->attr_key, op->value, new_alloc);
}
}
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>();
if (scope_.count(v) && scope_[v].size() != 0) {
return AttrStmt::make(
......@@ -182,7 +186,7 @@ class IRConvertSSA final : public IRMutator {
return stmt;
}
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
......@@ -194,13 +198,13 @@ class IRConvertSSA final : public IRMutator {
} // namespace
bool VerifySSA(const Stmt& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
IRVerifySSA visitor;
visitor(ir);
return visitor.is_ssa;
}
Stmt ConvertSSA(Stmt stmt) {
return IRConvertSSA().Mutate(stmt);
return IRConvertSSA()(std::move(stmt));
}
} // namespace ir
......
......@@ -21,7 +21,6 @@
* \file storage_access.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/target_info.h>
#include <string>
#include <utility>
......@@ -32,7 +31,7 @@
namespace tvm {
namespace ir {
void StorageAccessVisitor::Visit_(const Load* op) {
void StorageAccessVisitor::VisitExpr_(const Load* op) {
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) {
......@@ -47,10 +46,10 @@ void StorageAccessVisitor::Visit_(const Load* op) {
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
void StorageAccessVisitor::Visit_(const Store* op) {
void StorageAccessVisitor::VisitStmt_(const Store* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
......@@ -67,7 +66,7 @@ void StorageAccessVisitor::Visit_(const Store* op) {
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
......@@ -75,11 +74,11 @@ void StorageAccessVisitor::Visit_(const Store* op) {
allow_append_ = false;
}
void StorageAccessVisitor::Visit_(const Evaluate* op) {
void StorageAccessVisitor::VisitStmt_(const Evaluate* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// push to the scope
if (curr_stmt_.access.size() != 0) {
scope_.back().push_back(curr_stmt_);
......@@ -88,17 +87,17 @@ void StorageAccessVisitor::Visit_(const Evaluate* op) {
allow_append_ = false;
}
void StorageAccessVisitor::Visit_(const AttrStmt* op) {
void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::double_buffer_write) {
CHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<Variable>();
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
......@@ -115,7 +114,7 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
} else if (op->attr_key == attr::coproc_scope) {
IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
env_threads_.CopyOnWrite()->data.pop_back();
} else if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -123,23 +122,23 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
if (!in_device_env_) {
in_device_env_ = true;
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// no need to take the result as the thread barrier automatically syncs.
Summarize(std::move(scope_.back()), nullptr);
in_device_env_ = false;
scope_.pop_back();
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
env_threads_.CopyOnWrite()->data.pop_back();
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void StorageAccessVisitor::Visit_(const For* op) {
void StorageAccessVisitor::VisitStmt_(const For* op) {
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op);
......@@ -161,11 +160,11 @@ void StorageAccessVisitor::Visit_(const For* op) {
}
}
void StorageAccessVisitor::Visit_(const IfThenElse* op) {
void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) {
++condition_counter_;
this->Visit(op->condition);
this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->Visit(op->then_case);
this->VisitStmt(op->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
......@@ -180,10 +179,10 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) {
--condition_counter_;
}
void StorageAccessVisitor::Visit_(const Call* op) {
void StorageAccessVisitor::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l);
StmtExprVisitor::VisitExpr_(l);
} else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
......@@ -211,7 +210,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
curr_stmt_.access.emplace_back(e);
}
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
CHECK(allow_append_);
const std::string& s = op->args[0].as<StringImm>()->value;
......@@ -224,7 +223,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
curr_stmt_.access.emplace_back(std::move(e));
}
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
}
......@@ -236,11 +235,12 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
return it->second;
}
class StorageAccessInfoLower : public IRMutator {
class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
......@@ -259,7 +259,7 @@ class StorageAccessInfoLower : public IRMutator {
return stmt;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
......@@ -270,26 +270,26 @@ class StorageAccessInfoLower : public IRMutator {
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
return MakeAccessPtr(op);
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
Expr MakeAccessPtr(const Call* op) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
......@@ -337,7 +337,7 @@ class StorageAccessInfoLower : public IRMutator {
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower().Mutate(stmt);
return StorageAccessInfoLower()(std::move(stmt));
}
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
......
......@@ -27,7 +27,7 @@
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <vector>
#include <unordered_map>
#include "../runtime/thread_storage_scope.h"
......@@ -40,7 +40,7 @@ using runtime::StorageRank;
/*!
* \brief Base class of storage access analysis
*/
class StorageAccessVisitor : public IRVisitor {
class StorageAccessVisitor : public StmtExprVisitor {
public:
/*! \brief Storage access type */
enum AccessType {
......@@ -76,13 +76,13 @@ class StorageAccessVisitor : public IRVisitor {
std::vector<AccessEntry> access;
};
// override visitor pattern
void Visit_(const Load* op) final;
void Visit_(const Store* op) final;
void Visit_(const Evaluate* op) final;
void Visit_(const AttrStmt* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Call* op) final;
void VisitExpr_(const Load* op) final;
void VisitStmt_(const Store* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const For* op) final;
void VisitStmt_(const IfThenElse* op) final;
void VisitExpr_(const Call* op) final;
protected:
StorageAccessVisitor() {
......
......@@ -26,8 +26,7 @@
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
......@@ -48,7 +47,7 @@ using runtime::StorageScope;
using runtime::ThreadScope;
using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator {
class StorageFlattener : public StmtExprMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes,
......@@ -64,8 +63,8 @@ class StorageFlattener : public IRMutator {
cache_line_size_ = cache_line_size;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
......@@ -78,14 +77,14 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::double_buffer_scope &&
op->node->IsInstance<OperationNode>()) {
Operation func = Downcast<Operation>(op->node);
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
for (int i = 0; i < func->num_outputs(); ++i) {
TensorKey key{func, i};
auto it = buf_map_.find(key);
......@@ -99,7 +98,7 @@ class StorageFlattener : public IRMutator {
IterVar iv = Downcast<IterVar>(op->node);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
curr_thread_scope_.pop_back();
return stmt;
} else if (op->attr_key == attr::buffer_bind_scope) {
......@@ -116,17 +115,17 @@ class StorageFlattener : public IRMutator {
}
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::opengl_stage_scope) {
is_opengl_ = true;
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Stmt VisitStmt_(const Provide* op) final {
if (create_bound_attributes_)
shape_collector_.clear();
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
......@@ -159,11 +158,11 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
Stmt VisitStmt_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else {
// create a buffer entry
BufferEntry e;
......@@ -226,7 +225,7 @@ class StorageFlattener : public IRMutator {
align, 0, kDefault);
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
buf_map_[key].released = true;
Stmt ret;
......@@ -263,8 +262,8 @@ class StorageFlattener : public IRMutator {
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
......@@ -277,17 +276,17 @@ class StorageFlattener : public IRMutator {
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
return e;
return GetRef<Expr>(op);
}
}
Expr Mutate_(const Call* op, const Expr& olde) final {
Expr expr = IRMutator::Mutate_(op, olde);
Expr VisitExpr_(const Call* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
......@@ -308,8 +307,8 @@ class StorageFlattener : public IRMutator {
}
}
Stmt Mutate_(const Prefetch *op, const Stmt &s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Prefetch *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Prefetch>();
CHECK(op != nullptr);
TensorKey key{op->func, op->value_index};
......@@ -443,7 +442,7 @@ class StorageFlattener : public IRMutator {
// Apply the remaps
Stmt body = MergeNest(binder.asserts(), op->body);
body = MergeNest(binder.init_nest(), body);
body = this->Mutate(body);
body = this->VisitStmt(body);
// remove the binds
for (const Var& v : binder.defs()) {
var_remap_.erase(v.get());
......@@ -531,10 +530,10 @@ class StorageFlattener : public IRMutator {
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) {
IRVisitorWithAnalyzer bounded_analyzer;
bounded_analyzer.Visit(stmt);
bounded_analyzer(stmt);
stmt =
StorageFlattener(extern_buffer, cache_line_size,
create_bound_attributes, &bounded_analyzer).Mutate(stmt);
create_bound_attributes, &bounded_analyzer)(std::move(stmt));
return stmt;
}
......
......@@ -24,8 +24,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/target_info.h>
#include <map>
#include <unordered_set>
......@@ -54,7 +53,7 @@ using runtime::StorageScope;
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class LinearAccessPatternFinder final : public IRVisitor {
class LinearAccessPatternFinder final : public StmtExprVisitor {
public:
/*! \brief record the touch hist of statment. */
struct StmtEntry {
......@@ -78,7 +77,7 @@ class LinearAccessPatternFinder final : public IRVisitor {
const Allocate* alloc{nullptr};
};
void Visit_(const Allocate* op) final {
void VisitStmt_(const Allocate* op) final {
size_t level = scope_.size();
const Variable* buf = op->buffer_var.get();
auto it = alloc_info_.find(buf);
......@@ -86,12 +85,12 @@ class LinearAccessPatternFinder final : public IRVisitor {
CHECK(it->second.alloc == nullptr);
it->second.alloc = op;
it->second.level = level;
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const Store* op) final {
void VisitStmt_(const Store* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// Add write access.
const Variable* buf = op->buffer_var.get();
auto it = alloc_info_.find(buf);
......@@ -106,10 +105,10 @@ class LinearAccessPatternFinder final : public IRVisitor {
linear_seq_.push_back(e);
}
}
void Visit_(const Evaluate* op) final {
void VisitStmt_(const Evaluate* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
......@@ -117,9 +116,9 @@ class LinearAccessPatternFinder final : public IRVisitor {
linear_seq_.push_back(e);
}
}
void Visit_(const Load* op) final {
void VisitExpr_(const Load* op) final {
// Add write access.
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
const Variable* buf = op->buffer_var.get();
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
......@@ -128,15 +127,15 @@ class LinearAccessPatternFinder final : public IRVisitor {
scope_[it->second.level].touched.push_back(buf);
}
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load* l = op->args[0].as<Load>();
this->Visit(l->index);
this->VisitExpr(l->index);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
}
}
void Visit_(const Variable* buf) final {
void VisitExpr_(const Variable* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
......@@ -153,7 +152,7 @@ class LinearAccessPatternFinder final : public IRVisitor {
int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
// before scope.
linear_seq_.push_back(e);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
// after scope.
e.touched = std::move(scope_.back().touched);
scope_.pop_back();
......@@ -165,7 +164,7 @@ class LinearAccessPatternFinder final : public IRVisitor {
CHECK_NE(end_index, 0U);
linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
}
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
// Only record the outer most thread extent.
if (op->attr_key == attr::thread_extent && !in_thread_env_) {
in_thread_env_ = true;
......@@ -179,20 +178,20 @@ class LinearAccessPatternFinder final : public IRVisitor {
const Variable* buf = op->node.as<Variable>();
alloc_info_[buf].storage_scope =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const IfThenElse* op) final {
void VisitStmt_(const IfThenElse* op) final {
VisitNewScope(op);
}
void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
VisitNewScope(op);
}
void Visit_(const AssertStmt* op) final {
void VisitStmt_(const AssertStmt* op) final {
VisitNewScope(op);
}
......@@ -234,7 +233,7 @@ class LinearAccessPatternFinder final : public IRVisitor {
//
// The code after inplace transformation is no longer idempotent.
//
class InplaceOpVerifier : public IRVisitor {
class InplaceOpVerifier : public StmtExprVisitor {
public:
bool Check(const Object* stmt,
const Variable* dst,
......@@ -243,58 +242,62 @@ class InplaceOpVerifier : public IRVisitor {
src_ = src;
result_ = true;
if (stmt->IsInstance<AttrStmt>()) {
Visit_(static_cast<const AttrStmt*>(stmt));
VisitStmt_(static_cast<const AttrStmt*>(stmt));
} else if (stmt->IsInstance<For>()) {
Visit_(static_cast<const For*>(stmt));
VisitStmt_(static_cast<const For*>(stmt));
} else if (stmt->IsInstance<IfThenElse>()) {
Visit_(static_cast<const IfThenElse*>(stmt));
VisitStmt_(static_cast<const IfThenElse*>(stmt));
} else if (stmt->IsInstance<Store>()) {
Visit_(static_cast<const Store*>(stmt));
VisitStmt_(static_cast<const Store*>(stmt));
} else {
return false;
}
return result_;
}
using IRVisitor::Visit_;
using StmtExprVisitor::VisitStmt_;
void Visit(const ObjectRef& e) final {
void VisitStmt(const Stmt& n) final {
if (!result_) return;
IRVisitor::Visit(e);
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr(const Expr& n) final {
if (!result_) return;
StmtExprVisitor::VisitExpr(n);
}
void Visit_(const Variable* op) final {
void VisitExpr_(const Variable* op) final {
// assume all opaque access is unsafe
if (op == dst_ || op == src_) {
result_ = false; return;
}
}
void Visit_(const Store* op) final {
void VisitStmt_(const Store* op) final {
++mem_nest_;
this->Visit(op->index);
this->VisitExpr(op->index);
--mem_nest_;
if (op->buffer_var.get() == dst_) {
store_ = op;
this->Visit(op->value);
this->Visit(op->predicate);
this->VisitExpr(op->value);
this->VisitExpr(op->predicate);
store_ = nullptr;
} else {
this->Visit(op->value);
this->Visit(op->predicate);
this->VisitExpr(op->value);
this->VisitExpr(op->predicate);
}
}
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
// always reject extern code
if (op->attr_key == attr::extern_scope ||
op->attr_key == attr::volatile_scope) {
result_ = false; return;
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const Load* op) final {
void VisitExpr_(const Load* op) final {
const Variable* buf = op->buffer_var.get();
// cannot read from dst_ (no reduction)
if (buf == dst_) {
......@@ -312,7 +315,7 @@ class InplaceOpVerifier : public IRVisitor {
}
}
++mem_nest_;
IRVisitor::Visit_(op);
StmtExprVisitor::VisitExpr_(op);
--mem_nest_;
}
......@@ -332,7 +335,7 @@ class InplaceOpVerifier : public IRVisitor {
};
// Planner to plan and rewrite memory allocation.
class StoragePlanRewriter : public IRMutator {
class StoragePlanRewriter : public StmtExprMutator {
public:
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
using AllocEntry = LinearAccessPatternFinder::AllocEntry;
......@@ -341,12 +344,12 @@ class StoragePlanRewriter : public IRMutator {
detect_inplace_ = detect_inplace;
// plan the rewrite
LinearAccessPatternFinder finder;
finder.Visit(stmt);
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
this->PrepareNewAlloc();
// start rewrite
stmt = this->Mutate(stmt);
stmt = operator()(std::move(stmt));
if (attach_map_.count(nullptr)) {
std::vector<Stmt> nest;
for (StorageEntry* e : attach_map_.at(nullptr)) {
......@@ -363,8 +366,8 @@ class StoragePlanRewriter : public IRMutator {
}
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
......@@ -373,8 +376,8 @@ class StoragePlanRewriter : public IRMutator {
RemapIndex(op->value.dtype(), op->index, it->second),
op->predicate);
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
......@@ -383,7 +386,7 @@ class StoragePlanRewriter : public IRMutator {
RemapIndex(op->dtype, op->index, it->second),
op->predicate);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
if (it->second->bits_offset != 0) {
......@@ -391,19 +394,21 @@ class StoragePlanRewriter : public IRMutator {
}
return it->second->alloc_var;
} else {
return e;
return GetRef<Expr>(op);
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_map_.find(buffer);
if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e);
if (it == alloc_map_.end()) {
return StmtExprMutator::VisitExpr_(op);
}
const StorageEntry* se = it->second;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
Expr offset = this->VisitExpr(op->args[2]);
Expr extent = this->VisitExpr(op->args[3]);
uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(se->bits_offset % elem_bits, 0U);
if (se->bits_offset != 0) {
......@@ -414,56 +419,56 @@ class StoragePlanRewriter : public IRMutator {
{op->args[0], se->alloc_var, offset, extent, op->args[4]},
op->call_type);
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>();
return AttrStmt::make(
op->node, op->attr_key, op->value,
MakeAttach(svec, op->body));
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
} else if (op->attr_key == attr::volatile_scope) {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmt>();
auto it = alloc_map_.find(op->node.as<Variable>());
if (it == alloc_map_.end()) return stmt;
return AttrStmt::make(
it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<For>();
return For::make(
op->loop_var, op->min, op->extent, op->for_type, op->device_api,
MakeAttach(svec, op->body));
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
return this->Mutate(op->body);
Stmt VisitStmt_(const Allocate* op) final {
return this->VisitStmt(op->body);
}
private:
......@@ -929,28 +934,28 @@ class StoragePlanRewriter : public IRMutator {
// Turn alloc into vector alloc
// if all its access is the same vector type.
class VectorAllocRewriter : public IRMutator {
class VectorAllocRewriter : public StmtExprMutator {
public:
Expr Mutate_(const Load* op, const Expr& e) final {
Expr VisitExpr_(const Load* op) final {
UpdateTypeMap(op->buffer_var.get(), op->dtype);
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt VisitStmt_(const Store* op) final {
UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
UpdateTypeMap(buffer, dtype);
}
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Allocate* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Allocate>();
const auto& tvec = acc_map_[op->buffer_var.get()];
......@@ -989,7 +994,7 @@ class VectorAllocRewriter : public IRMutator {
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
VectorAllocRewriter rewriter;
n->body = rewriter.Mutate(n->body);
n->body = rewriter(n->body);
for (Var arg : f->args) {
if (arg.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[arg.get()];
......@@ -1010,8 +1015,8 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
}
Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(stmt, true);
return VectorAllocRewriter().Mutate(stmt);
stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
return VectorAllocRewriter()(std::move(stmt));
}
} // namespace ir
} // namespace tvm
......@@ -22,8 +22,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
......@@ -197,13 +196,13 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
StorageScope sync_scope_;
};
class ThreadSyncInserter : public IRMutator {
class ThreadSyncInserter : public StmtExprMutator {
public:
ThreadSyncInserter(StorageScope sync_scope,
const std::unordered_set<const Object*>& syncs)
: sync_scope_(sync_scope), syncs_(syncs) {}
Stmt Mutate(Stmt stmt) final {
Stmt VisitStmt(const Stmt& stmt) final {
if (syncs_.size() == 0) return stmt;
if (syncs_.count(stmt.get())) {
Stmt barrier;
......@@ -216,33 +215,33 @@ class ThreadSyncInserter : public IRMutator {
Call::Intrinsic));
}
// Mutate after query, to avoid stmt change.
stmt = IRMutator::Mutate(stmt);
stmt = Block::make(barrier, stmt);
auto ret = StmtExprMutator::VisitStmt(stmt);
ret = Block::make(barrier, ret);
return ret;
} else {
stmt = IRMutator::Mutate(stmt);
return StmtExprMutator::VisitStmt(stmt);
}
return stmt;
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr VisitExpr_(const Load* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].read_count;
}
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt VisitStmt_(const Store* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].write_count;
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
bool temp = true;
std::swap(temp, in_thread_env_);
thread_extents_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
std::swap(temp, in_thread_env_);
// first thread scope.
......@@ -256,15 +255,15 @@ class ThreadSyncInserter : public IRMutator {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
Expr expr = IRMutator::Mutate_(op, e);
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
const Variable* buffer_var = op->args[1].as<Variable>();
......@@ -280,7 +279,7 @@ class ThreadSyncInserter : public IRMutator {
}
return expr;
} else {
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
}
......@@ -363,8 +362,8 @@ class ThreadSyncInserter : public IRMutator {
Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::make(storage_scope);
ThreadSyncPlanner planner(sync_scope);
planner.Visit(stmt);
return ThreadSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt);
planner(stmt);
return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
}
LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
......
......@@ -24,8 +24,7 @@
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
......@@ -73,7 +72,7 @@ Expr unpack_type_cast(const Expr &input, const DataType &target_type) {
// MMAMatcher matches C = Cast(A)*Cast(B)+C,
// where A & B are fp16/int8 local buffers,
// and C is fp32/int32 local buffer.
class MMAMatcher: public IRVisitor {
class MMAMatcher: public StmtVisitor {
public:
explicit MMAMatcher(Map<Tensor, Buffer> extern_buffer) {
for (auto kv : extern_buffer) {
......@@ -84,22 +83,21 @@ class MMAMatcher: public IRVisitor {
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi;
}
}
using IRVisitor::Visit_;
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::pragma_tensor_core) {
tensor_core_on_ = true;
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
Visit(op->body);
this->VisitStmt(op->body);
} else {
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
}
void Visit_(const Provide* op) final {
IRVisitor::Visit_(op);
void VisitStmt_(const Provide* op) final {
StmtVisitor::VisitStmt_(op);
auto it = buf_map_.find(TensorKey{op->func, op->value_index});
if (it == buf_map_.end()) {
return;
......@@ -113,19 +111,19 @@ class MMAMatcher: public IRVisitor {
}
}
void Visit_(const Realize* op) final {
void VisitStmt_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
if (!buf_map_.at(key).external) {
return;
}
Visit(op->body);
this->VisitStmt(op->body);
} else {
BufferInfo bi;
bi.name = key.GetName();
bi.dtype = op->dtype;
buf_map_[key] = bi;
Visit(op->body);
this->VisitStmt(op->body);
buf_map_[key].released = true;
}
}
......@@ -236,12 +234,11 @@ class MMAMatcher: public IRVisitor {
// BodyVisitor visits the body stmt of original ComputeOp
// to get the access indices of input matrices,
// if it is recognized as matrix multiply.
class BodyVisitor : public IRVisitor {
class BodyVisitor : public StmtExprVisitor {
public:
BodyVisitor() {}
using IRVisitor::Visit_;
void Visit_(const Reduce* op) final {
void VisitExpr_(const Reduce* op) final {
auto* comm_add = op->combiner->result[0].as<Add>();
if (comm_add == nullptr || op->combiner->result.size() > 1) {
return;
......@@ -254,12 +251,12 @@ class BodyVisitor : public IRVisitor {
}
tensorcore_candidate_ = true;
IRVisitor::Visit(source);
StmtExprVisitor::VisitExpr(source);
}
}
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
void VisitExpr_(const Call* op) final {
StmtExprVisitor::VisitExpr_(op);
args_.insert(std::make_pair(op->name, op->args));
}
......@@ -298,7 +295,7 @@ class ScheduleAnalyser {
BodyVisitor body_visitor;
for (Expr expr : compute->body) {
body_visitor.Visit(expr);
body_visitor(expr);
}
if (!body_visitor.tensorcore_candidate_) {
continue;
......@@ -370,12 +367,11 @@ class ScheduleAnalyser {
// IndexVisitor visits access index of fragment
// to record variable for loop scaling
class IndexVisitor : public IRVisitor {
class IndexVisitor : public StmtExprVisitor {
public:
IndexVisitor() {}
using IRVisitor::Visit_;
void Visit_(const Variable* op) final {
void VisitExpr_(const Variable* op) final {
loop_scaling_.insert(std::make_pair(op, scaling_factor_));
}
......@@ -389,7 +385,7 @@ class IndexVisitor : public IRVisitor {
// BufferAnalyser gets buffer info,
// e.g. thread tile and warp tile, for TensorCore CodeGen
class BufferAnalyser : public IRVisitor {
class BufferAnalyser : public StmtExprVisitor {
public:
explicit BufferAnalyser(Map<Tensor, Buffer> extern_buffer,
const ScheduleAnalyser &schedule_analyser,
......@@ -407,9 +403,8 @@ class BufferAnalyser : public IRVisitor {
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi;
}
}
using IRVisitor::Visit_;
void Visit_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
if (const IntImm* value = op->value.as<IntImm>()) {
thread_extent_.insert(
......@@ -417,10 +412,10 @@ class BufferAnalyser : public IRVisitor {
op->node.as<IterVarNode>()->var->name_hint,
value->value));
}
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
Visit(op->body);
this->VisitStmt(op->body);
} else if (op->attr_key == attr::buffer_dim_align) {
Tensor tensor = Downcast<Tensor>(op->node);
const Call* tuple = op->value.as<Call>();
......@@ -432,14 +427,14 @@ class BufferAnalyser : public IRVisitor {
}
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
Visit(op->body);
this->VisitStmt(op->body);
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const Provide* op) final {
IRVisitor::Visit_(op);
void VisitStmt_(const Provide* op) final {
StmtExprVisitor::VisitStmt_(op);
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
......@@ -503,7 +498,7 @@ class BufferAnalyser : public IRVisitor {
}
auto index = rel_index[i];
auto simplified_index = ir::Simplify(index);
index_visitor.Visit(simplified_index);
index_visitor(simplified_index);
}
std::string input_name = simplify_name(bi.name);
......@@ -550,8 +545,8 @@ class BufferAnalyser : public IRVisitor {
}
}
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
void VisitExpr_(const Call* op) final {
StmtExprVisitor::VisitExpr_(op);
if (op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
......@@ -606,16 +601,16 @@ class BufferAnalyser : public IRVisitor {
}
auto index = rel_index[i];
auto simplified_index = ir::Simplify(index);
index_visitor.Visit(simplified_index);
index_visitor(simplified_index);
}
}
}
void Visit_(const Realize* op) final {
void VisitStmt_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
Visit(op->body);
this->VisitStmt(op->body);
} else {
// create a buffer entry
BufferInfo bi;
......@@ -653,7 +648,7 @@ class BufferAnalyser : public IRVisitor {
bi.shape = shape;
buf_map_[key] = bi;
Visit(op->body);
this->VisitStmt(op->body);
buf_map_[key].released = true;
}
}
......@@ -761,12 +756,12 @@ class BufferAnalyser : public IRVisitor {
};
// ThreadIdxMutator does the thread index unification inside a warp
class ThreadIdxMutator : public IRMutator {
class ThreadIdxMutator : public StmtExprMutator {
public:
explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {}
Expr Mutate_(const Variable* op, const Expr& olde) final {
Expr expr = IRMutator::Mutate_(op, olde);
Expr VisitExpr_(const Variable* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Variable>();
if (op != nullptr) {
if (op->name_hint == "threadIdx.x") {
......@@ -788,7 +783,7 @@ class ThreadIdxMutator : public IRMutator {
// TensorCoreIRMutator mutates the AST for TensorCore CodeGen
// based on tensor core intrinsics
class TensorCoreIRMutator : public IRMutator {
class TensorCoreIRMutator : public StmtExprMutator {
public:
explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser,
const BufferAnalyser &buffer_analyser)
......@@ -803,10 +798,10 @@ class TensorCoreIRMutator : public IRMutator {
warp_tile_(buffer_analyser.warp_tile_),
warp_threads_y_(buffer_analyser.warp_threads_y_) {}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
Stmt VisitStmt_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
bounds_[key] = op->bounds;
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Realize>();
if (op != nullptr) {
if (!frag_reg_.count(key.GetName())) {
......@@ -833,8 +828,8 @@ class TensorCoreIRMutator : public IRMutator {
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const AttrStmt* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == attr::realize_scope) {
auto node = op->node.as<OperationNode>();
if (node != nullptr) {
......@@ -846,7 +841,7 @@ class TensorCoreIRMutator : public IRMutator {
CHECK(it != matrix_abc_.end())
<< "Cannot find matrix info for " << node->name;
auto matrix_abc = "wmma." + it->second;
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
return AttrStmt::make(op->node,
op->attr_key,
matrix_abc,
......@@ -856,8 +851,8 @@ class TensorCoreIRMutator : public IRMutator {
return stmt;
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Provide* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = mma_sync_.find(op);
if (it != mma_sync_.end()) {
const auto &operands = it->second;
......@@ -941,7 +936,7 @@ class TensorCoreIRMutator : public IRMutator {
// thread index unification inside a warp
Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
Expr mutated_value = thread_idx_mutator.Mutate(op->value);
Expr mutated_value = thread_idx_mutator(op->value);
Expr src = Call::make(value->dtype,
"&",
{mutated_value},
......@@ -991,7 +986,7 @@ class TensorCoreIRMutator : public IRMutator {
// thread index unification inside a warp
Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator.Mutate(dst);
dst = thread_idx_mutator(dst);
dst = Call::make(DataType::Handle(),
"&",
{dst},
......@@ -1020,8 +1015,8 @@ class TensorCoreIRMutator : public IRMutator {
return stmt;
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const For* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<For>();
if (op != nullptr) {
auto it = loop_scaling_.find(op->loop_var.get());
......@@ -1177,7 +1172,7 @@ Stmt RewriteForTensorCore(Stmt stmt,
}
MMAMatcher mma_matcher(extern_buffer);
mma_matcher.Visit(stmt);
mma_matcher(stmt);
if (!mma_matcher.Matched()) {
return stmt;
}
......@@ -1189,12 +1184,12 @@ Stmt RewriteForTensorCore(Stmt stmt,
BufferAnalyser buffer_analyser(extern_buffer,
schedule_analyser, mma_matcher);
buffer_analyser.Visit(stmt);
buffer_analyser(stmt);
if (!buffer_analyser.QualifiedForTensorCore()) {
return stmt;
}
return TensorCoreIRMutator(schedule_analyser, buffer_analyser).Mutate(stmt);
return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
}
} // namespace ir
......
......@@ -24,7 +24,7 @@
// Unrolls the loop as in Halide pipeline.
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
......@@ -33,7 +33,7 @@
namespace tvm {
namespace ir {
class LoopUnroller : public IRMutator {
class LoopUnroller : public StmtExprMutator {
public:
explicit LoopUnroller(int auto_max_step,
int auto_max_depth,
......@@ -45,12 +45,12 @@ class LoopUnroller : public IRMutator {
explicit_unroll_(explicit_unroll) {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
std::swap(value, auto_max_step_);
Stmt ret = this->Mutate(op->body);
Stmt ret = this->VisitStmt(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
......@@ -58,16 +58,16 @@ class LoopUnroller : public IRMutator {
CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->Mutate(op->body);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);
return ret;
} else {
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const For* op) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<For>();
int value = GetExtent(op);
// condition for auto unroll
......@@ -110,18 +110,18 @@ class LoopUnroller : public IRMutator {
}
}
Stmt Mutate_(const Store* op, const Stmt& stmt) final {
Stmt VisitStmt_(const Store* op) final {
++step_count_;
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Evaluate* op, const Stmt& stmt) final {
Stmt VisitStmt_(const Evaluate* op) final {
++step_count_;
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Block* op, const Stmt& stmt) final {
Stmt first = this->Mutate(op->first);
Stmt VisitStmt_(const Block* op) final {
Stmt first = this->VisitStmt(op->first);
// cleanup state
int step_count = step_count_;
int unroll_depth = unroll_depth_;
......@@ -130,13 +130,13 @@ class LoopUnroller : public IRMutator {
unroll_depth_ = 0;
normal_loop_depth_ = 0;
// work on rest part
Stmt rest = this->Mutate(op->rest);
Stmt rest = this->VisitStmt(op->rest);
step_count_ += step_count;
normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return stmt;
return GetRef<Stmt>(op);
} else {
return Block::make(first, rest);
}
......@@ -204,7 +204,7 @@ Stmt UnrollLoop(Stmt stmt,
auto_max_step,
auto_max_depth,
auto_max_extent,
explicit_unroll).Mutate(stmt);
explicit_unroll)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
......
......@@ -23,7 +23,7 @@
// Loop vectorizer as in Halide pipeline.
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <unordered_set>
#include <unordered_map>
......@@ -54,13 +54,13 @@ inline Expr BroadcastTo(Expr e, int lanes) {
//
// The same principle applies when using one thread to simulate multiple context.
//
class VecAllocAccess : public IRMutator {
class VecAllocAccess : public StmtExprMutator {
public:
VecAllocAccess(const Variable* buf, Var var, int var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const Load* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<Load>();
if (op->buffer_var.get() == buf_) {
return Load::make(op->dtype, op->buffer_var,
......@@ -71,8 +71,8 @@ class VecAllocAccess : public IRMutator {
}
}
// Store
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const Store* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<Store>();
if (op->buffer_var.get() == buf_) {
return Store::make(op->buffer_var,
......@@ -93,19 +93,16 @@ class VecAllocAccess : public IRMutator {
int var_lanes_;
};
class Vectorizer : public IRMutator {
class Vectorizer : public StmtExprMutator {
public:
Vectorizer(Var var, int var_lanes)
: var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp::make(0, 1, var_lanes);
}
// user mutate from parent.
using IRMutator::Mutate;
Stmt Mutate(Stmt stmt) final {
Stmt VisitStmt(const Stmt& stmt) final {
CHECK(!need_scalarize_);
Stmt ret = IRMutator::Mutate(stmt);
Stmt ret = StmtExprMutator::VisitStmt(stmt);
if (need_scalarize_) {
need_scalarize_ = false;
return Scalarize(stmt);
......@@ -114,19 +111,18 @@ class Vectorizer : public IRMutator {
}
}
Expr Mutate_(const Add* op, const Expr &e) final {
return AddSubVec(op, e);
Expr VisitExpr_(const Add* op) final {
return AddSubVec(op);
}
Expr Mutate_(const Sub* op, const Expr &e) final {
return AddSubVec(op, e);
Expr VisitExpr_(const Sub* op) final {
return AddSubVec(op);
}
Expr Mutate_(const Mul* op, const Expr &e) final {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
Expr VisitExpr_(const Mul* op) final {
Expr a = this->VisitExpr(op->a);
Expr b = this->VisitExpr(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
return GetRef<Expr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
......@@ -143,53 +139,53 @@ class Vectorizer : public IRMutator {
}
return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec(op, e);
return BinaryVec(op);
}
Expr Mutate_(const Div* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const Div* op) final {
return BinaryVec(op);
}
Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const Mod* op) final {
return BinaryVec(op);
}
Expr Mutate_(const FloorDiv* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const FloorDiv* op) final {
return BinaryVec(op);
}
Expr Mutate_(const FloorMod* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const FloorMod* op) final {
return BinaryVec(op);
}
Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const Min* op) final {
return BinaryVec(op);
}
Expr Mutate_(const Max* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const Max* op) final {
return BinaryVec(op);
}
Expr Mutate_(const EQ* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const EQ* op) final {
return BinaryVec(op);
}
Expr Mutate_(const NE* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const NE* op) final {
return BinaryVec(op);
}
Expr Mutate_(const LT* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const LT* op) final {
return BinaryVec(op);
}
Expr Mutate_(const LE* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const LE* op) final {
return BinaryVec(op);
}
Expr Mutate_(const GT* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const GT* op) final {
return BinaryVec(op);
}
Expr Mutate_(const GE* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const GE* op) final {
return BinaryVec(op);
}
Expr Mutate_(const And* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const And* op) final {
return BinaryVec(op);
}
Expr Mutate_(const Or* op, const Expr &e) final {
return BinaryVec(op, e);
Expr VisitExpr_(const Or* op) final {
return BinaryVec(op);
}
Expr Mutate_(const Ramp* op, const Expr &e) final {
Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride);
Expr VisitExpr_(const Ramp* op) final {
Expr base = this->VisitExpr(op->base);
Expr stride = this->VisitExpr(op->stride);
if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
const Ramp* base_ramp = base.as<Ramp>();
if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
......@@ -208,14 +204,14 @@ class Vectorizer : public IRMutator {
}
return Shuffle::make_concat(elems);
}
Expr Mutate_(const Select *op, const Expr& e) final {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
Expr VisitExpr_(const Select *op) final {
Expr cond = this->VisitExpr(op->condition);
Expr t = this->VisitExpr(op->true_value);
Expr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
return GetRef<Expr>(op);
} else {
int lanes = std::max(std::max(
cond.dtype().lanes(),
......@@ -223,37 +219,37 @@ class Vectorizer : public IRMutator {
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
Expr Mutate_(const Cast *op, const Expr& e) final {
Expr value = this->Mutate(op->value);
Expr VisitExpr_(const Cast *op) final {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return e;
return GetRef<Expr>(op);
} else {
return Cast::make(op->dtype.with_lanes(value.dtype().lanes()), value);
}
}
// Variable
Expr Mutate_(const Variable* v, const Expr& e) final {
Expr VisitExpr_(const Variable* v) final {
if (v == var_.get()) {
return ramp_;
} else if (lets_.count(v)) {
return lets_[v];
} else {
return e;
return GetRef<Expr>(v);
}
}
// IfThenElse expr
Expr MutateIfThenElseExpr_(const Call *op, const Expr& e) {
Expr cond = this->Mutate(op->args[0]);
Expr MutateIfThenElseExpr_(const Call *op) {
Expr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_vector()) {
need_scalarize_ = true;
return e;
return GetRef<Expr>(op);
}
Expr t = this->Mutate(op->args[1]);
Expr f = this->Mutate(op->args[2]);
Expr t = this->VisitExpr(op->args[1]);
Expr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) &&
t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return e;
return GetRef<Expr>(op);
} else {
int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
t = BroadcastTo(t, lanes);
......@@ -264,23 +260,23 @@ class Vectorizer : public IRMutator {
}
}
// Call
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->name == intrinsic::tvm_if_then_else) {
return MutateIfThenElseExpr_(op, e);
return MutateIfThenElseExpr_(op);
}
if (!op->is_vectorizable()) {
// Cannot vectorize this op
Array<Expr> new_args;
for (auto arg : op->args) {
auto new_arg = this->Mutate(arg);
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_vector()) {
need_scalarize_ = true;
return e;
return GetRef<Expr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return e;
return GetRef<Expr>(op);
} else {
return Call::make(
op->dtype, op->name, new_args, op->call_type, op->func, op->value_index);
......@@ -290,7 +286,7 @@ class Vectorizer : public IRMutator {
Array<Expr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return e;
return GetRef<Expr>(op);
} else {
return Call::make(
op->dtype.with_lanes(lane), op->name, new_args,
......@@ -299,11 +295,11 @@ class Vectorizer : public IRMutator {
}
}
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
Expr VisitExpr_(const Load* op) final {
Expr index = this->VisitExpr(op->index);
Expr pred = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e;
return GetRef<Expr>(op);
} else {
int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
return Load::make(
......@@ -314,42 +310,42 @@ class Vectorizer : public IRMutator {
}
}
// Let
Expr Mutate_(const Let* op, const Expr& e) final {
Expr value = this->Mutate(op->value);
Expr VisitExpr_(const Let* op) final {
Expr value = this->VisitExpr(op->value);
CHECK(!lets_.count(op->var.get())) << "not SSA";
if (value.dtype().lanes() != op->value.dtype().lanes()) {
Var v(op->var->name_hint, value.dtype());
lets_[op->var.get()] = v;
return Let::make(v, value, Mutate(op->body));
return Let::make(v, value, this->VisitExpr(op->body));
} else {
Expr body = this->Mutate(op->body);
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
}
}
// Provide
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Expr new_value = this->Mutate(op->value);
Stmt VisitStmt_(const Provide* op) final {
Expr new_value = this->VisitExpr(op->value);
int lane = new_value.dtype().lanes();
Array<Expr> new_args = MutateArray(op->args, &lane);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
return GetRef<Stmt>(op);
} else {
new_value = BroadcastTo(new_value, lane);
return Provide::make(op->func, op->value_index, new_value, new_args);
}
}
// Store
Stmt Mutate_(const Store* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
Stmt VisitStmt_(const Store* op) final {
Expr value = this->VisitExpr(op->value);
Expr index = this->VisitExpr(op->index);
Expr pred = this->VisitExpr(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
return GetRef<Stmt>(op);
} else {
int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
lanes = std::max(lanes, pred.dtype().lanes());
......@@ -360,20 +356,20 @@ class Vectorizer : public IRMutator {
}
}
// For
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
if (op->for_type == ForType::Vectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
CHECK(is_zero(op->min));
CHECK(!op->extent.dtype().is_vector());
Expr extent = Mutate(op->extent);
Expr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_vector()) {
return Scalarize(s);
return Scalarize(GetRef<Stmt>(op));
}
Stmt body = Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
return GetRef<Stmt>(op);
} else {
return For::make(
op->loop_var, op->min, extent,
......@@ -381,47 +377,47 @@ class Vectorizer : public IRMutator {
}
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt VisitStmt_(const IfThenElse* op) final {
CHECK(!op->condition.dtype().is_vector());
Expr condition = this->Mutate(op->condition);
Expr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_vector()) {
return Scalarize(s);
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->Mutate(op->then_case);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->Mutate(op->else_case);
else_case = this->VisitStmt(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
return GetRef<Stmt>(op);
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// LetStmt
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const LetStmt* op) final {
LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize";
return Scalarize(s);
return Scalarize(GetRef<Stmt>(op));
}
// Allocate
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt VisitStmt_(const Allocate* op) final {
if (op->new_expr.defined()) {
LOG(WARNING) << "Cannot vectorize with new expr";
return Scalarize(s);
return Scalarize(GetRef<Stmt>(op));
}
Expr condition = Mutate(op->condition);
Expr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc ";
return Scalarize(s);
return Scalarize(GetRef<Stmt>(op));
}
Array<Expr> extents;
for (size_t i = 0; i < op->extents.size(); i++) {
Expr new_ext = Mutate(op->extents[i]);
Expr new_ext = this->VisitExpr(op->extents[i]);
if (new_ext.dtype().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc ";
return Scalarize(s);
return Scalarize(GetRef<Stmt>(op));
}
extents.push_back(new_ext);
}
......@@ -429,8 +425,8 @@ class Vectorizer : public IRMutator {
extents.push_back(var_lanes_);
// rewrite access to buffer internally.
Stmt body = VecAllocAccess(
op->buffer_var.get(), var_, var_lanes_).Mutate(op->body);
body = Mutate(body);
op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
return Allocate::make(
op->buffer_var, op->dtype,
extents, condition, body,
......@@ -466,7 +462,7 @@ class Vectorizer : public IRMutator {
std::vector<Expr> new_arr(arr.size());
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = this->Mutate(old_elem);
Expr new_elem = this->VisitExpr(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
lanes = std::max(lanes, new_elem.dtype().lanes());
......@@ -482,24 +478,24 @@ class Vectorizer : public IRMutator {
return Array<Expr>(new_arr);
}
template<typename T>
Expr BinaryVec(const T* op, const Expr& e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
Expr BinaryVec(const T* op) {
Expr a = this->VisitExpr(op->a);
Expr b = this->VisitExpr(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
return GetRef<Expr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
template<typename T>
Expr AddSubVec(const T* op, const Expr& e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
Expr AddSubVec(const T* op) {
Expr a = this->VisitExpr(op->a);
Expr b = this->VisitExpr(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
return GetRef<Expr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
......@@ -521,9 +517,9 @@ class Vectorizer : public IRMutator {
}
};
class LoopVectorizer : public IRMutator {
class LoopVectorizer : public StmtMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt VisitStmt_(const For* op) final {
if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min));
int lanes = 0;
......@@ -531,21 +527,21 @@ class LoopVectorizer : public IRMutator {
if (!succ || lanes < 1) {
LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
}
return Vectorizer(op->loop_var, lanes).Mutate(op->body);
return Vectorizer(op->loop_var, lanes)(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtMutator::VisitStmt_(op);
}
}
};
Stmt VectorizeLoop(Stmt stmt) {
return LoopVectorizer().Mutate(stmt);
return LoopVectorizer()(std::move(stmt));
}
class VectorizeSkipper : public IRMutator {
class VectorizeSkipper : public StmtMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
Stmt VisitStmt_(const For* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<For>();
if (op->for_type == ForType::Vectorized) {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
......@@ -557,7 +553,7 @@ class VectorizeSkipper : public IRMutator {
};
Stmt SkipVectorize(Stmt stmt) {
return VectorizeSkipper().Mutate(stmt);
return VectorizeSkipper()(std::move(stmt));
}
} // namespace ir
......
......@@ -24,7 +24,7 @@
#include <tvm/buffer.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/tensor.h>
#include <unordered_map>
......@@ -32,15 +32,15 @@
namespace tvm {
namespace ir {
class VerifyBuffer : public IRVisitor {
class VerifyBuffer : public StmtVisitor {
public:
bool Verify(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
return is_compact_;
}
void Visit_(const AttrStmt* op) final {
IRVisitor::Visit_(op);
void VisitStmt_(const AttrStmt* op) final {
StmtVisitor::VisitStmt_(op);
if (op->attr_key == attr::buffer_bind_scope) {
is_compact_ = true;
}
......
......@@ -26,12 +26,12 @@
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
class GPUCodeVerifier : public IRVisitor {
class GPUCodeVerifier : public StmtVisitor {
public:
bool Verify(tvm::Stmt stmt,
int64_t max_local_memory_per_block,
......@@ -49,12 +49,12 @@ class GPUCodeVerifier : public IRVisitor {
Reset_();
this->Visit(stmt);
this->VisitStmt(stmt);
return valid_;
}
void Visit_(const ProducerConsumer *op) {
void VisitStmt_(const ProducerConsumer* op) final {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
......@@ -62,10 +62,10 @@ class GPUCodeVerifier : public IRVisitor {
if (op->is_producer) {
nest_level_++;
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
nest_level_--;
} else {
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
if (nest_level_ == 0) {
......@@ -77,8 +77,8 @@ class GPUCodeVerifier : public IRVisitor {
}
}
void Visit_(const Allocate *op) {
IRVisitor::Visit_(op);
void VisitStmt_(const Allocate* op) final {
StmtVisitor::VisitStmt_(op);
// visit an allocation of a buffer in shared memory, record its size
if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
size_t size = static_cast<size_t>(op->constant_allocation_size());
......@@ -89,7 +89,7 @@ class GPUCodeVerifier : public IRVisitor {
}
}
void Visit_(const AttrStmt *op) {
void VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
std::string op_value = op->value.as<StringImm>()->value;
if (op_value == "local") {
......@@ -132,7 +132,7 @@ class GPUCodeVerifier : public IRVisitor {
}
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
private:
......
......@@ -22,8 +22,9 @@
* \brief Pass to check if memory accesses are legal.
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
......@@ -39,7 +40,7 @@ namespace {
* This pass performs such verification by checking if all Producer/Consumer
* with memory accesses are bound with threads when device type is GPU.
*/
class MemoryAccessVerifier final : protected IRVisitor {
class MemoryAccessVerifier final : protected StmtExprVisitor {
public:
/// Special member functions
//@{
......@@ -55,7 +56,7 @@ class MemoryAccessVerifier final : protected IRVisitor {
/// Interface to perform memory access verification
void Run() {
if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return;
IRVisitor::Visit(func_->body);
StmtExprVisitor::VisitStmt(func_->body);
}
/// Verification result
......@@ -64,42 +65,47 @@ class MemoryAccessVerifier final : protected IRVisitor {
protected:
/// Visitor implementation
//@{
void Visit(const ObjectRef &n) final {
void VisitExpr(const Expr &n) final {
if (Failed()) return;
StmtExprVisitor::VisitExpr(n);
}
void VisitStmt(const Stmt &n) final {
if (Failed()) return;
IRVisitor::Visit(n);
StmtExprVisitor::VisitStmt(n);
}
void Visit_(const LetStmt *op) final {
void VisitStmt_(const LetStmt* op) final {
// Book keep definitions
defs_[op->var.get()] = op->value;
return IRVisitor::Visit_(op);
return StmtExprVisitor::VisitStmt_(op);
}
void Visit_(const AttrStmt *op) final {
void VisitStmt_(const AttrStmt* op) final {
if (!InThreadEnv() && (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope)) {
EnterThreadEnv();
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitThreadEnv();
} else {
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
}
}
void Visit_(const ProducerConsumer *op) final {
void VisitStmt_(const ProducerConsumer* op) final {
EnterProducerConsumer(op);
IRVisitor::Visit_(op);
StmtExprVisitor::VisitStmt_(op);
ExitProducerConsumer();
}
void Visit_(const Load *op) final {
void VisitExpr_(const Load* op) final {
HandleLoadStoreToVariable(op->buffer_var);
return IRVisitor::Visit_(op);
return StmtExprVisitor::VisitExpr_(op);
}
void Visit_(const Store *op) final {
void VisitStmt_(const Store* op) final {
HandleLoadStoreToVariable(op->buffer_var);
return IRVisitor::Visit_(op);
return StmtExprVisitor::VisitStmt_(op);
}
//@}
......
......@@ -22,23 +22,23 @@
*/
#include <tvm/schedule_pass.h>
#include <tvm/operation.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace schedule {
using namespace ir;
class ElemWiseDetector : public ir::IRVisitor {
class ElemWiseDetector : public ir::ExprVisitor {
public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (!is_elem_wise_) return;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}
void Visit_(const Call* op) final {
void VisitExpr_(const Call* op) final {
Array<Expr> axis = op->args;
if (axis_.size() != axis.size()) {
is_elem_wise_ = false;
......@@ -51,7 +51,7 @@ class ElemWiseDetector : public ir::IRVisitor {
return;
}
}
IRVisitor::Visit_(op);
ExprVisitor::VisitExpr_(op);
}
bool is_elem_wise_{true};
......@@ -64,7 +64,7 @@ class ElemWiseDetector : public ir::IRVisitor {
bool IsElemWise(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
ElemWiseDetector v = ElemWiseDetector(compute->axis);
for (auto& e : compute->body) v.Visit(e);
for (auto& e : compute->body) v(e);
return v.is_elem_wise_;
}
return false;
......
......@@ -21,7 +21,6 @@
* \file bound.cc
* \brief The bound inference logic.
*/
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include <tvm/operation.h>
#include <tvm/ir_pass.h>
......
......@@ -22,7 +22,7 @@
* \brief Utilities to get information about schedule graph.
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/operation.h>
#include <utility>
#include <unordered_set>
......
......@@ -22,7 +22,7 @@
*/
#include <tvm/schedule.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "message_passing.h"
......@@ -42,24 +42,24 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) {
}
// The replacer of cache.
class VarReplacer : public ir::IRMutator {
class VarReplacer : public ir::StmtExprMutator {
public:
explicit VarReplacer(
const std::unordered_map<const Variable*, Expr>& vsub)
: vsub_(vsub) {}
Expr Mutate_(const Variable* op, const Expr& e) {
Expr VisitExpr_(const Variable* op) final {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
return e;
return GetRef<Expr>(op);
}
ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
// Replace free variables in combiner
auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) {
return this->Mutate(e);
return this->VisitExpr(e);
});
auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) {
return this->Mutate(e);
return this->VisitExpr(e);
});
if (combiner->identity_element.same_as(new_identity) &&
......@@ -71,8 +71,8 @@ class VarReplacer : public ir::IRMutator {
}
}
Expr Mutate_(const ir::Reduce* op, const Expr& e) {
Expr new_e = IRMutator::Mutate_(op, e);
Expr VisitExpr_(const ir::Reduce* op) final {
Expr new_e = StmtExprMutator::VisitExpr_(op);
const ir::Reduce* new_reduce = new_e.as<ir::Reduce>();
ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
if (op->combiner.same_as(new_combiner)) {
......@@ -316,9 +316,9 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
Array<Expr> body_list;
const ir::Reduce* first_reduce = nullptr;
for (auto cbody : compute->body) {
body = VarReplacer(vsub).Mutate(cbody);
body = VarReplacer(vsub)(cbody);
body = InjectPredicate(predicates, body);
body = VarReplacer(vsub2newvar).Mutate(body);
body = VarReplacer(vsub2newvar)(body);
// Reduce nodes in ONE computeOp must be the same except value_index
// This is right only if the original body ensures Reduce nodes are the same
if (body->IsInstance<ir::Reduce>()) {
......@@ -404,8 +404,8 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
for (Region old_region : tensor_op->input_regions) {
Region region;
for (Range r : old_region) {
Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
Expr min = VarReplacer(vsub2newvar)(r->min);
Expr extent = VarReplacer(vsub2newvar)(r->extent);
region.push_back(Range::make_by_min_extent(min, extent));
}
new_regions.push_back(region);
......@@ -413,7 +413,7 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
Array<Expr> new_scalar_inputs;
for (Expr old_input : tensor_op->scalar_inputs) {
new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
}
Operation cache_op = TensorComputeOpNode::make(
......@@ -786,9 +786,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
}
VarReplacer replacer(vsub);
Array<Expr> new_source = ir::UpdateArray(reduce->source,
[&replacer] (const Expr& e) { return replacer.Mutate(e); });
[&replacer] (const Expr& e) { return replacer(e); });
Expr new_pred = replacer.Mutate(predicate);
Expr new_pred = replacer(predicate);
std::vector<Expr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
......
......@@ -22,7 +22,6 @@
*/
#include <tvm/schedule.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "graph.h"
......
......@@ -21,9 +21,8 @@
* \file schedule_ops.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/operation.h>
#include <tvm/schedule_pass.h>
#include <utility>
......@@ -71,7 +70,7 @@ Stmt MakePipeline(const Stage& s,
}
// inject the operator's realization on the stmt.
class InjectAttach : public IRMutator {
class InjectAttach : public StmtMutator {
public:
InjectAttach(const Stage& stage,
const Stage& attach_spec,
......@@ -80,9 +79,9 @@ class InjectAttach : public IRMutator {
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
Stmt VisitStmt(const Stmt& input_stmt) final {
CHECK(input_stmt.defined());
auto stmt = StmtMutator::VisitStmt(input_stmt);
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->attr_key == attr::loop_scope) {
......@@ -115,7 +114,7 @@ class InjectAttach : public IRMutator {
};
// inject the operator's realization on the stmt.
class InjectScanStep : public IRMutator {
class InjectScanStep : public StmtMutator {
public:
InjectScanStep(const Stage& stage,
const Operation& scan_op,
......@@ -125,9 +124,9 @@ class InjectScanStep : public IRMutator {
: stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
Stmt VisitStmt(const Stmt& input_stmt) final {
CHECK(input_stmt.defined());
auto stmt = StmtMutator::VisitStmt(input_stmt);
// update
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
......@@ -161,12 +160,12 @@ class InjectScanStep : public IRMutator {
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator {
class SchedulePostProc : public StmtExprMutator {
public:
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
Stmt VisitStmt_(const ProducerConsumer* op) final {
auto it = replace_op_.find(op->func.get());
if (it != replace_op_.end()) {
Stmt body = this->Mutate(op->body);
Stmt body = this->VisitStmt(op->body);
if (it->second.defined()) {
return ProducerConsumer::make(
it->second, op->is_producer, body);
......@@ -174,36 +173,36 @@ class SchedulePostProc : public IRMutator {
return body;
}
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const LetStmt* op) final {
if (!HasSideEffect(op->value)) {
var_value_[op->var.get()] = Mutate(op->value);
return this->Mutate(op->body);
var_value_[op->var.get()] = this->VisitExpr(op->value);
return this->VisitStmt(op->body);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt VisitStmt_(const AttrStmt* op) final {
if (op->attr_key == attr::loop_scope ||
op->attr_key == attr::scan_init_scope) {
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value;
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::thread_extent) {
// delete duplicated thread extent attr
auto it = thread_extent_scope_.find(op->node.get());
if (it != thread_extent_scope_.end()) {
CHECK(is_zero(ir::Simplify(it->second - op->value)));
return this->Mutate(op->body);
return this->VisitStmt(op->body);
} else {
thread_extent_scope_[op->node.get()] = op->value;
Stmt ret = IRMutator::Mutate_(op, s);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extent_scope_.erase(op->node.get());
return ret;
}
......@@ -214,9 +213,9 @@ class SchedulePostProc : public IRMutator {
if (it->second.defined()) {
Stmt ret = AttrStmt::make(
it->second, op->attr_key, op->value, op->body);
return this->Mutate(ret);
return this->VisitStmt(ret);
} else {
return this->Mutate(op->body);
return this->VisitStmt(op->body);
}
}
} else if (op->attr_key == ir::attr::buffer_bind_scope) {
......@@ -227,9 +226,9 @@ class SchedulePostProc : public IRMutator {
if (it->second.defined()) {
return AttrStmt::make(
Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
op->attr_key, op->value, Mutate(op->body));
op->attr_key, op->value, this->VisitStmt(op->body));
} else {
return this->Mutate(op->body);
return this->VisitStmt(op->body);
}
}
} else if (op->attr_key == ir::attr::buffer_dim_align) {
......@@ -239,16 +238,16 @@ class SchedulePostProc : public IRMutator {
if (it->second.defined()) {
return AttrStmt::make(
it->second.output(tensor->value_index),
op->attr_key, op->value, Mutate(op->body));
op->attr_key, op->value, this->VisitStmt(op->body));
} else {
return this->Mutate(op->body);
return this->VisitStmt(op->body);
}
}
}
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
Stmt VisitStmt_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
auto it = replace_realize_.find(key);
if (it != replace_realize_.end()) {
......@@ -256,29 +255,29 @@ class SchedulePostProc : public IRMutator {
Stmt ret = Realize::make(
it->second->op, it->second->value_index,
op->dtype, op->bounds, op->condition, op->body);
return this->Mutate(ret);
return this->VisitStmt(ret);
} else {
return this->Mutate(op->body);
return this->VisitStmt(op->body);
}
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Stmt VisitStmt_(const Provide* op) final {
TensorKey key{op->func, op->value_index};
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second;
Stmt ret = Provide::make(
dst->op, dst->value_index, op->value, op->args);
return this->Mutate(ret);
return this->VisitStmt(ret);
} else {
return IRMutator::Mutate_(op, s);
return StmtExprMutator::VisitStmt_(op);
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr VisitExpr_(const Call* op) final {
if (op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = replace_buffer_.find(key);
......@@ -287,18 +286,18 @@ class SchedulePostProc : public IRMutator {
Expr ret = Call::make(
op->dtype, dst->op->name, op->args,
op->call_type, dst->op, dst->value_index);
return this->Mutate(ret);
return this->VisitExpr(ret);
}
}
return IRMutator::Mutate_(op, e);
return StmtExprMutator::VisitExpr_(op);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
Expr VisitExpr_(const Variable* op) final {
auto it = var_value_.find(op);
if (it != var_value_.end()) {
return it->second;
} else {
return e;
return GetRef<Expr>(op);
}
}
......@@ -392,14 +391,14 @@ Stmt ScheduleOps(
if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
body = mu.Mutate(body);
body = mu(std::move(body));
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined());
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
body = mu.Mutate(body);
body = mu(std::move(body));
CHECK(mu.found_attach)
<< "did not find attachment point for scan.update";
} else if (attach_spec->attach_type == kInlinedAlready) {
......@@ -411,7 +410,7 @@ Stmt ScheduleOps(
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined());
InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
body = mutator.Mutate(body);
body = mutator(std::move(body));
CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in "
<< attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
......@@ -421,7 +420,7 @@ Stmt ScheduleOps(
}
SchedulePostProc post_proc;
post_proc.Init(sch);
return post_proc.Mutate(body);
return post_proc(std::move(body));
}
} // namespace schedule
......
......@@ -20,6 +20,7 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment