Unverified Commit a3530f8f by Tianqi Chen Committed by GitHub

[RELAY] Add multiref trigger to ForwardRewrite (#2168)

parent 0a1f3d41
......@@ -164,11 +164,14 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level);
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr);
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
......
......@@ -320,6 +320,23 @@ class TupleGetItem(Expr):
_make.TupleGetItem, tuple_value, index)
class TempExpr(Expr):
"""Baseclass of all TempExpr.
TempExprs are pass specific expression that can be
useful to define intermediate result in the
rewriting pass such as layout or type transformation.
"""
def realize(self):
"""Convert the expression to a normal(non-temp) Expr.
Returns
-------
The corresponding normal expression.
"""
return _expr.TempExprRealize(self)
class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
......
......@@ -258,5 +258,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TempExpr temp = args[0];
*ret = temp->Realize();
});
} // namespace relay
} // namespace tvm
......@@ -7,6 +7,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include "pass_util.h"
namespace tvm {
namespace relay {
......@@ -42,13 +43,18 @@ class TempRealizer : private ExprMutator {
class ForwardRewriter : private ExprMutator {
public:
ForwardRewriter(const OpMap<FForwardRewrite>& rewrite_map,
std::function<NodeRef(const Call&)> fcontext)
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_map_(rewrite_map),
fcontext_(fcontext) {
fcontext_(fcontext),
fmulti_ref_trigger_(fmulti_ref_trigger) {
}
// Transform expression.
Expr Rewrite(Expr expr) {
if (fmulti_ref_trigger_ != nullptr) {
ref_counter_ = GetExprRefCount(expr);
}
return this->VisitExpr(expr);
}
......@@ -57,6 +63,10 @@ class ForwardRewriter : private ExprMutator {
const OpMap<FForwardRewrite>& rewrite_map_;
// The context.
std::function<NodeRef(const Call&)> fcontext_{nullptr};
// The multiple reference trigger
std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
// Internal ref counter
std::unordered_map<const Node*, size_t> ref_counter_;
// internal realizer
TempRealizer realizer_;
......@@ -67,8 +77,18 @@ class ForwardRewriter : private ExprMutator {
// Visit and allow non-realized version.
Expr GetTempExpr(const Expr& expr) {
if (fmulti_ref_trigger_ != nullptr) {
Expr ret = ExprMutator::VisitExpr(expr);
auto it = ref_counter_.find(expr.get());
CHECK(it != ref_counter_.end());
if (it->second > 1) {
ret = fmulti_ref_trigger_(ret);
}
return ret;
} else {
return ExprMutator::VisitExpr(expr);
}
}
// Automatic fold TupleGetItem.
Expr VisitExpr_(const TupleGetItemNode* op) final {
......@@ -124,9 +144,12 @@ class ForwardRewriter : private ExprMutator {
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_name,
std::function<NodeRef(const Call&)> fcontext) {
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(rewrite_map, fcontext).Rewrite(expr);
return ForwardRewriter(rewrite_map,
fcontext,
fmulti_ref_trigger).Rewrite(expr);
}
} // namespace relay
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment