Commit 05ea6018 by Tianqi Chen Committed by ziheng

[RELAY][PASS] FoldScaleAxis Forward (#2020)

* [RELAY][PASS] FoldScaleAxis Forward

* Introduce helper function type_as

* Update per review comment

* Fix according to comments
parent 26e3aa19
...@@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
/*! \brief Attributes used in squeeze operators */ /*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> { struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes; // use axis to make the name numpy compatible.
Array<Integer> axis;
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") { TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes) TVM_ATTR_FIELD(axis)
.describe("The axes to squeeze in the input tensor." .describe("The axis to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;" "If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed." "Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.") "It is an error if an axis does not has dimension 1.")
.set_default(Array<IndexExpr>({})); .set_default(NullValue<Array<Integer> >());
} }
}; // struct SqueezeAttrs }; // struct SqueezeAttrs
......
...@@ -40,6 +40,18 @@ class ExprNode : public RelayNode { ...@@ -40,6 +40,18 @@ class ExprNode : public RelayNode {
"field for this node"; "field for this node";
return this->checked_type_; return this->checked_type_;
} }
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr"; static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
...@@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode { ...@@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->type_key();
return node;
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -150,7 +150,14 @@ class ExprVisitor ...@@ -150,7 +150,14 @@ class ExprVisitor
class ExprMutator class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public: public:
Expr Mutate(const Expr& expr); /*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
}
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override; Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override; Expr VisitExpr_(const GlobalVarNode* op) override;
...@@ -161,7 +168,8 @@ class ExprMutator ...@@ -161,7 +168,8 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions. /*!
* \brief Used to visit the types inside of expressions.
* *
* Can be overloaded to transform the types in arbitrary * Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type * ways, one way would be to define a sub-class of type
...@@ -169,7 +177,7 @@ class ExprMutator ...@@ -169,7 +177,7 @@ class ExprMutator
*/ */
virtual Type VisitType(const Type& t); virtual Type VisitType(const Type& t);
private: protected:
/*! \brief Internal map used for memoization. */ /*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
}; };
......
...@@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode { ...@@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode {
v->Visit("support_level", &support_level); v->Visit("support_level", &support_level);
} }
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
bool IsPrimitiveOp() const {
if (is_primitive_ != -1) return is_primitive_ != 0;
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
return is_primitive_ != 0;
}
static constexpr const char* _type_key = "relay.Op"; static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
...@@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode { ...@@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode {
// friend class // friend class
friend class GenericOpMap; friend class GenericOpMap;
friend class OpRegistry; friend class OpRegistry;
friend bool IsPrimitiveOp(const Expr&);
// Program internal unique index of operator. // Program internal unique index of operator.
// Used to help index the program. // Used to help index the program.
uint32_t index_{0}; uint32_t index_{0};
// whether this is a primitive op. -1 means unknown.
mutable int is_primitive_{-1};
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
}; };
/*! /*!
...@@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op, ...@@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
*/ */
inline bool IsPrimitiveOp(const Expr& expr) { inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>(); const auto* op = expr.as<OpNode>();
return op != nullptr && op->IsPrimitiveOp();
if (!op) {
return false;
}
const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
} }
} // namespace relay } // namespace relay
......
...@@ -10,6 +10,7 @@ from . import _make ...@@ -10,6 +10,7 @@ from . import _make
from .expr import Expr from .expr import Expr
from .ty import Type from .ty import Type
def infer_type(expr, env=None): def infer_type(expr, env=None):
"""Infer the type of expr under the context of env. """Infer the type of expr under the context of env.
...@@ -30,6 +31,23 @@ def infer_type(expr, env=None): ...@@ -30,6 +31,23 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env) return _ir_pass.infer_type(expr, env)
def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
"""
return _ir_pass.forward_fold_scale_axis(expr)
def well_formed(expr): def well_formed(expr):
"""Check that each Var is only bound once (well formed). """Check that each Var is only bound once (well formed).
...@@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs): ...@@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs):
""" """
return bool(_make._alpha_equal(lhs, rhs)) return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs): def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence. """Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that The difference between this and alpha-equality is that
...@@ -170,6 +189,7 @@ def graph_equal(lhs, rhs): ...@@ -170,6 +189,7 @@ def graph_equal(lhs, rhs):
""" """
return bool(_make._graph_equal(lhs, rhs)) return bool(_make._graph_equal(lhs, rhs))
def structural_hash(value): def structural_hash(value):
"""Hash a Relay expression structurally. """Hash a Relay expression structurally.
......
...@@ -49,27 +49,25 @@ def transpose(data, axes=None): ...@@ -49,27 +49,25 @@ def transpose(data, axes=None):
return _make.transpose(data, list(axes)) return _make.transpose(data, list(axes))
def squeeze(data, axes=None): def squeeze(data, axis=None):
"""Squeeze axes in the array. """Squeeze axes in the array.
Parameters Parameters
---------- ----------
data : relay.Expr data : tvm.relay.Expr
The input data to the operator. The input data to the operator.
axes : None or List[int] axis : None or List[int]
Axes to remove. The set of axes to remove.
If axes = [] or = None, remove all axis of dimensions 1. If axis = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes. If any specified axis has dimension that does not equal 1, it is an error.
If any axis in axes has dimension that does not equal 1, it is an error.
Returns Returns
------- -------
result : relay.Expr result : tvm.relay.Expr
The squeezed result. The squeezed result.
""" """
axes = axes or [] return _make.squeeze(data, axis)
return _make.squeeze(data, list(axes))
def reshape(data, newshape): def reshape(data, newshape):
......
...@@ -296,14 +296,24 @@ class AlphaEqualHandler: ...@@ -296,14 +296,24 @@ class AlphaEqualHandler:
if (const CallNode* rhs = other.as<CallNode>()) { if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false; if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false; if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false; // skip type_args check for primitive ops.
bool is_primitive = IsPrimitiveOp(lhs->op);
if (!is_primitive) {
if (lhs->type_args.size() != rhs->type_args.size()) {
return false;
}
}
for (size_t i = 0; i < lhs->args.size(); ++i) { for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false; if (!ExprEqual(lhs->args[i], rhs->args[i])) {
return false;
}
} }
if (!is_primitive) {
for (size_t i = 0; i < lhs->type_args.size(); ++i) { for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false; if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
} }
}
return AttrEqual(lhs->attrs, rhs->attrs); return AttrEqual(lhs->attrs, rhs->attrs);
} else { } else {
return false; return false;
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
Expr ExprMutator::Mutate(const Expr& expr) { Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr); auto it = this->memo_.find(expr);
if (it != this->memo_.end()) { if (it != this->memo_.end()) {
return it->second; return it->second;
} else { } else {
Expr new_expr = ExprMutator::VisitExpr(expr); Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr; memo_[expr] = new_expr;
return new_expr; return new_expr;
} }
......
...@@ -761,9 +761,9 @@ Examples:: ...@@ -761,9 +761,9 @@ Examples::
TVM_REGISTER_NODE_TYPE(SqueezeAttrs); TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
Expr MakeSqueeze(Expr data, Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) { Array<Integer> axis) {
auto attrs = make_node<SqueezeAttrs>(); auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes); attrs->axis = std::move(axis);
static const Op& op = Op::Get("squeeze"); static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
...@@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types, ...@@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types,
const auto* param = attrs.as<SqueezeAttrs>(); const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
std::vector<IndexExpr> result_shape; std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1 // if axes is None, squeeze all axes of dimension 1
if (param->axes.size() == 0) { if (!param->axis.defined()) {
for (const auto& e : data->shape) { for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e); const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
...@@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types, ...@@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types,
for (const auto& e : data->shape) { for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true)); original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
} }
for (const auto& e : param->axes) { for (const auto& e : param->axis) {
const int64_t* axis_ptr = as_const_int(e); original_shape.at(e->value).second = false;
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
} }
for (const auto p : original_shape) { for (const auto p : original_shape) {
if (p.second) { if (p.second) {
......
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pattern_util.h
* \brief Header of internal operator functions
* These can be used for writing passes.
*/
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm {
namespace relay {
/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that:
*
* rhs matches the dimension of lhs specified by lhs_axes
* rhs's value equals 1 on rest of dimensions.
*
* \param tlhs The type of left operand (data)
* \param trhs The type right operand (bias)
* \param lhs_axes The axes on lhs to match.
* \param rhs_value A squeezed version of rhs which only contains matched dimension.
* \return Whether match is successful.
*/
inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
const TensorTypeNode* trhs,
const Array<Integer>& lhs_axes,
Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false;
AttrsEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0;
NodePtr<SqueezeAttrs> squeeze_attrs;
if (rhs_value != nullptr) {
squeeze_attrs = make_node<SqueezeAttrs>();
}
for (size_t i = 0; i < tlhs->shape.size(); ++i) {
if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
return false;
}
++j;
} else if (i >= base) {
if (!is_const_int(trhs->shape[i - base], 1)) {
return false;
}
if (rhs_value != nullptr) {
squeeze_attrs->axis.push_back(static_cast<int>(i - base));
}
}
}
if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
static const Op& squeeze_op = Op::Get("squeeze");
*rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
}
return true;
}
/*!
* \brief Expand 1D Tensor to match axis.
*
* The result bias can be used to add or multiply to
* the target Tensor on the specified axis via broadcasting rule.
*
* \param bias The bias.
* \param target_ndim target dimension.
* \param axes The axis on the output we want to match on.
*/
inline Expr ExpandBiasToMatchAxis(Expr bias,
int target_ndim,
const Array<Integer>& axes) {
static const Op& expand_dims = Op::Get("expand_dims");
for (size_t i = axes.size(); i != 0; --i) {
if (i == axes.size()) {
int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
if (num_pad_axis > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(num_pad_axis);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
} else {
int64_t diff = axes[i]->value - axes[i - 1]->value;
CHECK_GE(diff, 0L);
if (diff > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(diff);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
}
}
return bias;
}
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Divide(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("divide");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
...@@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator {
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr) CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << GetRef<Expr>(op) << "Cannot resolve type of " << GetRef<Expr>(op)
<< " at " << op->span; << " at " << op->span;
Expr new_e = ExprMutator::VisitExpr_(op); Expr new_e = ExprMutator::VisitExpr_(op);
if (!checked_type.same_as(new_e->checked_type_)) { // new_call and new_var's code is only going to be valid for VarNode/CallNode.
// Compiler optimization will likely fold these away for other nodes.
CallNode* new_call =(
std::is_base_of<CallNode, T>::value ?
static_cast<CallNode*>(new_e.node_.get()) : nullptr);
VarNode* new_var =(
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
// check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_);
bool need_update_call = (
std::is_base_of<CallNode, T>::value &&
it->second.type_args.defined() &&
!it->second.type_args.same_as(new_call->type_args));
bool need_update_var = (
std::is_base_of<VarNode, T>::value &&
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());
if (!need_update_type && !need_update_var && !need_update_call) return new_e;
if (!new_e.node_.unique()) {
// Copy on write optimization // Copy on write optimization
// If new_e is an old expression, // If new_e is an old expression,
// we make a copy mutating an existing reference. // we make a copy mutating an existing reference.
if (!new_e.node_.unique()) {
new_e = Expr(make_node<T>(*new_e.as<T>())); new_e = Expr(make_node<T>(*new_e.as<T>()));
new_call = (
std::is_base_of<CallNode, T>::value ?
static_cast<CallNode*>(new_e.node_.get()) : nullptr);
new_var = (
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
} }
// attach the information.
if (need_update_type) {
new_e->checked_type_ = checked_type; new_e->checked_type_ = checked_type;
} }
if (it->second.type_args.defined()) { if (need_update_call) {
Call call = Downcast<Call>(new_e); new_call->type_args = it->second.type_args;
const CallNode* const_call_ref = call.operator->(); for (size_t i = 0; i < new_call->type_args.size(); i++) {
CallNode* call_ref = const_cast<CallNode*>(const_call_ref); new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i]));
call_ref->type_args = it->second.type_args;
for (size_t i = 0; i < call->type_args.size(); i++) {
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
} }
} }
if (need_update_var) {
new_var->type_annotation = checked_type;
}
return new_e; return new_e;
} }
...@@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator {
private: private:
const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_; const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_; TypeSolver* solver_;
// whether attach the checked type as type_annotation
// if original type anntation is missing.
bool update_missing_type_annotation_{true};
}; };
......
...@@ -55,8 +55,8 @@ def test_transpose_infer_type(): ...@@ -55,8 +55,8 @@ def test_transpose_infer_type():
def test_squeeze_infer_type(): def test_squeeze_infer_type():
n, t, d = 1, 4, 1 n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32")) x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(2,)) y = relay.squeeze(x, axis=(2,))
assert "axes=" in y.astext() assert "axis=" in y.astext()
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType( assert yy.checked_type == relay.TensorType(
(1, 4), "float32") (1, 4), "float32")
...@@ -64,7 +64,7 @@ def test_squeeze_infer_type(): ...@@ -64,7 +64,7 @@ def test_squeeze_infer_type():
n, t, d = 1, 4, 1 n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32")) x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x) y = relay.squeeze(x)
assert "axes=" not in y.astext() assert "axis=" not in y.astext()
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType( assert yy.checked_type == relay.TensorType(
(4,), "float32") (4,), "float32")
...@@ -74,7 +74,7 @@ def test_squeeze_infer_type(): ...@@ -74,7 +74,7 @@ def test_squeeze_infer_type():
def test_squeeze_bad_axes_infer_type(): def test_squeeze_bad_axes_infer_type():
n, t, d = 1, 4, 1 n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32")) x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(1,)) y = relay.squeeze(x, axis=(1,))
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
......
from tvm import relay
def test_fold_fwd_simple():
"""Simple testcase."""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
x = relay.add(x, in_bias)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_bias, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
x = relay.add(x, in_bias)
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2)
def test_fold_fwd_dual_path():
"""scale axis being consumed by two consumers"""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.multiply(in_scale, x)
x = relay.nn.relu(x)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def expected(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, in_scale)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3)
def test_fold_fwd_fail():
"""testcase where we canont fold"""
def before(x, conv_weight, in_bias, in_scale, channels):
x = relay.multiply(x, in_scale)
xx = relay.nn.leaky_relu(x, alpha=0.1)
y1 = relay.nn.conv2d(xx, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
padding=(1, 1))
z = relay.add(y1, x)
return relay.Function(relay.ir_pass.free_vars(z), z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
check((2, 11, 10, 4), 4)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment