Commit 0a1f3d41 by ziheng Committed by Tianqi Chen

[PASS] PostOrderVisit (#2169)

parent b5e0d790
...@@ -151,6 +151,19 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> { ...@@ -151,6 +151,19 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
} }
}; };
// Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;
TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -182,6 +182,14 @@ class ExprMutator ...@@ -182,6 +182,14 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
}; };
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
/* /*
* \brief Bind function parameters or free variables. * \brief Bind function parameters or free variables.
* *
......
...@@ -10,6 +10,19 @@ from . import _make ...@@ -10,6 +10,19 @@ from . import _make
from .expr import Expr from .expr import Expr
from .ty import Type from .ty import Type
def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
only once.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
return _ir_pass.post_order_visit(expr, fvisit)
def infer_type(expr, mod=None): def infer_type(expr, mod=None):
"""Infer the type of expr under the context of mod. """Infer the type of expr under the context of mod.
......
...@@ -228,6 +228,36 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { ...@@ -228,6 +228,36 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
void ExprVisitor::VisitType(const Type& t) { return; } void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply
class ExprApplyVisit : public ExprVisitor {
public:
explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
void VisitExpr(const Expr& e) final {
if (visited_.count(e.get()) != 0) return;
visited_.insert(e.get());
ExprVisitor::VisitExpr(e);
f_(e);
}
private:
std::function<void(const Expr&)> f_;
std::unordered_set<const Node*> visited_;
};
void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
ExprApplyVisit(fvisit).VisitExpr(e);
}
TVM_REGISTER_API("relay._ir_pass.post_order_visit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
PostOrderVisit(args[0], [f](const Expr& n) {
f(n);
});
});
// Implement bind. // Implement bind.
class ExprBinder : public ExprMutator { class ExprBinder : public ExprMutator {
public: public:
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include "../type_relations.h" #include "../type_relations.h"
#include "../op_common.h" #include "../op_common.h"
...@@ -89,19 +90,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") ...@@ -89,19 +90,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy")
.add_type_rel("Identity", IdentityRel) .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
// relay.clip
// Clip TVM_REGISTER_NODE_TYPE(ClipAttrs);
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;
TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};
TVM_REGISTER_API("relay.op._make.clip") TVM_REGISTER_API("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) { .set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment