Unverified Commit 59c70a0e by Tianqi Chen Committed by GitHub

[RELAY][[PASS] Consolidate ForwardRewrite pass. (#2124)

parent de3b63a4
......@@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
/*!
* \brief Base class of the temporary expression.
*
* TempExprs are pass specific expression that can be
* useful to define intermediate result in the
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
class TempExprNode : public ExprNode {
public:
/*!
* \brief Convert the expression to a normal(non-temp) Expr.
* \return The corresponding normal(non-temp) expression.
*/
virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr";
TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);
// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
......
......@@ -276,6 +276,16 @@ class GenericOpMap {
*/
template <typename ValueType>
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
inline ValueType get(const Expr& expr, ValueType def_value) const;
private:
friend class OpRegistry;
......@@ -313,6 +323,14 @@ class OpMap {
* \return the const reference to the content value.
*/
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
*/
inline ValueType get(const Expr& expr, ValueType def_value) const;
private:
friend class Op;
......@@ -497,6 +515,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
}
template <typename ValueType>
inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
CHECK(expr.defined());
if (const OpNode* op = expr.as<OpNode>()) {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return value;
}
} else {
return value;
}
}
template <typename ValueType>
inline int OpMap<ValueType>::count(const Op& op) const {
return map_.count(op);
}
......@@ -505,12 +538,19 @@ template <typename ValueType>
inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
return map_[op];
}
template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Op& op,
ValueType def_value) const {
return map_.get<ValueType>(op, def_value);
}
template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Expr& expr,
ValueType def_value) const {
return map_.get<ValueType>(expr, def_value);
}
/*!
* \brief Check that an expression is a "primtive operator".
*
......
......@@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Attrs& attrs,
const Array<Tensor>& outs,
const Target& target)>;
/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
* We can make use of the op and type information.
* \param new_args The new arguments (some of them could be TempExpr).
* \param ctx Optional context information about ref_call.
* \return The rewriten result call, can also return nullptr,
* which indicate the rewriter should use the default fallback
* rule that realizes all its input and compose the call.
*
* \note When we register the function, we can register
* a different signature with ctx to be a specific node type.
*/
using FForwardRewrite = runtime::TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
......@@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr);
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
......
......@@ -73,6 +73,8 @@ class PackedFunc {
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*! \brief constructor from null */
PackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
......
......@@ -88,23 +88,6 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
}
/*!
* \param Get function from op_map.
* \param op_map The OpMap.
* \param op The operator being called.
* \tparam ValueType the content value type.
* \return The result value map.
*/
template<typename ValueType>
ValueType GetFunc(const OpMap<ValueType>& op_map,
const Expr& op) {
if (const OpNode* opnode = op.as<OpNode>()) {
return op_map.get(GetRef<Op>(opnode), ValueType());
} else {
return ValueType();
}
}
/*!
* \brief Preparation function for pass scale forward.
* \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output.
......@@ -114,7 +97,7 @@ using FForwardPrep = runtime::TypedPackedFunc<
Array<AxesSet> (const Call& call, const AxesSet& out_scale_axes)>;
/*! \brief Axis scale tuple. */
class STupleNode : public Node {
class ScaledExprNode : public TempExprNode {
public:
/*! \brief The value */
Expr value;
......@@ -123,29 +106,26 @@ class STupleNode : public Node {
/*! \brief The scaling factor */
Expr scale = NullValue<Expr>();
Expr Realize() const final {
CHECK(!axes.defined())
<< "outstanding scale";
return value;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("axes", &axes);
v->Visit("scale", &scale);
}
static constexpr const char* _type_key = "relay.fold_scale_axis.STupleNode";
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, Node);
static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr";
TVM_DECLARE_NODE_TYPE_INFO(ScaledExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(STuple, STupleNode, NodeRef);
/*!
* \brief The transform function, transform an old call to
* a new one given the new args.
* \param ref_call Reference call node that represent the op and the types.
* \param expected_out_axes The scale axes allowed in the output.
* \param sargs The input arguments.
*/
using FForwardTransform = TypedPackedFunc<
STuple(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs)>;
using FForwardRewrite = TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expeced_out_axes)>;
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
......@@ -219,7 +199,7 @@ class ForwardPrep : private ExprVisitor {
out_axes = NullValue<AxesSet>();
}
// pass the message back to all the children it references.
auto f = GetFunc(fprep, call->op);
auto f = fprep.get(call->op, nullptr);
if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out_axes);
CHECK_EQ(in_axes.size(), call->args.size());
......@@ -261,87 +241,6 @@ class ForwardPrep : private ExprVisitor {
}
};
class ForwardTransformer : private ExprMutator {
public:
// Transform expression.
Expr Fold(Expr expr) {
expected_scale_axes_ =
ForwardPrep().Prepare(expr);
return this->Mutate(expr);
}
private:
// Valid axes on each node.
std::unordered_map<const Node*, AxesSet> expected_scale_axes_;
std::unordered_map<const Node*, STuple> scale_memo_;
// If user simply call mutate,
// then only Expr is returned and we cannot
// accept outstanding scales.
Expr VisitExpr(const Expr& expr) final {
Expr res = ExprMutator::VisitExpr(expr);
CHECK(!scale_memo_.count(expr.get()))
<< "Outstanding scale";
return res;
}
STuple GetSTuple(const Expr& expr) {
Expr res = ExprMutator::VisitExpr(expr);
auto it = scale_memo_.find(expr.get());
if (it != scale_memo_.end()) {
CHECK(it->second->value.same_as(res));
return it->second;
} else {
auto node = make_node<STupleNode>();
node->value = res;
return STuple(node);
}
}
Expr VisitExpr_(const CallNode* call_node) final {
static const auto& ftransform =
Op::GetAttr<FForwardTransform>("FScaleAxisForwardTransform");
auto new_op = this->Mutate(call_node->op);
bool has_scale = false;
bool unchanged = call_node->op.same_as(new_op);
Array<STuple> call_sargs;
Array<Expr> call_args;
for (auto arg : call_node->args) {
STuple new_sarg = this->GetSTuple(arg);
unchanged &= new_sarg->value.same_as(arg);
if (new_sarg->axes.defined()) has_scale = true;
call_sargs.push_back(new_sarg);
call_args.push_back(new_sarg->value);
}
// get expected scale axes.
AxesSet expected_out_axes;
auto axis_it = expected_scale_axes_.find(call_node);
if (axis_it != expected_scale_axes_.end()) {
expected_out_axes = axis_it->second;
}
// propagation function
auto f = GetFunc(ftransform, call_node->op);
if (f != nullptr) {
STuple sret = f(GetRef<Call>(call_node), expected_out_axes, call_sargs);
if (sret.defined()) {
if (sret->axes.defined()) {
scale_memo_[call_node] = sret;
}
return sret->value;
}
}
// normal path
CHECK(!has_scale) << "Outstanding scale, on op=" << call_node->op;
if (unchanged) {
return GetRef<Expr>(call_node);
} else {
return CallNode::make(
new_op, call_args, call_node->attrs, call_node->type_args);
}
}
};
//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
......@@ -351,30 +250,31 @@ Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return {out};
}
STuple ReluForwardTransform(const Call& ref_call,
const AxesSet& expected_axes,
const Array<STuple>& sargs) {
if (!sargs[0]->axes.defined()) return STuple();
Expr ReluForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_axes) {
const auto* input = new_args[0].as<ScaledExprNode>();
if (input == nullptr) return Expr(nullptr);
// return transformed conv2d
auto rnode = make_node<STupleNode>();
auto rnode = make_node<ScaledExprNode>();
rnode->value = CallNode::make(
ref_call->op, {sargs[0]->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[0]->scale;
rnode->axes = sargs[0]->axes;
return STuple(rnode);
ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = input->scale;
rnode->axes = input->axes;
return Expr(rnode);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
// AddSub
Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) {
......@@ -391,69 +291,69 @@ Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) {
}
}
STuple AddSubForwardTransform(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs) {
if (!sargs[0]->axes.defined() && !sargs[1]->axes.defined()) {
return STuple();
}
Expr AddSubForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_out_axes) {
const auto* slhs = new_args[0].as<ScaledExprNode>();
const auto* srhs = new_args[1].as<ScaledExprNode>();
if (!slhs && !srhs) return Expr();
const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
auto rnode = make_node<ScaledExprNode>();
auto rnode = make_node<STupleNode>();
if (sargs[0]->axes.defined()) {
CHECK(!sargs[1]->axes.defined());
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, sargs[0]->axes));
if (slhs != nullptr) {
CHECK(srhs == nullptr);
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
Expr scale = ExpandBiasToMatchAxis(
sargs[0]->scale, tlhs->shape.size(), sargs[0]->axes);
Expr rhs = Divide(sargs[1]->value, scale);
rnode->value = CallNode::make(ref_call->op, {sargs[0]->value, rhs},
slhs->scale, tlhs->shape.size(), slhs->axes);
Expr rhs = Divide(new_args[1], scale);
rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs},
ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[0]->scale;
rnode->axes = sargs[0]->axes;
rnode->scale = slhs->scale;
rnode->axes = slhs->axes;
} else {
CHECK(sargs[1]->axes.defined());
CHECK(sargs[0]->axes.defined());
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, sargs[1]->axes));
CHECK(slhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ExpandBiasToMatchAxis(
sargs[1]->scale, trhs->shape.size(), sargs[1]->axes);
Expr lhs = Divide(sargs[0]->value, scale);
rnode->value = CallNode::make(ref_call->op, {lhs, sargs[1]->value},
srhs->scale, trhs->shape.size(), srhs->axes);
Expr lhs = Divide(new_args[0], scale);
rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value},
ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[1]->scale;
rnode->axes = sargs[1]->axes;
rnode->scale = srhs->scale;
rnode->axes = srhs->axes;
}
return STuple(rnode);
return Expr(rnode);
}
RELAY_REGISTER_OP("add")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
RELAY_REGISTER_OP("add")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
RELAY_REGISTER_OP("subtract")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
RELAY_REGISTER_OP("subtract")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
// Producer operators
// Multiply produces the scale-axis pair.
STuple MultiplyForwardTransform(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs) {
if (!expected_out_axes.defined()) return STuple();
Expr MultiplyForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_out_axes) {
if (!expected_out_axes.defined()) return Expr();
// TODO(tvm-team) allow same axes accumulation
// not as important because it is less common in nn.
CHECK(!sargs[0]->axes.defined());
CHECK(!sargs[1]->axes.defined());
const auto* slhs = new_args[0].as<ScaledExprNode>();
const auto* srhs = new_args[1].as<ScaledExprNode>();
CHECK(!slhs && !srhs);
const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
Expr lhs = sargs[0]->value;
Expr rhs = sargs[1]->value;
auto rnode = make_node<STupleNode>();
Expr lhs = new_args[0];
Expr rhs = new_args[1];
auto rnode = make_node<ScaledExprNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) {
rnode->value = lhs;
rnode->scale = rhs;
......@@ -463,11 +363,11 @@ STuple MultiplyForwardTransform(const Call& ref_call,
rnode->scale = lhs;
rnode->axes = expected_out_axes;
}
return STuple(rnode);
return Expr(rnode);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", MultiplyForwardTransform);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
// Consumer operators
// Conv2D send out requirement of axis folding.
......@@ -500,13 +400,14 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
}
// Conv2D consumes the scale axis during transformation.
STuple Conv2DForwardTransform(const Call& ref_call,
const AxesSet& expected_axes,
const Array<STuple>& sargs) {
Expr Conv2DForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_axes) {
// if data do not have scale, normal transform path.
STuple sdata = sargs[0];
if (!sdata->scale.defined()) return STuple();
CHECK(sdata->axes.defined());
const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>();
if (sdata == nullptr) return Expr();
if (sweight != nullptr) return Expr();
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
......@@ -524,7 +425,8 @@ STuple Conv2DForwardTransform(const Call& ref_call,
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr weight = sargs[1]->value;
Expr weight = new_args[1];
// match the ic_axis
if (is_depthwise_conv2d) {
......@@ -537,21 +439,30 @@ STuple Conv2DForwardTransform(const Call& ref_call,
weight = Multiply(weight, scale);
}
// return transformed conv2d
auto rnode = make_node<STupleNode>();
rnode->value = CallNode::make(
return CallNode::make(
ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
return STuple(rnode);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", Conv2DForwardTransform);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
Expr ForwardFoldScaleAxis(Expr data) {
return ForwardTransformer().Fold(data);
auto expected_scale_axes =
ForwardPrep().Prepare(data);
auto fcontext = [&](const Call& call) -> NodeRef{
auto it = expected_scale_axes.find(call.get());
if (it != expected_scale_axes.end()) {
return it->second;
} else {
return NodeRef(nullptr);
}
};
return ForwardRewrite(
data, "FScaleAxisForwardRewrite", fcontext);
}
// Expose the FoldScaleAxisFoward
......@@ -602,7 +513,7 @@ class BackwardPrep : private ExprVisitor {
ExprVisitor::VisitExpr_(call);
static const auto& fprep =
Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
auto f = GetFunc(fprep, call->op);
auto f = fprep.get(call->op, nullptr);
if (f == nullptr) return;
auto rit = ref_counter_.find(call);
CHECK(rit != ref_counter_.end());
......@@ -705,7 +616,7 @@ Expr BackwardTransformerNode::Transform(
const CallNode* call_node, AxesSet axes, Expr scale) {
static const auto& ftransform =
Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = GetFunc(ftransform, call_node->op);
auto f = ftransform.get(call_node->op, nullptr);
if (f != nullptr) {
return f(GetRef<Call>(call_node),
axes,
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
namespace tvm {
namespace relay {
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator {
public:
Expr Realize(Expr expr) {
return VisitExpr(expr);
}
private:
Expr VisitExpr(const Expr& expr) final {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Expr res;
if (const auto* temp = expr.as_derived<TempExprNode>()) {
res = temp->Realize();
} else {
res = ExprFunctor::VisitExpr(expr);
}
memo_[res] = res;
return res;
}
}
};
class ForwardRewriter : private ExprMutator {
public:
ForwardRewriter(const OpMap<FForwardRewrite>& rewrite_map,
std::function<NodeRef(const Call&)> fcontext)
: rewrite_map_(rewrite_map),
fcontext_(fcontext) {
}
// Transform expression.
Expr Rewrite(Expr expr) {
return this->VisitExpr(expr);
}
private:
// The rewrite rule.
const OpMap<FForwardRewrite>& rewrite_map_;
// The context.
std::function<NodeRef(const Call&)> fcontext_{nullptr};
// internal realizer
TempRealizer realizer_;
Expr VisitExpr(const Expr& expr) final {
// by default always realize.
return realizer_.Realize(ExprMutator::VisitExpr(expr));
}
// Visit and allow non-realized version.
Expr GetTempExpr(const Expr& expr) {
return ExprMutator::VisitExpr(expr);
}
// Automatic fold TupleGetItem.
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr tuple = this->GetTempExpr(op->tuple);
if (const auto* ptuple = tuple.as<TupleNode>()) {
return ptuple->fields[op->index];
} else {
if (tuple.same_as(op->tuple)) {
return GetRef<Expr>(op);
} else {
return TupleGetItemNode::make(tuple, op->index);
}
}
}
Expr VisitExpr_(const CallNode* call_node) final {
const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite = rewrite_map_.get(call_node->op, nullptr);
auto new_op = this->Mutate(call_node->op);
bool unchanged = call_node->op.same_as(new_op);
Array<Expr> call_args;
for (auto arg : call_node->args) {
Expr new_arg = this->GetTempExpr(arg);
if (frewrite == nullptr) {
new_arg = realizer_.Realize(new_arg);
}
unchanged &= new_arg.same_as(arg);
call_args.push_back(new_arg);
}
// try to rewrite.
if (frewrite != nullptr) {
Expr res = frewrite(
ref_call, call_args,
fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));
if (res.defined()) return res;
// abort, use old rule
for (size_t i = 0; i < call_args.size(); ++i) {
Expr arg = call_args[i];
Expr new_arg = realizer_.Realize(arg);
if (!arg.same_as(new_arg)) {
call_args.Set(i, new_arg);
unchanged = false;
}
}
}
if (unchanged) return ref_call;
return CallNode::make(
new_op, call_args, call_node->attrs, call_node->type_args);
}
};
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_name,
std::function<NodeRef(const Call&)> fcontext) {
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(rewrite_map, fcontext).Rewrite(expr);
}
} // namespace relay
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment