Commit 05ea6018 by Tianqi Chen Committed by ziheng

[RELAY][PASS] FoldScaleAxis Forward (#2020)

* [RELAY][PASS] FoldScaleAxis Forward

* Introduce helper function type_as

* Update per review comment

* Fix according to comments
parent 26e3aa19
......@@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes;
// use axis to make the name numpy compatible.
Array<Integer> axis;
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The axes to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;"
TVM_ATTR_FIELD(axis)
.describe("The axis to squeeze in the input tensor."
"If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.")
.set_default(Array<IndexExpr>({}));
"It is an error if an axis does not has dimension 1.")
.set_default(NullValue<Array<Integer> >());
}
}; // struct SqueezeAttrs
......
......@@ -40,6 +40,18 @@ class ExprNode : public RelayNode {
"field for this node";
return this->checked_type_;
}
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
......@@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->type_key();
return node;
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
......@@ -150,7 +150,14 @@ class ExprVisitor
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
Expr Mutate(const Expr& expr);
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
}
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
......@@ -161,7 +168,8 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
/*!
* \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
......@@ -169,7 +177,7 @@ class ExprMutator
*/
virtual Type VisitType(const Type& t);
private:
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};
......
......@@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode {
v->Visit("support_level", &support_level);
}
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
bool IsPrimitiveOp() const {
if (is_primitive_ != -1) return is_primitive_ != 0;
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
return is_primitive_ != 0;
}
static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
......@@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode {
// friend class
friend class GenericOpMap;
friend class OpRegistry;
friend bool IsPrimitiveOp(const Expr&);
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
// whether this is a primitive op. -1 means unknown.
mutable int is_primitive_{-1};
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
};
/*!
......@@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
*/
inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>();
if (!op) {
return false;
}
const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
return op != nullptr && op->IsPrimitiveOp();
}
} // namespace relay
......
......@@ -10,6 +10,7 @@ from . import _make
from .expr import Expr
from .ty import Type
def infer_type(expr, env=None):
"""Infer the type of expr under the context of env.
......@@ -30,6 +31,23 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env)
def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
"""
return _ir_pass.forward_fold_scale_axis(expr)
def well_formed(expr):
"""Check that each Var is only bound once (well formed).
......@@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs):
"""
return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
......@@ -170,6 +189,7 @@ def graph_equal(lhs, rhs):
"""
return bool(_make._graph_equal(lhs, rhs))
def structural_hash(value):
"""Hash a Relay expression structurally.
......
......@@ -49,27 +49,25 @@ def transpose(data, axes=None):
return _make.transpose(data, list(axes))
def squeeze(data, axes=None):
def squeeze(data, axis=None):
"""Squeeze axes in the array.
Parameters
----------
data : relay.Expr
data : tvm.relay.Expr
The input data to the operator.
axes : None or List[int]
Axes to remove.
If axes = [] or = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes.
If any axis in axes has dimension that does not equal 1, it is an error.
axis : None or List[int]
The set of axes to remove.
If axis = None, remove all axis of dimensions 1.
If any specified axis has dimension that does not equal 1, it is an error.
Returns
-------
result : relay.Expr
result : tvm.relay.Expr
The squeezed result.
"""
axes = axes or []
return _make.squeeze(data, list(axes))
return _make.squeeze(data, axis)
def reshape(data, newshape):
......
......@@ -296,13 +296,23 @@ class AlphaEqualHandler:
if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false;
// skip type_args check for primitive ops.
bool is_primitive = IsPrimitiveOp(lhs->op);
if (!is_primitive) {
if (lhs->type_args.size() != rhs->type_args.size()) {
return false;
}
}
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
if (!ExprEqual(lhs->args[i], rhs->args[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
if (!is_primitive) {
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
}
return AttrEqual(lhs->attrs, rhs->attrs);
} else {
......
......@@ -12,12 +12,12 @@
namespace tvm {
namespace relay {
Expr ExprMutator::Mutate(const Expr& expr) {
Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
return it->second;
} else {
Expr new_expr = ExprMutator::VisitExpr(expr);
Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr;
return new_expr;
}
......
......@@ -761,9 +761,9 @@ Examples::
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) {
Array<Integer> axis) {
auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes);
attrs->axis = std::move(axis);
static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
......@@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types,
const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr);
std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1
if (param->axes.size() == 0) {
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
......@@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types,
for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
}
for (const auto& e : param->axes) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
for (const auto& e : param->axis) {
original_shape.at(e->value).second = false;
}
for (const auto p : original_shape) {
if (p.second) {
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file fold_scale_axis.cc
*
* \brief Fold axis scaling into weights of
* conv/dense operators.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "../op/nn/layout.h"
namespace tvm {
namespace relay {
/*!
* \brief namespace of fold scale axis
*
* Use namespace to reduce potential naming conflict.
*/
namespace fold_scale_axis {
using runtime::TypedPackedFunc;
// FoldScaleAxisFoward algorithm:
//
// The general idea is that we transform Expr to tuple of
// (value, axes, scale), where the final result satiesfies:
//
// result = value
// for i, k in enumerate(axes):
// k-ith dimension of result *= i-th dimension of scale
//
// Then we can propagate this signal along and fold the scale if necessary.
// However, it is possible that certain scale may never be consumed
// if there is no dense/conv2d that follows multiplication.
//
// In order to make sure all the scale we sent out can be consumed eventually,
// we run a backward "preparation phase", which propagates the demand
// of the potential axes scaling back to its input.
//
// The folding process is done in two steps:
// - Prepare phase: backward propagation of demand.
// - Transform phase: forward transformation,
/*!
* \brief sorted array axis, can also be nullptr.
*
* nullptr means no scaling request can be done.
*/
using AxesSet = Array<Integer>;
/*!
* \brief Merge two axis set together by taking
* intersection.
*
* \note The axes in a AxesSet should be sorted.
*
* \param lhs The left axis.
* \param rhs The right axis.
* \return The result of the inersection.
*/
AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
if (!lhs.defined()) return lhs;
if (!rhs.defined()) return rhs;
// This code relies on axes in a AxesSet to be sorted.
AxesSet ret;
size_t i = 0, j = 0;
while (i < lhs.size() && j < rhs.size()) {
if (lhs[i]->value < rhs[j]->value) {
++i;
} else if (lhs[i]->value > rhs[j]->value) {
++j;
} else {
ret.push_back(lhs[i]);
++i; ++j;
}
}
return ret;
}
/*!
* \param Get function from op_map.
* \param op_map The OpMap.
* \param op The operator being called.
* \tparam ValueType the content value type.
* \return The result value map.
*/
template<typename ValueType>
ValueType GetFunc(const OpMap<ValueType>& op_map,
const Expr& op) {
if (const OpNode* opnode = op.as<OpNode>()) {
return op_map.get(GetRef<Op>(opnode), ValueType());
} else {
return ValueType();
}
}
/*!
* \brief Preparation function for for pass scale forward.
* \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output.
* \return The result scaling on axes of the input.
*/
using FForwardPrep = runtime::TypedPackedFunc<
Array<AxesSet> (const Call& call, const AxesSet& out_scale_axes)>;
/*! \brief Axis scale tuple. */
class STupleNode : public Node {
public:
/*! \brief The value */
Expr value;
/*! \brief The axes to scale, can be nullptr(means no-scaling) */
AxesSet axes = NullValue<AxesSet>();
/*! \brief The scaling factor */
Expr scale = NullValue<Expr>();
void VisitAttrs(AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("axes", &axes);
v->Visit("scale", &scale);
}
static constexpr const char* _type_key = "relay.fold_scale_axis.STupleNode";
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, Node);
};
RELAY_DEFINE_NODE_REF(STuple, STupleNode, NodeRef);
/*!
* \brief The transform function, transform an old call to
* a new one given the new args.
* \param ref_call Reference call node that represent the op and the types.
* \param expected_out_axes The scale axes allowed in the output.
* \param sargs The input arguments.
*/
using FForwardTransform = TypedPackedFunc<
STuple(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs)>;
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
class FScaleAxisForwardPrep : private ExprVisitor {
public:
std::unordered_map<const Node*, AxesSet>
Prepare(const Expr& body) {
this->Update(body, NullValue<AxesSet>());
this->VisitExpr(body);
// flist is added in the Post-DFS order
// which is a special case of topological order.
// We reversely traverse the list to invoke the lazy functions.
// This act like a backprop of valid scale axis messages
for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) {
(*it)();
}
// return the created message;
return std::move(message_);
}
private:
// The invoke list
std::vector<std::function<void()> > flist_;
// The message on each node.
std::unordered_map<const Node*, AxesSet> message_;
// Update the message stored at node.
void Update(const Expr& node, const AxesSet& axes) {
// We run intersection of messages:
//
// %y = multiply(%x, %scale)
// %z1 = conv2d(%y, %w)
// %z2 = exp(%y)
//
// Consider the above code example,
// because %z2 will propagate null to %y,
// the AxesSet on %y is also null,
// and the forward folding won't be triggered.
const Node* key = node.get();
if (message_.count(key)) {
message_[key] = Intersect(message_[key], axes);
} else {
message_[key] = axes;
}
}
// Visitor pattern override.
void VisitExpr_(const LetNode* call) {
LOG(FATAL) << "FoldScaleAxis only accept dataflow-form";
}
void VisitExpr_(const FunctionNode* op) {
ExprVisitor::VisitExpr_(op);
auto flazy = [this, op] {
this->Update(op->body, NullValue<AxesSet>());
};
flist_.push_back(flazy);
}
void VisitExpr_(const CallNode* call) {
ExprVisitor::VisitExpr_(call);
// function to be lazily invoked
auto flazy = [this, call]() {
static const auto& fprep =
Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
// find the message send to this node.
auto it = message_.find(call);
AxesSet out_axes;
if (it != message_.end()) {
out_axes = it->second;
} else {
out_axes = NullValue<AxesSet>();
}
// pass the message back to all the children it references.
auto f = GetFunc(fprep, call->op);
if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out_axes);
CHECK_EQ(in_axes.size(), call->args.size());
for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], in_axes[i]);
}
} else {
for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], NullValue<AxesSet>());
}
}
};
flist_.push_back(flazy);
}
void VisitExpr_(const TupleNode* op) {
ExprVisitor::VisitExpr_(op);
// do not support pass scale through tuple for now.
auto flazy = [this, op]() {
for (const Expr& field : op->fields) {
this->Update(field, NullValue<AxesSet>());
}
};
flist_.push_back(flazy);
}
void VisitExpr_(const IfNode* op) {
ExprVisitor::VisitExpr_(op);
// do pass through condition
// by assigning NullValue<AxesSet>
// it means fuse signal cannot pass
// through into these subexpressions.
auto flazy = [this, op]() {
this->Update(op->cond, NullValue<AxesSet>());
this->Update(op->true_branch, NullValue<AxesSet>());
this->Update(op->false_branch, NullValue<AxesSet>());
};
flist_.push_back(flazy);
}
};
class FScaleAxisForwardTransform : private ExprMutator {
public:
// Transform expression.
Expr Transform(Expr expr) {
expected_scale_axes_ =
FScaleAxisForwardPrep().Prepare(expr);
return this->Mutate(expr);
}
private:
// Valid axes on each node.
std::unordered_map<const Node*, AxesSet> expected_scale_axes_;
std::unordered_map<const Node*, STuple> scale_memo_;
// If user simply call mutate,
// then only Expr is returned and we cannot
// accept outstanding scales.
Expr VisitExpr(const Expr& expr) final {
Expr res = ExprMutator::VisitExpr(expr);
CHECK(!scale_memo_.count(expr.get()))
<< "Outstanding scale";
return res;
}
STuple GetSTuple(const Expr& expr) {
Expr res = ExprMutator::VisitExpr(expr);
auto it = scale_memo_.find(expr.get());
if (it != scale_memo_.end()) {
CHECK(it->second->value.same_as(res));
return it->second;
} else {
auto node = make_node<STupleNode>();
node->value = res;
return STuple(node);
}
}
Expr VisitExpr_(const CallNode* call_node) final {
static const auto& ftransform =
Op::GetAttr<FForwardTransform>("FScaleAxisForwardTransform");
auto new_op = this->Mutate(call_node->op);
bool has_scale = false;
bool unchanged = call_node->op.same_as(new_op);
Array<STuple> call_sargs;
Array<Expr> call_args;
for (auto arg : call_node->args) {
STuple new_sarg = this->GetSTuple(arg);
unchanged &= new_sarg->value.same_as(arg);
if (new_sarg->axes.defined()) has_scale = true;
call_sargs.push_back(new_sarg);
call_args.push_back(new_sarg->value);
}
// get expected scale axes.
AxesSet expected_out_axes;
auto axis_it = expected_scale_axes_.find(call_node);
if (axis_it != expected_scale_axes_.end()) {
expected_out_axes = axis_it->second;
}
// propagation function
auto f = GetFunc(ftransform, call_node->op);
if (f != nullptr) {
STuple sret = f(GetRef<Call>(call_node), expected_out_axes, call_sargs);
if (sret.defined()) {
if (sret->axes.defined()) {
scale_memo_[call_node] = sret;
}
return sret->value;
}
}
// normal path
CHECK(!has_scale) << "Outstanding scale, on op=" << call_node->op;
if (unchanged) {
return GetRef<Expr>(call_node);
} else {
return CallNode::make(
new_op, call_args, call_node->attrs, call_node->type_args);
}
}
};
//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return {out};
}
STuple ReluForwardTransform(const Call& ref_call,
const AxesSet& expected_axes,
const Array<STuple>& sargs) {
if (!sargs[0]->axes.defined()) return STuple();
// return transformed conv2d
auto rnode = make_node<STupleNode>();
rnode->value = CallNode::make(
ref_call->op, {sargs[0]->value}, ref_call->attrs, {});
rnode->scale = sargs[0]->scale;
rnode->axes = sargs[0]->axes;
return STuple(rnode);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
// AddSub
Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
auto none = NullValue<AxesSet>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) {
return {out_axes, none};
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) {
return {none, out_axes};
} else {
return {none, none};
}
}
STuple AddSubForwardTransform(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs) {
if (!sargs[0]->axes.defined() && !sargs[1]->axes.defined()) {
return STuple();
}
const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
auto rnode = make_node<STupleNode>();
if (sargs[0]->axes.defined()) {
CHECK(!sargs[1]->axes.defined());
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, sargs[0]->axes));
Expr scale = ExpandBiasToMatchAxis(
sargs[0]->scale, tlhs->shape.size(), sargs[0]->axes);
Expr rhs = Divide(sargs[1]->value, scale);
rnode->value = CallNode::make(ref_call->op, {sargs[0]->value, rhs},
ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[0]->scale;
rnode->axes = sargs[0]->axes;
} else {
CHECK(sargs[1]->axes.defined());
CHECK(sargs[0]->axes.defined());
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, sargs[1]->axes));
Expr scale = ExpandBiasToMatchAxis(
sargs[1]->scale, trhs->shape.size(), sargs[1]->axes);
Expr lhs = Divide(sargs[0]->value, scale);
rnode->value = CallNode::make(ref_call->op, {lhs, sargs[1]->value},
ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[1]->scale;
rnode->axes = sargs[1]->axes;
}
return STuple(rnode);
}
RELAY_REGISTER_OP("add")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
RELAY_REGISTER_OP("add")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
RELAY_REGISTER_OP("subtract")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
RELAY_REGISTER_OP("subtract")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
// Producer operators
// Multiply produces the scale-axis pair.
STuple MultiplyForwardTransform(const Call& ref_call,
const AxesSet& expected_out_axes,
const Array<STuple>& sargs) {
if (!expected_out_axes.defined()) return STuple();
// TODO(tvm-team) allow same axes accumulation
// not as important because it is less common in nn.
CHECK(!sargs[0]->axes.defined());
CHECK(!sargs[1]->axes.defined());
const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
Expr lhs = sargs[0]->value;
Expr rhs = sargs[1]->value;
auto rnode = make_node<STupleNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) {
rnode->value = lhs;
rnode->scale = rhs;
rnode->axes = expected_out_axes;
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) {
rnode->value = rhs;
rnode->scale = lhs;
rnode->axes = expected_out_axes;
}
return STuple(rnode);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", MultiplyForwardTransform);
// Consumer operators
// Conv2D send out requirement of axis folding.
Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
// TODO(tvm-team) support general data layout
// by transforming weight
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C');
int c_small_axis = data_layout.indexof('c');
const auto* tdata = call->args[0]->type_as<TensorTypeNode>();
CHECK(tdata) << "require checked type";
CHECK_GE(c_big_axis, 0);
AxesSet data_axes = NullValue<AxesSet>();
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d =
is_const_int(tdata->shape[c_big_axis], param->groups);
if (weight_layout.indexof('i') < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
}
return {data_axes, NullValue<AxesSet>()};
}
// Conv2D consumes the scale axis during transformation.
STuple Conv2DForwardTransform(const Call& ref_call,
const AxesSet& expected_axes,
const Array<STuple>& sargs) {
// if data do not have scale, normal transform path.
STuple sdata = sargs[0];
if (!sdata->scale.defined()) return STuple();
CHECK(sdata->axes.defined());
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C');
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.indexof('i'), -1);
CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value);
int big_ic_axis = weight_layout.indexof('I');
const auto* tdata = ref_call->args[0]->type_as<TensorTypeNode>();
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d =
is_const_int(tdata->shape[c_big_axis], param->groups);
CHECK(param->groups == 1 || is_depthwise_conv2d);
// match the ic_axis
Expr scale = ExpandBiasToMatchAxis(
sdata->scale, weight_layout.ndim(), {big_ic_axis});
Expr weight = Multiply(sargs[1]->value, scale);
// return transformed conv2d
auto rnode = make_node<STupleNode>();
rnode->value = CallNode::make(
ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
return STuple(rnode);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardTransform>("FScaleAxisForwardTransform", Conv2DForwardTransform);
Expr ForwardFoldScaleAxis(Expr data) {
return FScaleAxisForwardTransform().Transform(data);
}
// Expose the FoldScaleAxisFoward
TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis")
.set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis);
} // namespace fold_scale_axis
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pattern_util.h
* \brief Header of internal operator functions
* These can be used for writing passes.
*/
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm {
namespace relay {
/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that:
*
* rhs matches the dimension of lhs specified by lhs_axes
* rhs's value equals 1 on rest of dimensions.
*
* \param tlhs The type of left operand (data)
* \param trhs The type right operand (bias)
* \param lhs_axes The axes on lhs to match.
* \param rhs_value A squeezed version of rhs which only contains matched dimension.
* \return Whether match is successful.
*/
inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
const TensorTypeNode* trhs,
const Array<Integer>& lhs_axes,
Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false;
AttrsEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0;
NodePtr<SqueezeAttrs> squeeze_attrs;
if (rhs_value != nullptr) {
squeeze_attrs = make_node<SqueezeAttrs>();
}
for (size_t i = 0; i < tlhs->shape.size(); ++i) {
if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
return false;
}
++j;
} else if (i >= base) {
if (!is_const_int(trhs->shape[i - base], 1)) {
return false;
}
if (rhs_value != nullptr) {
squeeze_attrs->axis.push_back(static_cast<int>(i - base));
}
}
}
if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
static const Op& squeeze_op = Op::Get("squeeze");
*rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
}
return true;
}
/*!
* \brief Expand 1D Tensor to match axis.
*
* The result bias can be used to add or multiply to
* the target Tensor on the specified axis via broadcasting rule.
*
* \param bias The bias.
* \param target_ndim target dimension.
* \param axes The axis on the output we want to match on.
*/
inline Expr ExpandBiasToMatchAxis(Expr bias,
int target_ndim,
const Array<Integer>& axes) {
static const Op& expand_dims = Op::Get("expand_dims");
for (size_t i = axes.size(); i != 0; --i) {
if (i == axes.size()) {
int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
if (num_pad_axis > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(num_pad_axis);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
} else {
int64_t diff = axes[i]->value - axes[i - 1]->value;
CHECK_GE(diff, 0L);
if (diff > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(diff);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
}
}
return bias;
}
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Divide(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("divide");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
......@@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator {
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << GetRef<Expr>(op)
<< " at " << op->span;
Expr new_e = ExprMutator::VisitExpr_(op);
if (!checked_type.same_as(new_e->checked_type_)) {
// new_call and new_var's code is only going to be valid for VarNode/CallNode.
// Compiler optimization will likely fold these away for other nodes.
CallNode* new_call =(
std::is_base_of<CallNode, T>::value ?
static_cast<CallNode*>(new_e.node_.get()) : nullptr);
VarNode* new_var =(
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
// check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_);
bool need_update_call = (
std::is_base_of<CallNode, T>::value &&
it->second.type_args.defined() &&
!it->second.type_args.same_as(new_call->type_args));
bool need_update_var = (
std::is_base_of<VarNode, T>::value &&
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());
if (!need_update_type && !need_update_var && !need_update_call) return new_e;
if (!new_e.node_.unique()) {
// Copy on write optimization
// If new_e is an old expression,
// we make a copy mutating an existing reference.
if (!new_e.node_.unique()) {
new_e = Expr(make_node<T>(*new_e.as<T>()));
}
new_e->checked_type_ = checked_type;
new_e = Expr(make_node<T>(*new_e.as<T>()));
new_call = (
std::is_base_of<CallNode, T>::value ?
static_cast<CallNode*>(new_e.node_.get()) : nullptr);
new_var = (
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
}
if (it->second.type_args.defined()) {
Call call = Downcast<Call>(new_e);
const CallNode* const_call_ref = call.operator->();
CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
call_ref->type_args = it->second.type_args;
// attach the information.
if (need_update_type) {
new_e->checked_type_ = checked_type;
}
for (size_t i = 0; i < call->type_args.size(); i++) {
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
if (need_update_call) {
new_call->type_args = it->second.type_args;
for (size_t i = 0; i < new_call->type_args.size(); i++) {
new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i]));
}
}
if (need_update_var) {
new_var->type_annotation = checked_type;
}
return new_e;
}
......@@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator {
private:
const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_;
// whether attach the checked type as type_annotation
// if original type anntation is missing.
bool update_missing_type_annotation_{true};
};
......
......@@ -55,8 +55,8 @@ def test_transpose_infer_type():
def test_squeeze_infer_type():
n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(2,))
assert "axes=" in y.astext()
y = relay.squeeze(x, axis=(2,))
assert "axis=" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(1, 4), "float32")
......@@ -64,7 +64,7 @@ def test_squeeze_infer_type():
n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x)
assert "axes=" not in y.astext()
assert "axis=" not in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(4,), "float32")
......@@ -74,7 +74,7 @@ def test_squeeze_infer_type():
def test_squeeze_bad_axes_infer_type():
n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(1,))
y = relay.squeeze(x, axis=(1,))
yy = relay.ir_pass.infer_type(y)
......
from tvm import relay
def test_fold_fwd_simple():
"""Simple testcase."""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
x = relay.add(x, in_bias)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_bias, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
x = relay.add(x, in_bias)
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2)
def test_fold_fwd_dual_path():
"""scale axis being consumed by two consumers"""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.multiply(in_scale, x)
x = relay.nn.relu(x)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def expected(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, in_scale)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3)
def test_fold_fwd_fail():
"""testcase where we canont fold"""
def before(x, conv_weight, in_bias, in_scale, channels):
x = relay.multiply(x, in_scale)
xx = relay.nn.leaky_relu(x, alpha=0.1)
y1 = relay.nn.conv2d(xx, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
padding=(1, 1))
z = relay.add(y1, x)
return relay.Function(relay.ir_pass.free_vars(z), z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
check((2, 11, 10, 4), 4)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment