Unverified Commit 518c3fd0 by Tianqi Chen Committed by GitHub

[REFACTOR] Remove old Low-level Visitor/Mutator (#4612)

parent e8a2c9b3
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir_mutator.h
* \brief Defines general IRMutation pass
*/
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <unordered_map>
#include <utility>
#include "expr.h"
#include "ir.h"
#include "tvm/node/functor.h"
namespace tvm {
namespace ir {
/*!
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via Visitor Pattern.
* Also you can implement via NodeFunctor.
* This enables easy extensions of possible new Node.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
* Simply following the same pattern as IRMutator and create a seperate class.
* \sa NodeFunctor
*/
class TVM_DLL IRMutator {
public:
/*!
* \brief mutate expression
* \return the mutated expr
*/
virtual Expr Mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this);
}
/*!
* \brief mutate expression
* \return the mutated stmt
*/
virtual Stmt Mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this);
}
/*! \brief destructor */
virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */
using FMutateExpr = NodeFunctor<Expr(const ObjectRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */
using FMutateStmt = NodeFunctor<Stmt(const ObjectRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Prefetch* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e);
virtual Expr Mutate_(const Mul* op, const Expr& e);
virtual Expr Mutate_(const Div* op, const Expr& e);
virtual Expr Mutate_(const Mod* op, const Expr& e);
virtual Expr Mutate_(const FloorDiv* op, const Expr& e);
virtual Expr Mutate_(const FloorMod* op, const Expr& e);
virtual Expr Mutate_(const Min* op, const Expr& e);
virtual Expr Mutate_(const Max* op, const Expr& e);
virtual Expr Mutate_(const EQ* op, const Expr& e);
virtual Expr Mutate_(const NE* op, const Expr& e);
virtual Expr Mutate_(const LT* op, const Expr& e);
virtual Expr Mutate_(const LE* op, const Expr& e);
virtual Expr Mutate_(const GT* op, const Expr& e);
virtual Expr Mutate_(const GE* op, const Expr& e);
virtual Expr Mutate_(const And* op, const Expr& e);
virtual Expr Mutate_(const Or* op, const Expr& e);
virtual Expr Mutate_(const Reduce* op, const Expr& e);
virtual Expr Mutate_(const Cast* op, const Expr& e);
virtual Expr Mutate_(const Not* op, const Expr& e);
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir_visitor.h
* \brief Visitor to quickly visit IR trees
*/
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_
#include "ir.h"
#include "tvm/node/functor.h"
namespace tvm {
namespace ir {
/*!
* \brief a base class for visitor to iterative traverse the IR
*
* This IRVisitor is implemented via NodeFunctor
* This enables extensions of possible new Node.
*
* \sa ExprFunctor, StmtFunctor, PostOrderVisit
*
* \note If you need to return values during Visit:
* - If it is mutation of the IR, use IRMutator
* - If you want to return other things, consider use ExprFunctor/StmtFunctor
* - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
*
* \code
*
* // This is an example code to show cases for traps in IRVisitor
* // The use case is to count number of Variables in the ir tree.
* class MyCounter : public IRVisitor {
* public:
* int Count(const ObjectRef& n) {
* ret_ = 0;
* this->Visit(n);
* return ret_;
* }
* void Visit_(const Variable* op) final {
* ret_ = 1;
* }
* void Visit_(const Add* op) final {
* ret_ = count(op->a) + count(op->b);
* }
* private:
* int ret_;
* };
* MyCounter counter;
* Var x("x");
* // this returns 2
* CHECK_EQ(counter.Count(x + x), 2);
* // Think what is the result of the following count
* counter.count(Max::make(x, x));
* // The result is actually 1
* // This is because Visit is not overriden for Max
* // so it simply calls Visit for the left and right children
* // and because Count is not called, ret_ is not cleared.
* // There can also be cases where ret_ is forgetten to be set.
*
* // These traps may not happen if we program carefully
* // But it is recommended to use ExprFunctor, which allows direct
* // return the value, this helps us to avoid such problems.
*
* \endcode
*/
class TVM_DLL IRVisitor {
public:
/*!
* \brief recursively visit an IR node
*/
virtual void Visit(const ObjectRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
/*! \brief destructor */
virtual ~IRVisitor() {}
/*! \brief functor type of visitor */
using FVisit = NodeFunctor<void(const ObjectRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
virtual void Visit_(const Add* op);
virtual void Visit_(const Sub* op);
virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op);
virtual void Visit_(const FloorDiv* op);
virtual void Visit_(const FloorMod* op);
virtual void Visit_(const Min* op);
virtual void Visit_(const Max* op);
virtual void Visit_(const EQ* op);
virtual void Visit_(const NE* op);
virtual void Visit_(const LT* op);
virtual void Visit_(const LE* op);
virtual void Visit_(const GT* op);
virtual void Visit_(const GE* op);
virtual void Visit_(const And* op);
virtual void Visit_(const Or* op);
virtual void Visit_(const Reduce* op);
virtual void Visit_(const Cast* op);
virtual void Visit_(const Not* op);
virtual void Visit_(const Select* op);
virtual void Visit_(const Ramp* op);
virtual void Visit_(const Shuffle* op);
virtual void Visit_(const Broadcast* op);
virtual void Visit_(const AssertStmt* op);
virtual void Visit_(const ProducerConsumer* op);
virtual void Visit_(const Provide* op);
virtual void Visit_(const Realize* op);
virtual void Visit_(const Prefetch* op);
virtual void Visit_(const Block* op);
virtual void Visit_(const Evaluate* op);
virtual void Visit_(const IntImm* op);
virtual void Visit_(const UIntImm* op);
virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op);
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_VISITOR_H_
......@@ -51,12 +51,12 @@ enum AnnotationType {
class FeatureVisitor : public StmtExprVisitor {
public:
// for loop
void VisitStmt_(const For *op);
void VisitStmt_(const AttrStmt *op);
void VisitStmt_(const For* op) final;
void VisitStmt_(const AttrStmt* op) final;
// memory access
void VisitExpr_(const Load *op);
void VisitStmt_(const Store *op);
void VisitExpr_(const Load* op) final;
void VisitStmt_(const Store* op) final;
using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_;
......
......@@ -51,7 +51,7 @@ class IndexParser: public ExprVisitor {
this->VisitExpr(expr);
}
void VisitExpr_(const Variable *op) {
void VisitExpr_(const Variable* op) final {
// TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern();
......@@ -60,7 +60,7 @@ class IndexParser: public ExprVisitor {
}
}
void VisitExpr_(const Mul *op) {
void VisitExpr_(const Mul* op) final {
if (op->a.as<Variable>()) {
if (const auto stride = op->b.as<IntImm>()) {
next_stride_ = stride->value;
......
......@@ -90,31 +90,31 @@ class TouchExtractor : public FeatureVisitor {
}
// arithmetic stats
void VisitExpr_(const Add *op) {
void VisitExpr_(const Add* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Sub *op) {
void VisitExpr_(const Sub* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Mul *op) {
void VisitExpr_(const Mul* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Div *op) {
void VisitExpr_(const Div* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Mod *op) {
void VisitExpr_(const Mod* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ir_mutator.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/packed_func_ext.h>
#include "ir_util.h"
namespace tvm {
namespace ir {
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst;
}
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator* m) {
return UpdateArray(arr, [&m](const Expr& e) { return m->Mutate(e); });
}
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator* m) {
std::vector<IterVar> new_dom(rdom.size());
bool changed = false;
for (size_t i = 0; i < rdom.size(); i++) {
IterVar v = rdom[i];
Range r = v->dom;
Expr new_min = m->Mutate(r->min);
Expr new_extent = m->Mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make(
Range::make_by_min_extent(new_min, new_extent),
v->var, v->iter_type, v->thread_tag);
}
if (!changed) {
return rdom;
} else {
return Array<IterVar>(new_dom);
}
}
// Mutate Stmt
#define DISPATCH_TO_MUTATE_STMT(OP) \
set_dispatch<OP>([](const ObjectRef& node, const Stmt& s, IRMutator* m) { \
return m->Mutate_(static_cast<const OP*>(node.get()), s); \
})
Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->attr_key, value, body);
}
}
Stmt IRMutator::Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
Stmt IRMutator::Mutate_(const For* op, const Stmt& s) {
Expr min = this->Mutate(op->min);
Expr extent = this->Mutate(op->extent);
Stmt body = this->Mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
}
Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
IRMutator* m = this;
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->Mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->dtype,
new_extents, condition, body,
new_expr, op->free_function);
}
}
Stmt IRMutator::Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
Stmt IRMutator::Mutate_(const Store* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) {
return s;
} else {
return Store::make(op->buffer_var, value, index, pred);
}
}
Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
auto new_args = MutateArray(op->args, this);
auto new_value = this->Mutate(op->value);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
} else {
return Provide::make(op->func, op->value_index, new_value, new_args);
}
}
Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
IRMutator* m = this;
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
if (!bounds_changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->value_index,
op->dtype, new_bounds,
condition, body);
}
}
Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
if (!bounds_changed) {
return s;
} else {
return Prefetch::make(op->func, op->value_index,
op->dtype, new_bounds);
}
}
Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt first = this->Mutate(op->first);
Stmt rest = this->Mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
}
Stmt IRMutator::Mutate_(const AssertStmt* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
Stmt IRMutator::Mutate_(const ProducerConsumer* op, const Stmt& s) {
Stmt body = this->Mutate(op->body);
if (body.same_as(op->body)) {
return s;
} else {
return ProducerConsumer::make(op->func, op->is_producer, body);
}
}
Stmt IRMutator::Mutate_(const Evaluate* op, const Stmt& s) {
Expr v = this->Mutate(op->value);
if (v.same_as(op->value)) {
return s;
} else {
return Evaluate::make(v);
}
}
Stmt IRMutator::Mutate_(const Free* op, const Stmt& s) {
return s;
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Free)
.DISPATCH_TO_MUTATE_STMT(AssertStmt)
.DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate)
.DISPATCH_TO_MUTATE_STMT(Prefetch);
// Mutate Expr
#define DISPATCH_TO_MUTATE_EXPR(OP) \
set_dispatch<OP>([](const ObjectRef& node, const Expr& e, IRMutator* m) { \
return m->Mutate_(static_cast<const OP*>(node.get()), e); \
})
Expr IRMutator::Mutate_(const Variable* op, const Expr& e) {
return e;
}
Expr IRMutator::Mutate_(const Load* op, const Expr& e) {
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e;
} else {
return Load::make(op->dtype, op->buffer_var, index, pred);
}
}
Expr IRMutator::Mutate_(const Let* op, const Expr& e) {
Expr value = this->Mutate(op->value);
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
}
Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
auto new_args = MutateArray(op->args, this);
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(op->dtype, op->name, new_args, op->call_type,
op->func, op->value_index);
}
}
#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP* op, const Expr& e) { \
Expr a = this->Mutate(op->a); \
Expr b = this->Mutate(op->b); \
if (a.same_as(op->a) && \
b.same_as(op->b)) { \
return e; \
} else { \
return OP::make(a, b); \
} \
}
DEFINE_BIOP_EXPR_MUTATE_(Add)
DEFINE_BIOP_EXPR_MUTATE_(Sub)
DEFINE_BIOP_EXPR_MUTATE_(Mul)
DEFINE_BIOP_EXPR_MUTATE_(Div)
DEFINE_BIOP_EXPR_MUTATE_(Mod)
DEFINE_BIOP_EXPR_MUTATE_(FloorDiv)
DEFINE_BIOP_EXPR_MUTATE_(FloorMod)
DEFINE_BIOP_EXPR_MUTATE_(Min)
DEFINE_BIOP_EXPR_MUTATE_(Max)
DEFINE_BIOP_EXPR_MUTATE_(EQ)
DEFINE_BIOP_EXPR_MUTATE_(NE)
DEFINE_BIOP_EXPR_MUTATE_(LT)
DEFINE_BIOP_EXPR_MUTATE_(LE)
DEFINE_BIOP_EXPR_MUTATE_(GT)
DEFINE_BIOP_EXPR_MUTATE_(GE)
DEFINE_BIOP_EXPR_MUTATE_(And)
DEFINE_BIOP_EXPR_MUTATE_(Or)
Expr IRMutator::Mutate_(const Reduce* op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Array<Expr> new_source = MutateArray(op->source, this);
Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source) &&
op->condition.same_as(new_cond)) {
return e;
} else {
return Reduce::make(
op->combiner, new_source, new_axis, new_cond, op->value_index);
}
}
Expr IRMutator::Mutate_(const Cast* op, const Expr& e) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->dtype, value);
}
}
Expr IRMutator::Mutate_(const Not* op, const Expr& e) {
Expr a = this->Mutate(op->a);
if (a.same_as(op->a)) {
return e;
} else {
return Not::make(a);
}
}
Expr IRMutator::Mutate_(const Select* op, const Expr& e) {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
return Select::make(cond, t, f);
}
}
Expr IRMutator::Mutate_(const Ramp* op, const Expr& e) {
Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return e;
} else {
return Ramp::make(base, stride, op->lanes);
}
}
Expr IRMutator::Mutate_(const Broadcast* op, const Expr& e) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Broadcast::make(value, op->lanes);
}
}
Expr IRMutator::Mutate_(const Shuffle* op, const Expr& e) {
auto new_vec = MutateArray(op->vectors, this);
if (new_vec.same_as(op->vectors)) {
return e;
} else {
return Shuffle::make(new_vec, op->indices);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
return e; \
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Variable)
.DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Add)
.DISPATCH_TO_MUTATE_EXPR(Sub)
.DISPATCH_TO_MUTATE_EXPR(Mul)
.DISPATCH_TO_MUTATE_EXPR(Div)
.DISPATCH_TO_MUTATE_EXPR(Mod)
.DISPATCH_TO_MUTATE_EXPR(FloorDiv)
.DISPATCH_TO_MUTATE_EXPR(FloorMod)
.DISPATCH_TO_MUTATE_EXPR(Min)
.DISPATCH_TO_MUTATE_EXPR(Max)
.DISPATCH_TO_MUTATE_EXPR(EQ)
.DISPATCH_TO_MUTATE_EXPR(NE)
.DISPATCH_TO_MUTATE_EXPR(LT)
.DISPATCH_TO_MUTATE_EXPR(LE)
.DISPATCH_TO_MUTATE_EXPR(GT)
.DISPATCH_TO_MUTATE_EXPR(GE)
.DISPATCH_TO_MUTATE_EXPR(And)
.DISPATCH_TO_MUTATE_EXPR(Or)
.DISPATCH_TO_MUTATE_EXPR(Reduce)
.DISPATCH_TO_MUTATE_EXPR(Cast)
.DISPATCH_TO_MUTATE_EXPR(Not)
.DISPATCH_TO_MUTATE_EXPR(Select)
.DISPATCH_TO_MUTATE_EXPR(Ramp)
.DISPATCH_TO_MUTATE_EXPR(Broadcast)
.DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm)
.DISPATCH_TO_MUTATE_EXPR(StringImm)
.DISPATCH_TO_MUTATE_EXPR(Shuffle);
} // namespace ir
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ir_visitor.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <unordered_set>
namespace tvm {
namespace ir {
IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst;
}
inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) {
v->Visit(arr[i]);
}
}
inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
for (size_t i = 0; i < rdom.size(); i++) {
Range r = rdom[i]->dom;
v->Visit(r->min);
v->Visit(r->extent);
}
}
void IRVisitor::Visit_(const Variable* op) {}
void IRVisitor::Visit_(const LetStmt* op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const AttrStmt* op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const For* op) {
IRVisitor* v = this;
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
}
void IRVisitor::Visit_(const Allocate* op) {
IRVisitor* v = this;
for (size_t i = 0; i < op->extents.size(); i++) {
v->Visit(op->extents[i]);
}
v->Visit(op->body);
v->Visit(op->condition);
if (op->new_expr.defined()) {
v->Visit(op->new_expr);
}
}
void IRVisitor::Visit_(const Load* op) {
this->Visit(op->index);
this->Visit(op->predicate);
}
void IRVisitor::Visit_(const Store* op) {
this->Visit(op->value);
this->Visit(op->index);
this->Visit(op->predicate);
}
void IRVisitor::Visit_(const IfThenElse* op) {
this->Visit(op->condition);
this->Visit(op->then_case);
if (op->else_case.defined()) {
this->Visit(op->else_case);
}
}
void IRVisitor::Visit_(const Let* op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const Free* op) {}
void IRVisitor::Visit_(const Call* op) {
VisitArray(op->args, this);
}
#define DEFINE_BINOP_VISIT_(OP) \
void IRVisitor::Visit_(const OP* op) { \
this->Visit(op->a); \
this->Visit(op->b); \
}
DEFINE_BINOP_VISIT_(Add)
DEFINE_BINOP_VISIT_(Sub)
DEFINE_BINOP_VISIT_(Mul)
DEFINE_BINOP_VISIT_(Div)
DEFINE_BINOP_VISIT_(Mod)
DEFINE_BINOP_VISIT_(FloorDiv)
DEFINE_BINOP_VISIT_(FloorMod)
DEFINE_BINOP_VISIT_(Min)
DEFINE_BINOP_VISIT_(Max)
DEFINE_BINOP_VISIT_(EQ)
DEFINE_BINOP_VISIT_(NE)
DEFINE_BINOP_VISIT_(LT)
DEFINE_BINOP_VISIT_(LE)
DEFINE_BINOP_VISIT_(GT)
DEFINE_BINOP_VISIT_(GE)
DEFINE_BINOP_VISIT_(And)
DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this);
VisitArray(op->source, this);
this->Visit(op->condition);
}
void IRVisitor::Visit_(const Cast* op) {
this->Visit(op->value);
}
void IRVisitor::Visit_(const Not* op) {
this->Visit(op->a);
}
void IRVisitor::Visit_(const Select* op) {
this->Visit(op->condition);
this->Visit(op->true_value);
this->Visit(op->false_value);
}
void IRVisitor::Visit_(const Ramp* op) {
this->Visit(op->base);
this->Visit(op->stride);
}
void IRVisitor::Visit_(const Shuffle* op) {
for (const auto& elem : op->indices)
this->Visit(elem);
for (const auto& elem : op->vectors)
this->Visit(elem);
}
void IRVisitor::Visit_(const Broadcast* op) {
this->Visit(op->value);
}
void IRVisitor::Visit_(const AssertStmt* op) {
this->Visit(op->condition);
this->Visit(op->message);
this->Visit(op->body);
}
void IRVisitor::Visit_(const ProducerConsumer* op) {
this->Visit(op->body);
}
void IRVisitor::Visit_(const Provide* op) {
VisitArray(op->args, this);
this->Visit(op->value);
}
void IRVisitor::Visit_(const Realize* op) {
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
}
this->Visit(op->body);
this->Visit(op->condition);
}
void IRVisitor::Visit_(const Prefetch* op) {
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
}
}
void IRVisitor::Visit_(const Block* op) {
this->Visit(op->first);
this->Visit(op->rest);
}
void IRVisitor::Visit_(const Evaluate* op) {
this->Visit(op->value);
}
#define DEFINE_OP_NO_VISIT_(OP) \
void IRVisitor::Visit_(const OP* op) {}
DEFINE_OP_NO_VISIT_(IntImm)
DEFINE_OP_NO_VISIT_(UIntImm)
DEFINE_OP_NO_VISIT_(FloatImm)
DEFINE_OP_NO_VISIT_(StringImm)
#define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const ObjectRef& node, IRVisitor* v) { \
v->Visit_(static_cast<const OP*>(node.get())); \
})
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Free)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Add)
.DISPATCH_TO_VISIT(Sub)
.DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(FloorDiv)
.DISPATCH_TO_VISIT(FloorMod)
.DISPATCH_TO_VISIT(Min)
.DISPATCH_TO_VISIT(Max)
.DISPATCH_TO_VISIT(EQ)
.DISPATCH_TO_VISIT(NE)
.DISPATCH_TO_VISIT(LT)
.DISPATCH_TO_VISIT(LE)
.DISPATCH_TO_VISIT(GT)
.DISPATCH_TO_VISIT(GE)
.DISPATCH_TO_VISIT(And)
.DISPATCH_TO_VISIT(Or)
.DISPATCH_TO_VISIT(Reduce)
.DISPATCH_TO_VISIT(Cast)
.DISPATCH_TO_VISIT(Not)
.DISPATCH_TO_VISIT(Select)
.DISPATCH_TO_VISIT(Ramp)
.DISPATCH_TO_VISIT(Shuffle)
.DISPATCH_TO_VISIT(Broadcast)
.DISPATCH_TO_VISIT(AssertStmt)
.DISPATCH_TO_VISIT(ProducerConsumer)
.DISPATCH_TO_VISIT(Provide)
.DISPATCH_TO_VISIT(Realize)
.DISPATCH_TO_VISIT(Block)
.DISPATCH_TO_VISIT(Evaluate)
.DISPATCH_TO_VISIT(IntImm)
.DISPATCH_TO_VISIT(UIntImm)
.DISPATCH_TO_VISIT(FloatImm)
.DISPATCH_TO_VISIT(StringImm)
.DISPATCH_TO_VISIT(Prefetch);
} // namespace ir
} // namespace tvm
......@@ -41,6 +41,19 @@ TEST(IRF, Basic) {
CHECK_EQ(f(z, 2), 4);
}
TEST(IRF, CountVar) {
using namespace tvm;
int n_var = 0;
Var x("x"), y;
auto z = x + 1 + y + y;
ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
if (n.as<Variable>()) ++n_var;
});
CHECK_EQ(n_var, 2);
}
TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::ir;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
namespace {
using namespace tvm;
using namespace tvm::ir;
// replace variable to constant
class IRVar2Const : public IRMutator {
public:
Var var;
int int_val;
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
static FMutateExpr &vtable_expr();
};
// implement vtable
IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const ObjectRef& ref, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) {
return Expr(IntImm::make(DataType::Int(32), vm->int_val));
} else {
return e;
}
});
} // namespace
TEST(IRMutator, Basic) {
using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
IRVar2Const mu;
mu.var = y;
mu.int_val = 10;
auto zz = mu.Mutate(z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 10)");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) {
using namespace tvm;
int n_var = 0;
Var x("x"), y;
auto z = x + 1 + y + y;
ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
if (n.as<Variable>()) ++n_var;
});
CHECK_EQ(n_var, 2);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment