Unverified Commit 983eba88 by Tianqi Chen Committed by GitHub

[IR] Unify approach to Visitor/Mutator under Functor (#4606)

IRMutator and IRVisitor were the main data structures for doing low level IR visiting.
As the project evolves, we start to introduce more powerful variants such as StmtFunctor and ExprFunctor.
This PR brings new classes that allows us to migrate the visitor mutator to be sub-class of these functors.

List of changes:

- Create separate class for ExprMutator and StmtMutator, following convention used in relay.
- Introduce copy-on-write to StmtMutator that can later benefit the statement mutations
  if we use move semantics and keep a single copy of stmt.
- Move two generic visit mutate util to use the new classes.

We will send followup PRs to migrate the existing passes that use the legacy visitors
to the new one.
parent 1ef1605a
...@@ -287,6 +287,265 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -287,6 +287,265 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
#undef EXPR_FUNCTOR_DEFAULT #undef EXPR_FUNCTOR_DEFAULT
#undef STMT_FUNCTOR_DEFAULT #undef STMT_FUNCTOR_DEFAULT
/*!
* \brief ExprVisitor
*/
class TVM_DLL ExprVisitor :
public ExprFunctor<void(const Expr&)> {
public:
using ExprFunctor::operator();
protected:
using ExprFunctor::VisitExpr;
// list of functions to override.
void VisitExpr_(const Variable* op) override;
void VisitExpr_(const Load* op) override;
void VisitExpr_(const Let* op) override;
void VisitExpr_(const Call* op) override;
void VisitExpr_(const Add* op) override;
void VisitExpr_(const Sub* op) override;
void VisitExpr_(const Mul* op) override;
void VisitExpr_(const Div* op) override;
void VisitExpr_(const Mod* op) override;
void VisitExpr_(const FloorDiv* op) override;
void VisitExpr_(const FloorMod* op) override;
void VisitExpr_(const Min* op) override;
void VisitExpr_(const Max* op) override;
void VisitExpr_(const EQ* op) override;
void VisitExpr_(const NE* op) override;
void VisitExpr_(const LT* op) override;
void VisitExpr_(const LE* op) override;
void VisitExpr_(const GT* op) override;
void VisitExpr_(const GE* op) override;
void VisitExpr_(const And* op) override;
void VisitExpr_(const Or* op) override;
void VisitExpr_(const Reduce* op) override;
void VisitExpr_(const Cast* op) override;
void VisitExpr_(const Not* op) override;
void VisitExpr_(const Select* op) override;
void VisitExpr_(const Ramp* op) override;
void VisitExpr_(const Broadcast* op) override;
void VisitExpr_(const Shuffle* op) override;
void VisitExpr_(const IntImm* op) override;
void VisitExpr_(const UIntImm* op) override;
void VisitExpr_(const FloatImm* op) override;
void VisitExpr_(const StringImm* op) override;
};
/*!
* \brief ExprMutator that mutates expressions.
*/
class TVM_DLL ExprMutator :
protected ExprFunctor<Expr(const Expr&)> {
public:
using ExprFunctor::operator();
protected:
using ExprFunctor::VisitExpr;
// list of functions to override.
Expr VisitExpr_(const Variable* op) override;
Expr VisitExpr_(const Load* op) override;
Expr VisitExpr_(const Let* op) override;
Expr VisitExpr_(const Call* op) override;
Expr VisitExpr_(const Add* op) override;
Expr VisitExpr_(const Sub* op) override;
Expr VisitExpr_(const Mul* op) override;
Expr VisitExpr_(const Div* op) override;
Expr VisitExpr_(const Mod* op) override;
Expr VisitExpr_(const FloorDiv* op) override;
Expr VisitExpr_(const FloorMod* op) override;
Expr VisitExpr_(const Min* op) override;
Expr VisitExpr_(const Max* op) override;
Expr VisitExpr_(const EQ* op) override;
Expr VisitExpr_(const NE* op) override;
Expr VisitExpr_(const LT* op) override;
Expr VisitExpr_(const LE* op) override;
Expr VisitExpr_(const GT* op) override;
Expr VisitExpr_(const GE* op) override;
Expr VisitExpr_(const And* op) override;
Expr VisitExpr_(const Or* op) override;
Expr VisitExpr_(const Reduce* op) override;
Expr VisitExpr_(const Cast* op) override;
Expr VisitExpr_(const Not* op) override;
Expr VisitExpr_(const Select* op) override;
Expr VisitExpr_(const Ramp* op) override;
Expr VisitExpr_(const Broadcast* op) override;
Expr VisitExpr_(const Shuffle* op) override;
Expr VisitExpr_(const IntImm* op) override;
Expr VisitExpr_(const UIntImm* op) override;
Expr VisitExpr_(const FloatImm* op) override;
Expr VisitExpr_(const StringImm* op) override;
};
/*!
* \brief StmtVisitor.
*/
class TVM_DLL StmtVisitor :
protected StmtFunctor<void(const Stmt&)> {
public:
using StmtFunctor::operator();
protected:
using StmtFunctor::VisitStmt;
/*!
* \brief Visitor to Exprs, can be overriden
* to do recursive changes to Exprs.
* \note A common pattern is to call ExprVisitor here,
* or have a class sub-class both StmtVisitor and ExprVisitor
* and redirect Visit to ExprMutator::VisitExpr(Expr)
*/
virtual void VisitExpr(const Expr& e) {}
// statement visitor
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const Free* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const Prefetch* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
};
/*!
* \brief StmtMutator that mutates the statements.
*/
class TVM_DLL StmtMutator :
protected StmtFunctor<Stmt(const Stmt&)> {
public:
/*!
* \brief Mutate stmt.
* \param stmt The input statement to be mutated.
* \return The result of the call
* \note It is important that stmt is passed by value.
* so copy on write can be triggered correctly.
* do mutator(std::move(stmt)) or when copy elison is triggered.
*/
Stmt operator()(Stmt stmt) {
allow_copy_on_write_ = true;
return VisitStmt(stmt);
}
protected:
// We perform copy on write optimizations on the StmtMutator
// so that an unique copy of parent can be mutated inplace
// when some of its children changed.
// We only do such optimization for Stmt nests(instead of Exprs) for now
// as Stmt's parent state is more likely remain unchanged when one of
// its child block changes.
/*!
* \brief Internal state to indicate whether copy on write is enabled.
* COW is enabled iff all the parents of the node are unique.
*/
bool allow_copy_on_write_{false};
/*!
* \brief Perform copy on write on node.
*
* If CopyOnWrite is allowed, directly return
* a strong reference to the node container.
* Otherwise, return a copy of the node.
*
* \return The result object pointer.
*/
template<typename TNode>
ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
if (allow_copy_on_write_) {
// return the old node.
return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node));
} else {
// Make a new copy of the node.
// need to rely on the default copy constructor
return runtime::make_object<TNode>(*node);
}
}
/*!
* \brief Internal mutator that everyone calls.
* \note To override mutate's behavior, override VisitExpr instead.
* \param stmt The input stmt.
* \return The mutated results.
*/
Stmt VisitStmt(const Stmt& stmt) override {
if (allow_copy_on_write_ && !stmt.unique()) {
allow_copy_on_write_ = false;
Stmt ret = StmtFunctor::VisitStmt(stmt);
allow_copy_on_write_ = true;
return ret;
} else {
return StmtFunctor::VisitStmt(stmt);
}
}
/*!
* \brief Visitor to Exprs, can be overriden
* to do recursive changes to Exprs.
* \note A common pattern is to call ExprMutator here,
* or have a class sub-class both StmtMutator and ExprMutator
* and redirect Mutate to ExprMutator::Mutate(Expr)
*/
virtual Expr VisitExpr(const Expr& e) {
return e;
}
// statement visitor
Stmt VisitStmt_(const AttrStmt* op) override;
Stmt VisitStmt_(const IfThenElse* op) override;
Stmt VisitStmt_(const LetStmt* op) override;
Stmt VisitStmt_(const For* op) override;
Stmt VisitStmt_(const Allocate* op) override;
Stmt VisitStmt_(const Store* op) override;
Stmt VisitStmt_(const Free* op) override;
Stmt VisitStmt_(const AssertStmt* op) override;
Stmt VisitStmt_(const ProducerConsumer* op) override;
Stmt VisitStmt_(const Provide* op) override;
Stmt VisitStmt_(const Realize* op) override;
Stmt VisitStmt_(const Prefetch* op) override;
Stmt VisitStmt_(const Block* op) override;
Stmt VisitStmt_(const Evaluate* op) override;
// internal helper.
class Internal;
};
/*!
* \brief Visitor that recursively visit stmts and exprs on them.
*/
class StmtExprVisitor :
public StmtVisitor,
public ExprVisitor {
public:
using StmtVisitor::operator();
using ExprVisitor::operator();
protected:
using StmtVisitor::VisitStmt;
using ExprVisitor::VisitExpr;
void VisitExpr(const Expr& e) override {
return ExprVisitor::VisitExpr(e);
}
};
/*!
* \brief Mutator that recursively mutates stmts and exprs on them.
*/
class StmtExprMutator :
public StmtMutator,
public ExprMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
using StmtMutator::VisitExpr;
using ExprMutator::VisitExpr;
Expr VisitExpr(const Expr& e) override {
return ExprMutator::VisitExpr(e);
}
};
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_ #endif // TVM_IR_FUNCTOR_EXT_H_
...@@ -123,6 +123,7 @@ class TVM_DLL IRMutator { ...@@ -123,6 +123,7 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const Shuffle* op, const Expr& e); virtual Expr Mutate_(const Shuffle* op, const Expr& e);
}; };
/*! /*!
* \brief recursively visit the ir in post DFS order node, and transform it * \brief recursively visit the ir in post DFS order node, and transform it
* *
...@@ -138,7 +139,7 @@ class TVM_DLL IRMutator { ...@@ -138,7 +139,7 @@ class TVM_DLL IRMutator {
* If it is not empty, preorder/postorder will only be called * If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list. * when the IRNode's type key is in the list.
*/ */
Stmt IRTransform(const Stmt& node, Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder, const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder, const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {}); const Array<Expr>& only_enable = {});
......
...@@ -284,6 +284,48 @@ class Array : public ObjectRef { ...@@ -284,6 +284,48 @@ class Array : public ObjectRef {
inline bool empty() const { inline bool empty() const {
return size() == 0; return size() == 0;
} }
/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
template<typename F>
inline void MutateByApply(F fmutate) {
ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
if (ptr == nullptr) return;
if (data_.unique()) {
// Copy on write optimization.
// Perform inplace update because this is an unique copy.
for (size_t i = 0; i < ptr->data.size(); ++i) {
// It is important to use move here
// to make prevent the element's ref count from increasing
// so fmutate itself can perform copy-on-write optimization
T old_elem = DowncastNoCheck<T>(std::move(ptr->data[i]));
T new_elem = fmutate(std::move(old_elem));
ptr->data[i] = std::move(new_elem);
}
} else {
// lazily trigger copy if there is element change.
ObjectPtr<ArrayNode> copy;
for (size_t i = 0; i < ptr->data.size(); ++i) {
T old_elem = DowncastNoCheck<T>(ptr->data[i]);
T new_elem = fmutate(old_elem);
if (!new_elem.same_as(ptr->data[i])) {
// copy the old array
if (copy == nullptr) {
copy = runtime::make_object<ArrayNode>(*ptr);
}
copy->data[i] = std::move(new_elem);
}
}
// replace the data with the new copy.
if (copy != nullptr) {
data_ = std::move(copy);
}
}
}
/*! \brief specify container node */ /*! \brief specify container node */
using ContainerType = ArrayNode; using ContainerType = ArrayNode;
......
...@@ -28,59 +28,6 @@ ...@@ -28,59 +28,6 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
class IRTransformer final : public IRMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt Mutate(Stmt stmt) final {
return MutateInternal<Stmt>(stmt);
}
Expr Mutate(Expr expr) final {
return MutateInternal<Expr>(expr);
}
private:
template <typename T>
T MutateInternal(T node) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return IRMutator::Mutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
node = IRMutator::Mutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(node);
if (post.defined()) return post;
}
return node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(const Stmt& ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<Expr>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (Expr s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.as<StringImm>()->value.c_str()));
}
return IRTransformer(f_preorder, f_postorder, only_type_index)
.Mutate(ir_node);
}
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst; static FMutateExpr inst; return inst;
} }
......
...@@ -26,26 +26,6 @@ ...@@ -26,26 +26,6 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
void Visit(const ObjectRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
IRVisitor::Visit(node);
f_(node);
}
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
};
void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst; static FVisit inst; return inst;
......
...@@ -31,7 +31,6 @@ TEST(IRF, Basic) { ...@@ -31,7 +31,6 @@ TEST(IRF, Basic) {
auto z = x + 1; auto z = x + 1;
NodeFunctor<int(const ObjectRef& n, int b)> f; NodeFunctor<int(const ObjectRef& n, int b)> f;
LOG(INFO) << "x";
f.set_dispatch<Variable>([](const ObjectRef& n, int b) { f.set_dispatch<Variable>([](const ObjectRef& n, int b) {
return b; return b;
}); });
...@@ -101,6 +100,98 @@ TEST(IRF, ExprVisit) { ...@@ -101,6 +100,98 @@ TEST(IRF, ExprVisit) {
CHECK_EQ(v.count, 1); CHECK_EQ(v.count, 1);
} }
TEST(IRF, StmtVisitor) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
class MyVisitor
: public StmtExprVisitor {
public:
int count = 0;
// implementation
void VisitExpr_(const Variable* op) final {
++count;
}
};
MyVisitor v;
auto fmaketest = [&]() {
auto z = x + 1;
Stmt body = Evaluate::make(z);
Var buffer("b", DataType::Handle());
return Allocate::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
};
v(fmaketest());
CHECK_EQ(v.count, 3);
}
TEST(IRF, StmtMutator) {
using namespace tvm;
using namespace tvm::ir;
Var x("x");
class MyVisitor
: public ir::StmtMutator,
public ir::ExprMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
// implementation
Expr VisitExpr_(const Add* op) final {
return op->a;
}
Expr VisitExpr(const Expr& expr) final {
return ExprMutator::VisitExpr(expr);
}
};
auto fmaketest = [&]() {
auto z = x + 1;
Stmt body = Evaluate::make(z);
Var buffer("b", DataType::Handle());
return Allocate::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
};
MyVisitor v;
{
auto body = fmaketest();
Stmt body2 = Evaluate::make(1);
Stmt bref = body.as<Allocate>()->body;
auto* extentptr = body.as<Allocate>()->extents.get();
Array<Stmt> arr{std::move(body), body2, body2};
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() == arrptr);
// inplace update body
CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
CHECK(arr[0].as<Allocate>()->extents.get() == extentptr);
// copy because there is additional refs
CHECK(!arr[0].as<Allocate>()->body.same_as(bref));
CHECK(arr[0].as<Allocate>()->body.as<Evaluate>()->value.same_as(x));
CHECK(bref.as<Evaluate>()->value.as<Add>());
}
{
Array<Stmt> arr{fmaketest()};
// mutate array get reference by another one, triiger copy.
Array<Stmt> arr2 = arr;
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() != arrptr);
CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
CHECK(!arr2[0].as<Allocate>()->extents[1].same_as(x));
// mutate but no content change.
arr2 = arr;
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr2.get() == arr.get());
}
{
auto body = Evaluate::make(Call::make(DataType::Int(32), "xyz", {x + 1}, Call::Extern));
auto res = v(std::move(body));
CHECK(res.as<Evaluate>()->value.as<Call>()->args[0].same_as(x));
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment