Commit 05ea6018 by Tianqi Chen Committed by ziheng

[RELAY][PASS] FoldScaleAxis Forward (#2020)

* [RELAY][PASS] FoldScaleAxis Forward

* Introduce helper function type_as

* Update per review comment

* Fix according to comments
parent 26e3aa19
...@@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
/*! \brief Attributes used in squeeze operators */ /*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> { struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes; // use axis to make the name numpy compatible.
Array<Integer> axis;
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") { TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes) TVM_ATTR_FIELD(axis)
.describe("The axes to squeeze in the input tensor." .describe("The axis to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;" "If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed." "Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.") "It is an error if an axis does not has dimension 1.")
.set_default(Array<IndexExpr>({})); .set_default(NullValue<Array<Integer> >());
} }
}; // struct SqueezeAttrs }; // struct SqueezeAttrs
......
...@@ -40,6 +40,18 @@ class ExprNode : public RelayNode { ...@@ -40,6 +40,18 @@ class ExprNode : public RelayNode {
"field for this node"; "field for this node";
return this->checked_type_; return this->checked_type_;
} }
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr"; static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
...@@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode { ...@@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->type_key();
return node;
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -150,7 +150,14 @@ class ExprVisitor ...@@ -150,7 +150,14 @@ class ExprVisitor
class ExprMutator class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public: public:
Expr Mutate(const Expr& expr); /*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
}
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override; Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override; Expr VisitExpr_(const GlobalVarNode* op) override;
...@@ -161,7 +168,8 @@ class ExprMutator ...@@ -161,7 +168,8 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions. /*!
* \brief Used to visit the types inside of expressions.
* *
* Can be overloaded to transform the types in arbitrary * Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type * ways, one way would be to define a sub-class of type
...@@ -169,7 +177,7 @@ class ExprMutator ...@@ -169,7 +177,7 @@ class ExprMutator
*/ */
virtual Type VisitType(const Type& t); virtual Type VisitType(const Type& t);
private: protected:
/*! \brief Internal map used for memoization. */ /*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
}; };
......
...@@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode { ...@@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode {
v->Visit("support_level", &support_level); v->Visit("support_level", &support_level);
} }
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
bool IsPrimitiveOp() const {
if (is_primitive_ != -1) return is_primitive_ != 0;
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
return is_primitive_ != 0;
}
static constexpr const char* _type_key = "relay.Op"; static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
...@@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode { ...@@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode {
// friend class // friend class
friend class GenericOpMap; friend class GenericOpMap;
friend class OpRegistry; friend class OpRegistry;
friend bool IsPrimitiveOp(const Expr&);
// Program internal unique index of operator. // Program internal unique index of operator.
// Used to help index the program. // Used to help index the program.
uint32_t index_{0}; uint32_t index_{0};
// whether this is a primitive op. -1 means unknown.
mutable int is_primitive_{-1};
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
}; };
/*! /*!
...@@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op, ...@@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
*/ */
inline bool IsPrimitiveOp(const Expr& expr) { inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>(); const auto* op = expr.as<OpNode>();
return op != nullptr && op->IsPrimitiveOp();
if (!op) {
return false;
}
const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
} }
} // namespace relay } // namespace relay
......
...@@ -10,6 +10,7 @@ from . import _make ...@@ -10,6 +10,7 @@ from . import _make
from .expr import Expr from .expr import Expr
from .ty import Type from .ty import Type
def infer_type(expr, env=None): def infer_type(expr, env=None):
"""Infer the type of expr under the context of env. """Infer the type of expr under the context of env.
...@@ -30,6 +31,23 @@ def infer_type(expr, env=None): ...@@ -30,6 +31,23 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env) return _ir_pass.infer_type(expr, env)
def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
"""
return _ir_pass.forward_fold_scale_axis(expr)
def well_formed(expr): def well_formed(expr):
"""Check that each Var is only bound once (well formed). """Check that each Var is only bound once (well formed).
...@@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs): ...@@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs):
""" """
return bool(_make._alpha_equal(lhs, rhs)) return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs): def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence. """Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that The difference between this and alpha-equality is that
...@@ -170,6 +189,7 @@ def graph_equal(lhs, rhs): ...@@ -170,6 +189,7 @@ def graph_equal(lhs, rhs):
""" """
return bool(_make._graph_equal(lhs, rhs)) return bool(_make._graph_equal(lhs, rhs))
def structural_hash(value): def structural_hash(value):
"""Hash a Relay expression structurally. """Hash a Relay expression structurally.
......
...@@ -49,27 +49,25 @@ def transpose(data, axes=None): ...@@ -49,27 +49,25 @@ def transpose(data, axes=None):
return _make.transpose(data, list(axes)) return _make.transpose(data, list(axes))
def squeeze(data, axes=None): def squeeze(data, axis=None):
"""Squeeze axes in the array. """Squeeze axes in the array.
Parameters Parameters
---------- ----------
data : relay.Expr data : tvm.relay.Expr
The input data to the operator. The input data to the operator.
axes : None or List[int] axis : None or List[int]
Axes to remove. The set of axes to remove.
If axes = [] or = None, remove all axis of dimensions 1. If axis = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes. If any specified axis has dimension that does not equal 1, it is an error.
If any axis in axes has dimension that does not equal 1, it is an error.
Returns Returns
------- -------
result : relay.Expr result : tvm.relay.Expr
The squeezed result. The squeezed result.
""" """
axes = axes or [] return _make.squeeze(data, axis)
return _make.squeeze(data, list(axes))
def reshape(data, newshape): def reshape(data, newshape):
......
...@@ -296,13 +296,23 @@ class AlphaEqualHandler: ...@@ -296,13 +296,23 @@ class AlphaEqualHandler:
if (const CallNode* rhs = other.as<CallNode>()) { if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false; if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false; if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false; // skip type_args check for primitive ops.
bool is_primitive = IsPrimitiveOp(lhs->op);
if (!is_primitive) {
if (lhs->type_args.size() != rhs->type_args.size()) {
return false;
}
}
for (size_t i = 0; i < lhs->args.size(); ++i) { for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false; if (!ExprEqual(lhs->args[i], rhs->args[i])) {
return false;
}
} }
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false; if (!is_primitive) {
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
} }
return AttrEqual(lhs->attrs, rhs->attrs); return AttrEqual(lhs->attrs, rhs->attrs);
} else { } else {
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
Expr ExprMutator::Mutate(const Expr& expr) { Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr); auto it = this->memo_.find(expr);
if (it != this->memo_.end()) { if (it != this->memo_.end()) {
return it->second; return it->second;
} else { } else {
Expr new_expr = ExprMutator::VisitExpr(expr); Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr; memo_[expr] = new_expr;
return new_expr; return new_expr;
} }
......
...@@ -761,9 +761,9 @@ Examples:: ...@@ -761,9 +761,9 @@ Examples::
TVM_REGISTER_NODE_TYPE(SqueezeAttrs); TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
Expr MakeSqueeze(Expr data, Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) { Array<Integer> axis) {
auto attrs = make_node<SqueezeAttrs>(); auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes); attrs->axis = std::move(axis);
static const Op& op = Op::Get("squeeze"); static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
...@@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types, ...@@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types,
const auto* param = attrs.as<SqueezeAttrs>(); const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
std::vector<IndexExpr> result_shape; std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1 // if axes is None, squeeze all axes of dimension 1
if (param->axes.size() == 0) { if (!param->axis.defined()) {
for (const auto& e : data->shape) { for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e); const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
...@@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types, ...@@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types,
for (const auto& e : data->shape) { for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true)); original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
} }
for (const auto& e : param->axes) { for (const auto& e : param->axis) {
const int64_t* axis_ptr = as_const_int(e); original_shape.at(e->value).second = false;
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
} }
for (const auto p : original_shape) { for (const auto p : original_shape) {
if (p.second) { if (p.second) {
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file fold_scale_axis.cc
*
* \brief Fold axis scaling into weights of
* conv/dense operators.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "../op/nn/layout.h"
namespace tvm {
namespace relay {
/*!
* \brief namespace of fold scale axis
*
* Use namespace to reduce potential naming conflict.
*/
namespace fold_scale_axis {
using runtime::TypedPackedFunc;
// FoldScaleAxisFoward algorithm:
//
// The general idea is that we transform Expr to tuple of
// (value, axes, scale), where the final result satiesfies:
//
// result = value
// for i, k in enumerate(axes):
// k-ith dimension of result *= i-th dimension of scale
//
// Then we can propagate this signal along and fold the scale if necessary.
// However, it is possible that certain scale may never be consumed
// if there is no dense/conv2d that follows multiplication.
//
// In order to make sure all the scale we sent out can be consumed eventually,
// we run a backward "preparation phase", which propagates the demand
// of the potential axes scaling back to its input.
//
// The folding process is done in two steps:
// - Prepare phase: backward propagation of demand.
// - Transform phase: forward transformation,
/*!
* \brief sorted array axis, can also be nullptr.
*
* nullptr means no scaling request can be done.
*/
using AxesSet = Array<Integer>;
/*!
* \brief Merge two axis set together by taking
* intersection.
*
* \note The axes in a AxesSet should be sorted.
*
* \param lhs The left axis.
* \param rhs The right axis.
* \return The result of the inersection.
*/
AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
if (!lhs.defined()) return lhs;
if (!rhs.defined()) return rhs;
// This code relies on axes in a AxesSet to be sorted.
AxesSet ret;
size_t i = 0, j = 0;
while (i < lhs.size() && j < rhs.size()) {
if (lhs[i]->value < rhs[j]->value) {
++i;
} else if (lhs[i]->value > rhs[j]->value) {
++j;
} else {
ret.push_back(lhs[i]);
++i; ++j;
}
}
return ret;
}
/*!
* \param Get function from op_map.
* \param op_map The OpMap.
* \param op The operator being called.
* \tparam ValueType the content value type.
* \return The result value map.
*/
template<typename ValueType>
ValueType GetFunc(const OpMap<ValueType>& op_map,
const Expr& op) {
if (const OpNode* opnode = op.as<OpNode>()) {
return op_map.get(GetRef<Op>(opnode), ValueType());
} else {
return ValueType();
}
}
/*!
* \brief Preparation function for for pass scale forward.
* \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output.
* \return The result scaling on axes of the input.
*/
using FForwardPrep = runtime::TypedPackedFunc<
Array<AxesSet> (const Call& call, const AxesSet& out_scale_axes)>;
/*! \brief Axis scale tuple. */
class STupleNode : public Node {
public:
/*! \brief The value */
Expr value;
/*! \brief The axes to scale, can be nullptr(means no-scaling) */
AxesSet axes = NullValue<AxesSet>();
/*! \brief The scaling factor */
Expr scale = NullValue<Expr>();
void VisitAttrs(AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("axes", &axes);
v->Visit("scale", &scale);
}
static constexpr const char* _type_key = "relay.fold_scale_axis.STupleNode";
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, Node);
};
RELAY_DEFINE_NODE_REF(STuple, STupleNode, NodeRef);
/*!
* \brief The transform function, transform an old call to
* a new one given the new args.
* \param ref_call Reference call node that represent the op and the types.
* \param expected_out_axes The scale axes allowed in the output.
* \param sargs The input arguments.
*/
using FForwardTransform = TypedPackedFunc<
STuple(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs)>;
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
class FScaleAxisForwardPrep : private ExprVisitor {
public:
std::unordered_map<const Node*, AxesSet>
Prepare(const Expr& body) {
this->Update(body, NullValue<AxesSet>());
this->VisitExpr(body);
// flist is added in the Post-DFS order
// which is a special case of topological order.
// We reversely traverse the list to invoke the lazy functions.
// This act like a backprop of valid scale axis messages
for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) {
(*it)();
}
// return the created message;
return std::move(message_);
}
private:
// The invoke list
std::vector<std::function<void()> > flist_;
// The message on each node.
std::unordered_map<const Node*, AxesSet> message_;
// Update the message stored at node.
void Update(const Expr& node, const AxesSet& axes) {
// We run intersection of messages:
//
// %y = multiply(%x, %scale)
// %z1 = conv2d(%y, %w)
// %z2 = exp(%y)
//
// Consider the above code example,
// because %z2 will propagate null to %y,
// the AxesSet on %y is also null,
// and the forward folding won't be triggered.
const Node* key = node.get();
if (message_.count(key)) {
message_[key] = Intersect(message_[key], axes);
} else {
message_[key] = axes;
}
}
// Visitor pattern override.
void VisitExpr_(const LetNode* call) {
LOG(FATAL) << "FoldScaleAxis only accept dataflow-form";
}
void VisitExpr_(const FunctionNode* op) {
ExprVisitor::VisitExpr_(op);
auto flazy = [this, op] {
this->Update(op->body, NullValue<AxesSet>());
};
flist_.push_back(flazy);
}
void VisitExpr_(const CallNode* call) {
ExprVisitor::VisitExpr_(call);
// function to be lazily invoked
auto flazy = [this, call]() {
static const auto& fprep =
Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
// find the message send to this node.
auto it = message_.find(call);
AxesSet out_axes;
if (it != message_.end()) {
out_axes = it->second;
} else {
out_axes = NullValue<AxesSet>();
}
// pass the message back to all the children it references.
auto f = GetFunc(fprep, call->op);
if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out_axes);
CHECK_EQ(in_axes.size(), call->args.size());
for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], in_axes[i]);
}
} else {
for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], NullValue<AxesSet>());
}
}
};
flist_.push_back(flazy);
}
void VisitExpr_(const TupleNode* op) {
ExprVisitor::VisitExpr_(op);
// do not support pass scale through tuple for now.
auto flazy = [this, op]() {
for (const Expr& field : op->fields) {
this->Update(field, NullValue<AxesSet>());
}
};
flist_.push_back(flazy);
}
void VisitExpr_(const IfNode* op) {
ExprVisitor::VisitExpr_(op);
// do pass through condition
// by assigning NullValue<AxesSet>
// it means fuse signal cannot pass
// through into these subexpressions.
auto flazy = [this, op]() {
this->Update(op->cond, NullValue<AxesSet>());
this->Update(op->true_branch, NullValue<AxesSet>());
this->Update(op->false_branch, NullValue<AxesSet>());
};
flist_.push_back(flazy);
}
};
class FScaleAxisForwardTransform : private ExprMutator {
public:
// Transform expression.
Expr Transform(Expr expr) {
expected_scale_axes_ =
FScaleAxisForwardPrep().Prepare(expr);
return this->Mutate(expr);
}
private:
// Valid axes on each node.
std::unordered_map<const Node*, AxesSet> expected_scale_axes_;
std::unordered_map<const Node*, STuple> scale_memo_;
// If user simply call mutate,
// then only Expr is returned and we cannot
// accept outstanding scales.
Expr VisitExpr(const Expr& expr) final {
Expr res = ExprMutator::VisitExpr(expr);
CHECK(!scale_memo_.count(expr.get()))
<< "Outstanding scale";
return res;
}
STuple GetSTuple(const Expr& expr) {
Expr res = ExprMutator::VisitExpr(expr);
auto it = scale_memo_.find(expr.get());
if (it != scale_memo_.end()) {
CHECK(it->second->value.same_as(res));
return it->second;
} else {
auto node = make_node<STupleNode>();
node->value = res;
return STuple(node);
}
}
Expr VisitExpr_(const CallNode* call_node) final {
static const auto& ftransform =
Op::GetAttr<FForwardTransform>("FScaleAxisForwardTransform");
auto new_op = this->Mutate(call_node->op);
bool has_scale = false;
bool unchanged = call_node->op.same_as(new_op);
Array<STuple> call_sargs;
Array<Expr> call_args;
for (auto arg : call_node->args) {
STuple new_sarg = this->GetSTuple(arg);
unchanged &= new_sarg->value.same_as(arg);
if (new_sarg->axes.defined()) has_scale = true;
call_sargs.push_back(new_sarg);
call_args.push_back(new_sarg->value);
}
// get expected scale axes.
AxesSet expected_out_axes;
auto axis_it = expected_scale_axes_.find(call_node);
if (axis_it != expected_scale_axes_.end()) {
expected_out_axes = axis_it->second;
}
// propagation function
auto f = GetFunc(ftransform, call_node->op);
if (f != nullptr) {
STuple sret = f(GetRef<Call>(call_node), expected_out_axes, call_sargs);
if (sret.defined()) {
if (sret->axes.defined()) {
scale_memo_[call_node] = sret;
}
return sret->value;
}
}
// normal path
CHECK(!has_scale) << "Outstanding scale, on op=" << call_node->op;
if (unchanged) {
return GetRef<Expr>(call_node);
} else {
return CallNode::make(
new_op, call_args, call_node->attrs, call_node->type_args);
}
}
};
//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return {out};
}
STuple ReluForwardTransform(const Call& ref_call,
const AxesSet& expected_axes,
const Array<STuple>& sargs) {
if (!sargs[0]->axes.defined()) return STuple();
// return transformed conv2d
auto rnode = make_node<STupleNode>();
rnode->value = CallNode::make(
ref_call->op, {sargs[0]->value}, ref_call->attrs, {});
rnode->scale = sargs[0]->scale;
rnode->axes = sargs[0]->axes;
return STuple(rnode);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
// AddSub
Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
auto none = NullValue<AxesSet>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) {
return {out_axes, none};
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) {
return {none, out_axes};
} else {
return {none, none};
}
}
STuple AddSubForwardTransform(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs) {
if (!sargs[0]->axes.defined() && !sargs[1]->axes.defined()) {
return STuple();
}
const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
auto rnode = make_node<STupleNode>();
if (sargs[0]->axes.defined()) {
CHECK(!sargs[1]->axes.defined());
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, sargs[0]->axes));
Expr scale = ExpandBiasToMatchAxis(
sargs[0]->scale, tlhs->shape.size(), sargs[0]->axes);
Expr rhs = Divide(sargs[1]->value, scale);
rnode->value = CallNode::make(ref_call->op, {sargs[0]->value, rhs},
ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[0]->scale;
rnode->axes = sargs[0]->axes;
} else {
CHECK(sargs[1]->axes.defined());
CHECK(sargs[0]->axes.defined());
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, sargs[1]->axes));
Expr scale = ExpandBiasToMatchAxis(
sargs[1]->scale, trhs->shape.size(), sargs[1]->axes);
Expr lhs = Divide(sargs[0]->value, scale);
rnode->value = CallNode::make(ref_call->op, {lhs, sargs[1]->value},
ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[1]->scale;
rnode->axes = sargs[1]->axes;
}
return STuple(rnode);
}
RELAY_REGISTER_OP("add")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
RELAY_REGISTER_OP("add")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
RELAY_REGISTER_OP("subtract")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
RELAY_REGISTER_OP("subtract")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
// Producer operators
// Multiply produces the scale-axis pair.
STuple MultiplyForwardTransform(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs) {
if (!expected_out_axes.defined()) return STuple();
// TODO(tvm-team) allow same axes accumulation
// not as important because it is less common in nn.
CHECK(!sargs[0]->axes.defined());
CHECK(!sargs[1]->axes.defined());
const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
Expr lhs = sargs[0]->value;
Expr rhs = sargs[1]->value;
auto rnode = make_node<STupleNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) {
rnode->value = lhs;
rnode->scale = rhs;
rnode->axes = expected_out_axes;
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) {
rnode->value = rhs;
rnode->scale = lhs;
rnode->axes = expected_out_axes;
}
return STuple(rnode);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", MultiplyForwardTransform);
// Consumer operators
// Conv2D send out requirement of axis folding.
Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
// TODO(tvm-team) support general data layout
// by transforming weight
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C');
int c_small_axis = data_layout.indexof('c');
const auto* tdata = call->args[0]->type_as<TensorTypeNode>();
CHECK(tdata) << "require checked type";
CHECK_GE(c_big_axis, 0);
AxesSet data_axes = NullValue<AxesSet>();
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d =
is_const_int(tdata->shape[c_big_axis], param->groups);
if (weight_layout.indexof('i') < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
}
return {data_axes, NullValue<AxesSet>()};
}
// Conv2D consumes the scale axis during transformation.
STuple Conv2DForwardTransform(const Call& ref_call,
const AxesSet& expected_axes,
const Array<STuple>& sargs) {
// if data do not have scale, normal transform path.
STuple sdata = sargs[0];
if (!sdata->scale.defined()) return STuple();
CHECK(sdata->axes.defined());
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C');
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.indexof('i'), -1);
CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value);
int big_ic_axis = weight_layout.indexof('I');
const auto* tdata = ref_call->args[0]->type_as<TensorTypeNode>();
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d =
is_const_int(tdata->shape[c_big_axis], param->groups);
CHECK(param->groups == 1 || is_depthwise_conv2d);
// match the ic_axis
Expr scale = ExpandBiasToMatchAxis(
sdata->scale, weight_layout.ndim(), {big_ic_axis});
Expr weight = Multiply(sargs[1]->value, scale);
// return transformed conv2d
auto rnode = make_node<STupleNode>();
rnode->value = CallNode::make(
ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
return STuple(rnode);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", Conv2DForwardTransform);
Expr ForwardFoldScaleAxis(Expr data) {
return FScaleAxisForwardTransform().Transform(data);
}
// Expose the FoldScaleAxisFoward
TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis")
.set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis);
} // namespace fold_scale_axis
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pattern_util.h
* \brief Header of internal operator functions
* These can be used for writing passes.
*/
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm {
namespace relay {
/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that:
*
* rhs matches the dimension of lhs specified by lhs_axes
* rhs's value equals 1 on rest of dimensions.
*
* \param tlhs The type of left operand (data)
* \param trhs The type right operand (bias)
* \param lhs_axes The axes on lhs to match.
* \param rhs_value A squeezed version of rhs which only contains matched dimension.
* \return Whether match is successful.
*/
inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
const TensorTypeNode* trhs,
const Array<Integer>& lhs_axes,
Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false;
AttrsEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0;
NodePtr<SqueezeAttrs> squeeze_attrs;
if (rhs_value != nullptr) {
squeeze_attrs = make_node<SqueezeAttrs>();
}
for (size_t i = 0; i < tlhs->shape.size(); ++i) {
if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
return false;
}
++j;
} else if (i >= base) {
if (!is_const_int(trhs->shape[i - base], 1)) {
return false;
}
if (rhs_value != nullptr) {
squeeze_attrs->axis.push_back(static_cast<int>(i - base));
}
}
}
if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
static const Op& squeeze_op = Op::Get("squeeze");
*rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
}
return true;
}
/*!
* \brief Expand 1D Tensor to match axis.
*
* The result bias can be used to add or multiply to
* the target Tensor on the specified axis via broadcasting rule.
*
* \param bias The bias.
* \param target_ndim target dimension.
* \param axes The axis on the output we want to match on.
*/
inline Expr ExpandBiasToMatchAxis(Expr bias,
int target_ndim,
const Array<Integer>& axes) {
static const Op& expand_dims = Op::Get("expand_dims");
for (size_t i = axes.size(); i != 0; --i) {
if (i == axes.size()) {
int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
if (num_pad_axis > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(num_pad_axis);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
} else {
int64_t diff = axes[i]->value - axes[i - 1]->value;
CHECK_GE(diff, 0L);
if (diff > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(diff);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
}
}
return bias;
}
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Divide(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("divide");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
...@@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator {
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr) CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << GetRef<Expr>(op) << "Cannot resolve type of " << GetRef<Expr>(op)
<< " at " << op->span; << " at " << op->span;
Expr new_e = ExprMutator::VisitExpr_(op); Expr new_e = ExprMutator::VisitExpr_(op);
if (!checked_type.same_as(new_e->checked_type_)) { // new_call and new_var's code is only going to be valid for VarNode/CallNode.
// Compiler optimization will likely fold these away for other nodes.
CallNode* new_call =(
std::is_base_of<CallNode, T>::value ?
static_cast<CallNode*>(new_e.node_.get()) : nullptr);
VarNode* new_var =(
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
// check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_);
bool need_update_call = (
std::is_base_of<CallNode, T>::value &&
it->second.type_args.defined() &&
!it->second.type_args.same_as(new_call->type_args));
bool need_update_var = (
std::is_base_of<VarNode, T>::value &&
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());
if (!need_update_type && !need_update_var && !need_update_call) return new_e;
if (!new_e.node_.unique()) {
// Copy on write optimization // Copy on write optimization
// If new_e is an old expression, // If new_e is an old expression,
// we make a copy mutating an existing reference. // we make a copy mutating an existing reference.
if (!new_e.node_.unique()) { new_e = Expr(make_node<T>(*new_e.as<T>()));
new_e = Expr(make_node<T>(*new_e.as<T>())); new_call = (
} std::is_base_of<CallNode, T>::value ?
new_e->checked_type_ = checked_type; static_cast<CallNode*>(new_e.node_.get()) : nullptr);
new_var = (
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
} }
if (it->second.type_args.defined()) { // attach the information.
Call call = Downcast<Call>(new_e); if (need_update_type) {
const CallNode* const_call_ref = call.operator->(); new_e->checked_type_ = checked_type;
CallNode* call_ref = const_cast<CallNode*>(const_call_ref); }
call_ref->type_args = it->second.type_args;
for (size_t i = 0; i < call->type_args.size(); i++) { if (need_update_call) {
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i])); new_call->type_args = it->second.type_args;
for (size_t i = 0; i < new_call->type_args.size(); i++) {
new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i]));
} }
} }
if (need_update_var) {
new_var->type_annotation = checked_type;
}
return new_e; return new_e;
} }
...@@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator {
private: private:
const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_; const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_; TypeSolver* solver_;
// whether attach the checked type as type_annotation
// if original type anntation is missing.
bool update_missing_type_annotation_{true};
}; };
......
...@@ -55,8 +55,8 @@ def test_transpose_infer_type(): ...@@ -55,8 +55,8 @@ def test_transpose_infer_type():
def test_squeeze_infer_type(): def test_squeeze_infer_type():
n, t, d = 1, 4, 1 n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32")) x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(2,)) y = relay.squeeze(x, axis=(2,))
assert "axes=" in y.astext() assert "axis=" in y.astext()
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType( assert yy.checked_type == relay.TensorType(
(1, 4), "float32") (1, 4), "float32")
...@@ -64,7 +64,7 @@ def test_squeeze_infer_type(): ...@@ -64,7 +64,7 @@ def test_squeeze_infer_type():
n, t, d = 1, 4, 1 n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32")) x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x) y = relay.squeeze(x)
assert "axes=" not in y.astext() assert "axis=" not in y.astext()
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType( assert yy.checked_type == relay.TensorType(
(4,), "float32") (4,), "float32")
...@@ -74,7 +74,7 @@ def test_squeeze_infer_type(): ...@@ -74,7 +74,7 @@ def test_squeeze_infer_type():
def test_squeeze_bad_axes_infer_type(): def test_squeeze_bad_axes_infer_type():
n, t, d = 1, 4, 1 n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32")) x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(1,)) y = relay.squeeze(x, axis=(1,))
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
......
from tvm import relay
def test_fold_fwd_simple():
"""Simple testcase."""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
x = relay.add(x, in_bias)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_bias, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
x = relay.add(x, in_bias)
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2)
def test_fold_fwd_dual_path():
"""scale axis being consumed by two consumers"""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.multiply(in_scale, x)
x = relay.nn.relu(x)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def expected(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, in_scale)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3)
def test_fold_fwd_fail():
"""testcase where we canont fold"""
def before(x, conv_weight, in_bias, in_scale, channels):
x = relay.multiply(x, in_scale)
xx = relay.nn.leaky_relu(x, alpha=0.1)
y1 = relay.nn.conv2d(xx, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
padding=(1, 1))
z = relay.add(y1, x)
return relay.Function(relay.ir_pass.free_vars(z), z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
check((2, 11, 10, 4), 4)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment