Unverified Commit 983eba88 by Tianqi Chen Committed by GitHub

[IR] Unify approach to Visitor/Mutator under Functor (#4606)

IRMutator and IRVisitor were the main data structures for doing low level IR visiting.
As the project evolves, we start to introduce more powerful variants such as StmtFunctor and ExprFunctor.
This PR brings new classes that allows us to migrate the visitor mutator to be sub-class of these functors.

List of changes:

- Create separate class for ExprMutator and StmtMutator, following convention used in relay.
- Introduce copy-on-write to StmtMutator that can later benefit the statement mutations
  if we use move semantics and keep a single copy of stmt.
- Move two generic visit mutate util to use the new classes.

We will send followup PRs to migrate the existing passes that use the legacy visitors
to the new one.
parent 1ef1605a
......@@ -287,6 +287,265 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
#undef EXPR_FUNCTOR_DEFAULT
#undef STMT_FUNCTOR_DEFAULT
/*!
* \brief ExprVisitor
*/
class TVM_DLL ExprVisitor :
public ExprFunctor<void(const Expr&)> {
public:
using ExprFunctor::operator();
protected:
using ExprFunctor::VisitExpr;
// list of functions to override.
void VisitExpr_(const Variable* op) override;
void VisitExpr_(const Load* op) override;
void VisitExpr_(const Let* op) override;
void VisitExpr_(const Call* op) override;
void VisitExpr_(const Add* op) override;
void VisitExpr_(const Sub* op) override;
void VisitExpr_(const Mul* op) override;
void VisitExpr_(const Div* op) override;
void VisitExpr_(const Mod* op) override;
void VisitExpr_(const FloorDiv* op) override;
void VisitExpr_(const FloorMod* op) override;
void VisitExpr_(const Min* op) override;
void VisitExpr_(const Max* op) override;
void VisitExpr_(const EQ* op) override;
void VisitExpr_(const NE* op) override;
void VisitExpr_(const LT* op) override;
void VisitExpr_(const LE* op) override;
void VisitExpr_(const GT* op) override;
void VisitExpr_(const GE* op) override;
void VisitExpr_(const And* op) override;
void VisitExpr_(const Or* op) override;
void VisitExpr_(const Reduce* op) override;
void VisitExpr_(const Cast* op) override;
void VisitExpr_(const Not* op) override;
void VisitExpr_(const Select* op) override;
void VisitExpr_(const Ramp* op) override;
void VisitExpr_(const Broadcast* op) override;
void VisitExpr_(const Shuffle* op) override;
void VisitExpr_(const IntImm* op) override;
void VisitExpr_(const UIntImm* op) override;
void VisitExpr_(const FloatImm* op) override;
void VisitExpr_(const StringImm* op) override;
};
/*!
* \brief ExprMutator that mutates expressions.
*/
class TVM_DLL ExprMutator :
protected ExprFunctor<Expr(const Expr&)> {
public:
using ExprFunctor::operator();
protected:
using ExprFunctor::VisitExpr;
// list of functions to override.
Expr VisitExpr_(const Variable* op) override;
Expr VisitExpr_(const Load* op) override;
Expr VisitExpr_(const Let* op) override;
Expr VisitExpr_(const Call* op) override;
Expr VisitExpr_(const Add* op) override;
Expr VisitExpr_(const Sub* op) override;
Expr VisitExpr_(const Mul* op) override;
Expr VisitExpr_(const Div* op) override;
Expr VisitExpr_(const Mod* op) override;
Expr VisitExpr_(const FloorDiv* op) override;
Expr VisitExpr_(const FloorMod* op) override;
Expr VisitExpr_(const Min* op) override;
Expr VisitExpr_(const Max* op) override;
Expr VisitExpr_(const EQ* op) override;
Expr VisitExpr_(const NE* op) override;
Expr VisitExpr_(const LT* op) override;
Expr VisitExpr_(const LE* op) override;
Expr VisitExpr_(const GT* op) override;
Expr VisitExpr_(const GE* op) override;
Expr VisitExpr_(const And* op) override;
Expr VisitExpr_(const Or* op) override;
Expr VisitExpr_(const Reduce* op) override;
Expr VisitExpr_(const Cast* op) override;
Expr VisitExpr_(const Not* op) override;
Expr VisitExpr_(const Select* op) override;
Expr VisitExpr_(const Ramp* op) override;
Expr VisitExpr_(const Broadcast* op) override;
Expr VisitExpr_(const Shuffle* op) override;
Expr VisitExpr_(const IntImm* op) override;
Expr VisitExpr_(const UIntImm* op) override;
Expr VisitExpr_(const FloatImm* op) override;
Expr VisitExpr_(const StringImm* op) override;
};
/*!
* \brief StmtVisitor.
*/
class TVM_DLL StmtVisitor :
protected StmtFunctor<void(const Stmt&)> {
public:
using StmtFunctor::operator();
protected:
using StmtFunctor::VisitStmt;
/*!
* \brief Visitor to Exprs, can be overriden
* to do recursive changes to Exprs.
* \note A common pattern is to call ExprVisitor here,
* or have a class sub-class both StmtVisitor and ExprVisitor
* and redirect Visit to ExprMutator::VisitExpr(Expr)
*/
virtual void VisitExpr(const Expr& e) {}
// statement visitor
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const Free* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const Prefetch* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
};
/*!
* \brief StmtMutator that mutates the statements.
*/
class TVM_DLL StmtMutator :
protected StmtFunctor<Stmt(const Stmt&)> {
public:
/*!
* \brief Mutate stmt.
* \param stmt The input statement to be mutated.
* \return The result of the call
* \note It is important that stmt is passed by value.
* so copy on write can be triggered correctly.
* do mutator(std::move(stmt)) or when copy elison is triggered.
*/
Stmt operator()(Stmt stmt) {
allow_copy_on_write_ = true;
return VisitStmt(stmt);
}
protected:
// We perform copy on write optimizations on the StmtMutator
// so that an unique copy of parent can be mutated inplace
// when some of its children changed.
// We only do such optimization for Stmt nests(instead of Exprs) for now
// as Stmt's parent state is more likely remain unchanged when one of
// its child block changes.
/*!
* \brief Internal state to indicate whether copy on write is enabled.
* COW is enabled iff all the parents of the node are unique.
*/
bool allow_copy_on_write_{false};
/*!
* \brief Perform copy on write on node.
*
* If CopyOnWrite is allowed, directly return
* a strong reference to the node container.
* Otherwise, return a copy of the node.
*
* \return The result object pointer.
*/
template<typename TNode>
ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
if (allow_copy_on_write_) {
// return the old node.
return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node));
} else {
// Make a new copy of the node.
// need to rely on the default copy constructor
return runtime::make_object<TNode>(*node);
}
}
/*!
* \brief Internal mutator that everyone calls.
* \note To override mutate's behavior, override VisitExpr instead.
* \param stmt The input stmt.
* \return The mutated results.
*/
Stmt VisitStmt(const Stmt& stmt) override {
if (allow_copy_on_write_ && !stmt.unique()) {
allow_copy_on_write_ = false;
Stmt ret = StmtFunctor::VisitStmt(stmt);
allow_copy_on_write_ = true;
return ret;
} else {
return StmtFunctor::VisitStmt(stmt);
}
}
/*!
* \brief Visitor to Exprs, can be overriden
* to do recursive changes to Exprs.
* \note A common pattern is to call ExprMutator here,
* or have a class sub-class both StmtMutator and ExprMutator
* and redirect Mutate to ExprMutator::Mutate(Expr)
*/
virtual Expr VisitExpr(const Expr& e) {
return e;
}
// statement visitor
Stmt VisitStmt_(const AttrStmt* op) override;
Stmt VisitStmt_(const IfThenElse* op) override;
Stmt VisitStmt_(const LetStmt* op) override;
Stmt VisitStmt_(const For* op) override;
Stmt VisitStmt_(const Allocate* op) override;
Stmt VisitStmt_(const Store* op) override;
Stmt VisitStmt_(const Free* op) override;
Stmt VisitStmt_(const AssertStmt* op) override;
Stmt VisitStmt_(const ProducerConsumer* op) override;
Stmt VisitStmt_(const Provide* op) override;
Stmt VisitStmt_(const Realize* op) override;
Stmt VisitStmt_(const Prefetch* op) override;
Stmt VisitStmt_(const Block* op) override;
Stmt VisitStmt_(const Evaluate* op) override;
// internal helper.
class Internal;
};
/*!
* \brief Visitor that recursively visit stmts and exprs on them.
*/
class StmtExprVisitor :
public StmtVisitor,
public ExprVisitor {
public:
using StmtVisitor::operator();
using ExprVisitor::operator();
protected:
using StmtVisitor::VisitStmt;
using ExprVisitor::VisitExpr;
void VisitExpr(const Expr& e) override {
return ExprVisitor::VisitExpr(e);
}
};
/*!
* \brief Mutator that recursively mutates stmts and exprs on them.
*/
class StmtExprMutator :
public StmtMutator,
public ExprMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
using StmtMutator::VisitExpr;
using ExprMutator::VisitExpr;
Expr VisitExpr(const Expr& e) override {
return ExprMutator::VisitExpr(e);
}
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
......@@ -123,6 +123,7 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
......@@ -138,7 +139,7 @@ class TVM_DLL IRMutator {
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
Stmt IRTransform(const Stmt& node,
Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
......
......@@ -284,6 +284,48 @@ class Array : public ObjectRef {
inline bool empty() const {
return size() == 0;
}
/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
template<typename F>
inline void MutateByApply(F fmutate) {
ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
if (ptr == nullptr) return;
if (data_.unique()) {
// Copy on write optimization.
// Perform inplace update because this is an unique copy.
for (size_t i = 0; i < ptr->data.size(); ++i) {
// It is important to use move here
// to make prevent the element's ref count from increasing
// so fmutate itself can perform copy-on-write optimization
T old_elem = DowncastNoCheck<T>(std::move(ptr->data[i]));
T new_elem = fmutate(std::move(old_elem));
ptr->data[i] = std::move(new_elem);
}
} else {
// lazily trigger copy if there is element change.
ObjectPtr<ArrayNode> copy;
for (size_t i = 0; i < ptr->data.size(); ++i) {
T old_elem = DowncastNoCheck<T>(ptr->data[i]);
T new_elem = fmutate(old_elem);
if (!new_elem.same_as(ptr->data[i])) {
// copy the old array
if (copy == nullptr) {
copy = runtime::make_object<ArrayNode>(*ptr);
}
copy->data[i] = std::move(new_elem);
}
}
// replace the data with the new copy.
if (copy != nullptr) {
data_ = std::move(copy);
}
}
}
/*! \brief specify container node */
using ContainerType = ArrayNode;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ir_functor.cc
*/
#include <tvm/ir_functor_ext.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
// visitor to implement apply
class IRApplyVisit :
public StmtExprVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
void VisitExpr(const Expr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
ExprVisitor::VisitExpr(node);
f_(node);
}
void VisitStmt(const Stmt& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
StmtVisitor::VisitStmt(node);
f_(node);
}
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
};
void PostOrderVisit(const ObjectRef& node,
std::function<void(const ObjectRef&)> fvisit) {
if (node.as<StmtNode>()) {
IRApplyVisit visitor(fvisit);
visitor(Downcast<Stmt>(node));
} else {
IRApplyVisit visitor(fvisit);
visitor(Downcast<Expr>(node));
}
}
class IRTransformer final :
public StmtExprMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt VisitStmt(const Stmt& stmt) final {
return MutateInternal<Stmt>(stmt, [this](const Stmt& s) {
return StmtMutator::VisitStmt(s);
});
}
Expr VisitExpr(const Expr& expr) final {
return MutateInternal<Expr>(expr, [this](const Expr& e) {
return ExprMutator::VisitExpr(e);
});
}
private:
template <typename T, typename F>
T MutateInternal(const T& node, F fmutate) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return fmutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
T new_node = fmutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(new_node);
if (post.defined()) return post;
}
return new_node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<Expr>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (Expr s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.as<StringImm>()->value.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
}
// Implementation of Visitors
template<typename T, typename F>
inline void VisitArray(const Array<T>& arr, F fvisit) {
for (size_t i = 0; i < arr.size(); i++) {
fvisit(arr[i]);
}
}
void StmtVisitor::VisitStmt_(const LetStmt* op) {
this->VisitExpr(op->value);
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const AttrStmt* op) {
this->VisitExpr(op->value);
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const For* op) {
this->VisitExpr(op->min);
this->VisitExpr(op->extent);
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const Allocate* op) {
VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); });
this->VisitStmt(op->body);
this->VisitExpr(op->condition);
if (op->new_expr.defined()) {
this->VisitExpr(op->new_expr);
}
}
void StmtVisitor::VisitStmt_(const Store* op) {
this->VisitExpr(op->value);
this->VisitExpr(op->index);
this->VisitExpr(op->predicate);
}
void StmtVisitor::VisitStmt_(const IfThenElse* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->then_case);
if (op->else_case.defined()) {
this->VisitStmt(op->else_case);
}
}
void StmtVisitor::VisitStmt_(const Free* op) {}
void StmtVisitor::VisitStmt_(const AssertStmt* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->message);
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const ProducerConsumer* op) {
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const Provide* op) {
VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
this->VisitExpr(op->value);
}
void StmtVisitor::VisitStmt_(const Realize* op) {
VisitArray(op->bounds, [this](const Range& r) {
this->VisitExpr(r->min);
this->VisitExpr(r->extent);
});
this->VisitStmt(op->body);
this->VisitExpr(op->condition);
}
void StmtVisitor::VisitStmt_(const Prefetch* op) {
VisitArray(op->bounds, [this](const Range& r) {
this->VisitExpr(r->min);
this->VisitExpr(r->extent);
});
}
void StmtVisitor::VisitStmt_(const Block* op) {
this->VisitStmt(op->first);
this->VisitStmt(op->rest);
}
void StmtVisitor::VisitStmt_(const Evaluate* op) {
this->VisitExpr(op->value);
}
void ExprVisitor::VisitExpr_(const Variable* op) {}
void ExprVisitor::VisitExpr_(const Load* op) {
this->VisitExpr(op->index);
this->VisitExpr(op->predicate);
}
void ExprVisitor::VisitExpr_(const Let* op) {
this->VisitExpr(op->value);
this->VisitExpr(op->body);
}
void ExprVisitor::VisitExpr_(const Call* op) {
VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
}
#define DEFINE_BINOP_VISIT_(OP) \
void ExprVisitor::VisitExpr_(const OP* op) { \
this->VisitExpr(op->a); \
this->VisitExpr(op->b); \
}
DEFINE_BINOP_VISIT_(Add);
DEFINE_BINOP_VISIT_(Sub);
DEFINE_BINOP_VISIT_(Mul);
DEFINE_BINOP_VISIT_(Div);
DEFINE_BINOP_VISIT_(Mod);
DEFINE_BINOP_VISIT_(FloorDiv);
DEFINE_BINOP_VISIT_(FloorMod);
DEFINE_BINOP_VISIT_(Min);
DEFINE_BINOP_VISIT_(Max);
DEFINE_BINOP_VISIT_(EQ);
DEFINE_BINOP_VISIT_(NE);
DEFINE_BINOP_VISIT_(LT);
DEFINE_BINOP_VISIT_(LE);
DEFINE_BINOP_VISIT_(GT);
DEFINE_BINOP_VISIT_(GE);
DEFINE_BINOP_VISIT_(And);
DEFINE_BINOP_VISIT_(Or);
void ExprVisitor::VisitExpr_(const IntImm* op) {}
void ExprVisitor::VisitExpr_(const UIntImm* op) {}
void ExprVisitor::VisitExpr_(const FloatImm* op) {}
void ExprVisitor::VisitExpr_(const StringImm* op) {}
void ExprVisitor::VisitExpr_(const Reduce* op) {
VisitArray(op->axis, [this](const IterVar& r) {
this->VisitExpr(r->dom->min);
this->VisitExpr(r->dom->extent);
});
VisitArray(op->source, [this](const Expr& e) { this->VisitExpr(e); });
this->VisitExpr(op->condition);
}
void ExprVisitor::VisitExpr_(const Cast* op) {
this->VisitExpr(op->value);
}
void ExprVisitor::VisitExpr_(const Not* op) {
this->VisitExpr(op->a);
}
void ExprVisitor::VisitExpr_(const Select* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->true_value);
this->VisitExpr(op->false_value);
}
void ExprVisitor::VisitExpr_(const Ramp* op) {
this->VisitExpr(op->base);
this->VisitExpr(op->stride);
}
void ExprVisitor::VisitExpr_(const Shuffle* op) {
VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); });
VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); });
}
void ExprVisitor::VisitExpr_(const Broadcast* op) {
this->VisitExpr(op->value);
}
// Implementation of mutators
template<typename T, typename F>
inline Array<T> MutateArray(const Array<T>& arr,
F fmutate,
bool allow_copy_on_write = false) {
if (allow_copy_on_write) {
// if we allow copy on write, we can directly
// call the inplace mutate function.
const_cast<Array<T>&>(arr).MutateByApply(fmutate);
return arr;
} else {
Array<T> copy = arr;
copy.MutateByApply(fmutate);
return copy;
}
}
class StmtMutator::Internal {
public:
static Array<Expr> Mutate(StmtMutator* self, const Array<Expr>& arr) {
auto fmutate = [self](const Expr& e) { return self->VisitExpr(e); };
return MutateArray(arr, fmutate, self->allow_copy_on_write_);
}
static Array<Stmt> Mutate(StmtMutator* self, const Array<Stmt>& arr) {
auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); };
return MutateArray(arr, fmutate, self->allow_copy_on_write_);
}
static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
auto fmutate = [self](const Range& r) {
Expr min = self->VisitExpr(r->min);
Expr extent = self->VisitExpr(r->extent);
if (min.same_as(r->min) && extent.same_as(r->extent)) {
return r;
} else {
return Range::make_by_min_extent(min, extent);
}
};
return MutateArray(arr, fmutate, self->allow_copy_on_write_);
}
};
Stmt StmtMutator::VisitStmt_(const AttrStmt* op) {
Expr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const LetStmt* op) {
Expr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const For* op) {
Expr min = this->VisitExpr(op->min);
Expr extent = this->VisitExpr(op->extent);
Stmt body = this->VisitStmt(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->min = std::move(min);
n->extent = std::move(extent);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Allocate* op) {
Array<Expr> extents = Internal::Mutate(this, op->extents);
Stmt body = this->VisitStmt(op->body);
Expr condition = this->VisitExpr(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = this->VisitExpr(op->new_expr);
}
if (extents.same_as(op->extents) &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->extents = std::move(extents);
n->body = std::move(body);
n->condition = std::move(condition);
n->new_expr = std::move(new_expr);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const IfThenElse* op) {
Expr condition = this->VisitExpr(op->condition);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->VisitStmt(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->condition = std::move(condition);
n->then_case = std::move(then_case);
n->else_case = std::move(then_case);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Store* op) {
Expr value = this->VisitExpr(op->value);
Expr index = this->VisitExpr(op->index);
Expr predicate = this->VisitExpr(op->predicate);
if (value.same_as(op->value) &&
index.same_as(op->index) &&
predicate.same_as(op->predicate)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->value = std::move(value);
n->index = std::move(index);
n->predicate = std::move(predicate);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Provide* op) {
Array<Expr> args = Internal::Mutate(this, op->args);
Expr value = this->VisitExpr(op->value);
if (args.same_as(op->args) &&
value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->args = std::move(args);
n->value = std::move(value);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Realize* op) {
Region bounds = Internal::Mutate(this, op->bounds);
Stmt body = this->VisitStmt(op->body);
Expr condition = this->VisitExpr(op->condition);
if (bounds.same_as(op->bounds) &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->bounds = std::move(bounds);
n->body = std::move(body);
n->condition = std::move(condition);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Prefetch* op) {
Region bounds = Internal::Mutate(this, op->bounds);
if (bounds.same_as(op->bounds)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->bounds = std::move(bounds);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Block* op) {
Stmt first = this->VisitStmt(op->first);
Stmt rest = this->VisitStmt(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->first = std::move(first);
n->rest = std::move(rest);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const AssertStmt* op) {
Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message);
Stmt body = this->VisitStmt(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->condition = std::move(condition);
n->message = std::move(message);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) {
Stmt body = this->VisitStmt(op->body);
if (body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Evaluate* op) {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->value = std::move(value);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const Free* op) {
return GetRef<Stmt>(op);
}
Expr ExprMutator::VisitExpr_(const Variable* op) {
return GetRef<Expr>(op);
}
Expr ExprMutator::VisitExpr_(const Load* op) {
Expr index = this->VisitExpr(op->index);
Expr predicate = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && predicate.same_as(op->predicate)) {
return GetRef<Expr>(op);
} else {
return Load::make(op->dtype, op->buffer_var, index, predicate);
}
}
Expr ExprMutator::VisitExpr_(const Let* op) {
Expr value = this->VisitExpr(op->value);
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
}
}
Expr ExprMutator::VisitExpr_(const Call* op) {
auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); };
Array<Expr> args = MutateArray(op->args, fmutate);
if (args.same_as(op->args)) {
return GetRef<Expr>(op);
} else {
return Call::make(op->dtype,
op->name,
args,
op->call_type,
op->func,
op->value_index);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr ExprMutator::VisitExpr_(const OP *op) { \
return GetRef<Expr>(op); \
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
Expr ExprMutator::VisitExpr_(const OP* op) { \
Expr a = this->VisitExpr(op->a); \
Expr b = this->VisitExpr(op->b); \
if (a.same_as(op->a) && \
b.same_as(op->b)) { \
return GetRef<Expr>(op); \
} else { \
return OP::make(a, b); \
} \
}
DEFINE_BIOP_EXPR_MUTATE_(Add);
DEFINE_BIOP_EXPR_MUTATE_(Sub);
DEFINE_BIOP_EXPR_MUTATE_(Mul);
DEFINE_BIOP_EXPR_MUTATE_(Div);
DEFINE_BIOP_EXPR_MUTATE_(Mod);
DEFINE_BIOP_EXPR_MUTATE_(FloorDiv);
DEFINE_BIOP_EXPR_MUTATE_(FloorMod);
DEFINE_BIOP_EXPR_MUTATE_(Min);
DEFINE_BIOP_EXPR_MUTATE_(Max);
DEFINE_BIOP_EXPR_MUTATE_(EQ);
DEFINE_BIOP_EXPR_MUTATE_(NE);
DEFINE_BIOP_EXPR_MUTATE_(LT);
DEFINE_BIOP_EXPR_MUTATE_(LE);
DEFINE_BIOP_EXPR_MUTATE_(GT);
DEFINE_BIOP_EXPR_MUTATE_(GE);
DEFINE_BIOP_EXPR_MUTATE_(And);
DEFINE_BIOP_EXPR_MUTATE_(Or);
Expr ExprMutator::VisitExpr_(const Reduce* op) {
auto fitervar = [this](const IterVar& v) {
Range r = v->dom;
Expr min = this->VisitExpr(r->min);
Expr extent = this->VisitExpr(r->extent);
if (min.same_as(r->min) &&
extent.same_as(r->extent)) {
return v;
} else {
return IterVarNode::make(
Range::make_by_min_extent(min, extent),
v->var, v->iter_type, v->thread_tag);
}
};
Array<IterVar> axis = MutateArray(op->axis, fitervar);
auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
Array<Expr> source = MutateArray(op->source, fexpr);
Expr condition = this->VisitExpr(op->condition);
if (axis.same_as(op->axis) &&
source.same_as(op->source) &&
condition.same_as(op->condition)) {
return GetRef<Expr>(op);
} else {
return Reduce::make(
op->combiner, source, axis, condition, op->value_index);
}
}
Expr ExprMutator::VisitExpr_(const Cast* op) {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return Cast::make(op->dtype, value);
}
}
Expr ExprMutator::VisitExpr_(const Not* op) {
Expr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<Expr>(op);
} else {
return Not::make(a);
}
}
Expr ExprMutator::VisitExpr_(const Select* op) {
Expr condition = this->VisitExpr(op->condition);
Expr true_value = this->VisitExpr(op->true_value);
Expr false_value = this->VisitExpr(op->false_value);
if (condition.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return GetRef<Expr>(op);
} else {
return Select::make(condition, true_value, false_value);
}
}
Expr ExprMutator::VisitExpr_(const Ramp* op) {
Expr base = this->VisitExpr(op->base);
Expr stride = this->VisitExpr(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return GetRef<Expr>(op);
} else {
return Ramp::make(base, stride, op->lanes);
}
}
Expr ExprMutator::VisitExpr_(const Broadcast* op) {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return Broadcast::make(value, op->lanes);
}
}
Expr ExprMutator::VisitExpr_(const Shuffle* op) {
auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
auto vectors = MutateArray(op->vectors, fexpr);
if (vectors.same_as(op->vectors)) {
return GetRef<Expr>(op);
} else {
return Shuffle::make(vectors, op->indices);
}
}
} // namespace ir
} // namespace tvm
......@@ -28,59 +28,6 @@
namespace tvm {
namespace ir {
class IRTransformer final : public IRMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt Mutate(Stmt stmt) final {
return MutateInternal<Stmt>(stmt);
}
Expr Mutate(Expr expr) final {
return MutateInternal<Expr>(expr);
}
private:
template <typename T>
T MutateInternal(T node) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return IRMutator::Mutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
node = IRMutator::Mutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(node);
if (post.defined()) return post;
}
return node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(const Stmt& ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<Expr>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (Expr s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.as<StringImm>()->value.c_str()));
}
return IRTransformer(f_preorder, f_postorder, only_type_index)
.Mutate(ir_node);
}
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
......
......@@ -26,26 +26,6 @@
namespace tvm {
namespace ir {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
void Visit(const ObjectRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
IRVisitor::Visit(node);
f_(node);
}
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
};
void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst;
......
......@@ -31,7 +31,6 @@ TEST(IRF, Basic) {
auto z = x + 1;
NodeFunctor<int(const ObjectRef& n, int b)> f;
LOG(INFO) << "x";
f.set_dispatch<Variable>([](const ObjectRef& n, int b) {
return b;
});
......@@ -101,6 +100,98 @@ TEST(IRF, ExprVisit) {
CHECK_EQ(v.count, 1);
}
TEST(IRF, StmtVisitor) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
class MyVisitor
: public StmtExprVisitor {
public:
int count = 0;
// implementation
void VisitExpr_(const Variable* op) final {
++count;
}
};
MyVisitor v;
auto fmaketest = [&]() {
auto z = x + 1;
Stmt body = Evaluate::make(z);
Var buffer("b", DataType::Handle());
return Allocate::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
};
v(fmaketest());
CHECK_EQ(v.count, 3);
}
TEST(IRF, StmtMutator) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
class MyVisitor
: public ir::StmtMutator,
public ir::ExprMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
// implementation
Expr VisitExpr_(const Add* op) final {
return op->a;
}
Expr VisitExpr(const Expr& expr) final {
return ExprMutator::VisitExpr(expr);
}
};
auto fmaketest = [&]() {
auto z = x + 1;
Stmt body = Evaluate::make(z);
Var buffer("b", DataType::Handle());
return Allocate::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
};
MyVisitor v;
{
auto body = fmaketest();
Stmt body2 = Evaluate::make(1);
Stmt bref = body.as<Allocate>()->body;
auto* extentptr = body.as<Allocate>()->extents.get();
Array<Stmt> arr{std::move(body), body2, body2};
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() == arrptr);
// inplace update body
CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
CHECK(arr[0].as<Allocate>()->extents.get() == extentptr);
// copy because there is additional refs
CHECK(!arr[0].as<Allocate>()->body.same_as(bref));
CHECK(arr[0].as<Allocate>()->body.as<Evaluate>()->value.same_as(x));
CHECK(bref.as<Evaluate>()->value.as<Add>());
}
{
Array<Stmt> arr{fmaketest()};
// mutate array get reference by another one, triiger copy.
Array<Stmt> arr2 = arr;
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() != arrptr);
CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
CHECK(!arr2[0].as<Allocate>()->extents[1].same_as(x));
// mutate but no content change.
arr2 = arr;
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr2.get() == arr.get());
}
{
auto body = Evaluate::make(Call::make(DataType::Int(32), "xyz", {x + 1}, Call::Extern));
auto res = v(std::move(body));
CHECK(res.as<Evaluate>()->value.as<Call>()->args[0].same_as(x));
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment