Unverified Commit 59c70a0e by Tianqi Chen Committed by GitHub

[RELAY][[PASS] Consolidate ForwardRewrite pass. (#2124)

parent de3b63a4
...@@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode { ...@@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
/*!
* \brief Base class of the temporary expression.
*
* TempExprs are pass specific expression that can be
* useful to define intermediate result in the
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
class TempExprNode : public ExprNode {
public:
/*!
* \brief Convert the expression to a normal(non-temp) Expr.
* \return The corresponding normal(non-temp) expression.
*/
virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr";
TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);
// implementataions // implementataions
template<typename TTypeNode> template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const { inline const TTypeNode* ExprNode::type_as() const {
......
...@@ -276,6 +276,16 @@ class GenericOpMap { ...@@ -276,6 +276,16 @@ class GenericOpMap {
*/ */
template <typename ValueType> template <typename ValueType>
inline ValueType get(const Op& op, ValueType def_value) const; inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
inline ValueType get(const Expr& expr, ValueType def_value) const;
private: private:
friend class OpRegistry; friend class OpRegistry;
...@@ -313,6 +323,14 @@ class OpMap { ...@@ -313,6 +323,14 @@ class OpMap {
* \return the const reference to the content value. * \return the const reference to the content value.
*/ */
inline ValueType get(const Op& op, ValueType def_value) const; inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
*/
inline ValueType get(const Expr& expr, ValueType def_value) const;
private: private:
friend class Op; friend class Op;
...@@ -497,6 +515,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { ...@@ -497,6 +515,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
} }
template <typename ValueType> template <typename ValueType>
inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
CHECK(expr.defined());
if (const OpNode* op = expr.as<OpNode>()) {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return value;
}
} else {
return value;
}
}
template <typename ValueType>
inline int OpMap<ValueType>::count(const Op& op) const { inline int OpMap<ValueType>::count(const Op& op) const {
return map_.count(op); return map_.count(op);
} }
...@@ -505,12 +538,19 @@ template <typename ValueType> ...@@ -505,12 +538,19 @@ template <typename ValueType>
inline ValueType OpMap<ValueType>::operator[](const Op& op) const { inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
return map_[op]; return map_[op];
} }
template <typename ValueType> template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Op& op, inline ValueType OpMap<ValueType>::get(const Op& op,
ValueType def_value) const { ValueType def_value) const {
return map_.get<ValueType>(op, def_value); return map_.get<ValueType>(op, def_value);
} }
template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Expr& expr,
ValueType def_value) const {
return map_.get<ValueType>(expr, def_value);
}
/*! /*!
* \brief Check that an expression is a "primtive operator". * \brief Check that an expression is a "primtive operator".
* *
......
...@@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc< ...@@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Attrs& attrs, Schedule(const Attrs& attrs,
const Array<Tensor>& outs, const Array<Tensor>& outs,
const Target& target)>; const Target& target)>;
/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
* We can make use of the op and type information.
* \param new_args The new arguments (some of them could be TempExpr).
* \param ctx Optional context information about ref_call.
* \return The rewriten result call, can also return nullptr,
* which indicate the rewriter should use the default fallback
* rule that realizes all its input and compose the call.
*
* \note When we register the function, we can register
* a different signature with ctx to be a specific node type.
*/
using FForwardRewrite = runtime::TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx)>;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_ #endif // TVM_RELAY_OP_ATTR_TYPES_H_
...@@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr); ...@@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr);
*/ */
Expr FuseOps(const Expr& expr, int fuse_opt_level); Expr FuseOps(const Expr& expr, int fuse_opt_level);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr);
/*! \brief A hashing structure in the style of std::hash. */ /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash { struct StructuralHash {
......
...@@ -73,6 +73,8 @@ class PackedFunc { ...@@ -73,6 +73,8 @@ class PackedFunc {
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>; using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */ /*! \brief default constructor */
PackedFunc() {} PackedFunc() {}
/*! \brief constructor from null */
PackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*! /*!
* \brief constructing a packed function from a std::function. * \brief constructing a packed function from a std::function.
* \param body the internal container of packed function. * \param body the internal container of packed function.
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
namespace tvm {
namespace relay {
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator {
public:
Expr Realize(Expr expr) {
return VisitExpr(expr);
}
private:
Expr VisitExpr(const Expr& expr) final {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Expr res;
if (const auto* temp = expr.as_derived<TempExprNode>()) {
res = temp->Realize();
} else {
res = ExprFunctor::VisitExpr(expr);
}
memo_[res] = res;
return res;
}
}
};
class ForwardRewriter : private ExprMutator {
public:
ForwardRewriter(const OpMap<FForwardRewrite>& rewrite_map,
std::function<NodeRef(const Call&)> fcontext)
: rewrite_map_(rewrite_map),
fcontext_(fcontext) {
}
// Transform expression.
Expr Rewrite(Expr expr) {
return this->VisitExpr(expr);
}
private:
// The rewrite rule.
const OpMap<FForwardRewrite>& rewrite_map_;
// The context.
std::function<NodeRef(const Call&)> fcontext_{nullptr};
// internal realizer
TempRealizer realizer_;
Expr VisitExpr(const Expr& expr) final {
// by default always realize.
return realizer_.Realize(ExprMutator::VisitExpr(expr));
}
// Visit and allow non-realized version.
Expr GetTempExpr(const Expr& expr) {
return ExprMutator::VisitExpr(expr);
}
// Automatic fold TupleGetItem.
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr tuple = this->GetTempExpr(op->tuple);
if (const auto* ptuple = tuple.as<TupleNode>()) {
return ptuple->fields[op->index];
} else {
if (tuple.same_as(op->tuple)) {
return GetRef<Expr>(op);
} else {
return TupleGetItemNode::make(tuple, op->index);
}
}
}
Expr VisitExpr_(const CallNode* call_node) final {
const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite = rewrite_map_.get(call_node->op, nullptr);
auto new_op = this->Mutate(call_node->op);
bool unchanged = call_node->op.same_as(new_op);
Array<Expr> call_args;
for (auto arg : call_node->args) {
Expr new_arg = this->GetTempExpr(arg);
if (frewrite == nullptr) {
new_arg = realizer_.Realize(new_arg);
}
unchanged &= new_arg.same_as(arg);
call_args.push_back(new_arg);
}
// try to rewrite.
if (frewrite != nullptr) {
Expr res = frewrite(
ref_call, call_args,
fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));
if (res.defined()) return res;
// abort, use old rule
for (size_t i = 0; i < call_args.size(); ++i) {
Expr arg = call_args[i];
Expr new_arg = realizer_.Realize(arg);
if (!arg.same_as(new_arg)) {
call_args.Set(i, new_arg);
unchanged = false;
}
}
}
if (unchanged) return ref_call;
return CallNode::make(
new_op, call_args, call_node->attrs, call_node->type_args);
}
};
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_name,
std::function<NodeRef(const Call&)> fcontext) {
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(rewrite_map, fcontext).Rewrite(expr);
}
} // namespace relay
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment