Commit 05ea6018 by Tianqi Chen Committed by ziheng

[RELAY][PASS] FoldScaleAxis Forward (#2020)

* [RELAY][PASS] FoldScaleAxis Forward

* Introduce helper function type_as

* Update per review comment

* Fix according to comments
parent 26e3aa19
......@@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes;
// use axis to make the name numpy compatible.
Array<Integer> axis;
TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The axes to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;"
TVM_ATTR_FIELD(axis)
.describe("The axis to squeeze in the input tensor."
"If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.")
.set_default(Array<IndexExpr>({}));
"It is an error if an axis does not has dimension 1.")
.set_default(NullValue<Array<Integer> >());
}
}; // struct SqueezeAttrs
......
......@@ -40,6 +40,18 @@ class ExprNode : public RelayNode {
"field for this node";
return this->checked_type_;
}
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
......@@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->type_key();
return node;
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
......@@ -150,7 +150,14 @@ class ExprVisitor
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
Expr Mutate(const Expr& expr);
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
}
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
......@@ -161,7 +168,8 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
/*!
* \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
......@@ -169,7 +177,7 @@ class ExprMutator
*/
virtual Type VisitType(const Type& t);
private:
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};
......
......@@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode {
v->Visit("support_level", &support_level);
}
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
bool IsPrimitiveOp() const {
if (is_primitive_ != -1) return is_primitive_ != 0;
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
return is_primitive_ != 0;
}
static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
......@@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode {
// friend class
friend class GenericOpMap;
friend class OpRegistry;
friend bool IsPrimitiveOp(const Expr&);
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
// whether this is a primitive op. -1 means unknown.
mutable int is_primitive_{-1};
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
};
/*!
......@@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
*/
inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>();
if (!op) {
return false;
}
const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
return op != nullptr && op->IsPrimitiveOp();
}
} // namespace relay
......
......@@ -10,6 +10,7 @@ from . import _make
from .expr import Expr
from .ty import Type
def infer_type(expr, env=None):
"""Infer the type of expr under the context of env.
......@@ -30,6 +31,23 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env)
def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
"""
return _ir_pass.forward_fold_scale_axis(expr)
def well_formed(expr):
"""Check that each Var is only bound once (well formed).
......@@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs):
"""
return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
......@@ -170,6 +189,7 @@ def graph_equal(lhs, rhs):
"""
return bool(_make._graph_equal(lhs, rhs))
def structural_hash(value):
"""Hash a Relay expression structurally.
......
......@@ -49,27 +49,25 @@ def transpose(data, axes=None):
return _make.transpose(data, list(axes))
def squeeze(data, axes=None):
def squeeze(data, axis=None):
"""Squeeze axes in the array.
Parameters
----------
data : relay.Expr
data : tvm.relay.Expr
The input data to the operator.
axes : None or List[int]
Axes to remove.
If axes = [] or = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes.
If any axis in axes has dimension that does not equal 1, it is an error.
axis : None or List[int]
The set of axes to remove.
If axis = None, remove all axis of dimensions 1.
If any specified axis has dimension that does not equal 1, it is an error.
Returns
-------
result : relay.Expr
result : tvm.relay.Expr
The squeezed result.
"""
axes = axes or []
return _make.squeeze(data, list(axes))
return _make.squeeze(data, axis)
def reshape(data, newshape):
......
......@@ -296,13 +296,23 @@ class AlphaEqualHandler:
if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false;
// skip type_args check for primitive ops.
bool is_primitive = IsPrimitiveOp(lhs->op);
if (!is_primitive) {
if (lhs->type_args.size() != rhs->type_args.size()) {
return false;
}
}
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
if (!ExprEqual(lhs->args[i], rhs->args[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
if (!is_primitive) {
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
}
return AttrEqual(lhs->attrs, rhs->attrs);
} else {
......
......@@ -12,12 +12,12 @@
namespace tvm {
namespace relay {
Expr ExprMutator::Mutate(const Expr& expr) {
Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
return it->second;
} else {
Expr new_expr = ExprMutator::VisitExpr(expr);
Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr;
return new_expr;
}
......
......@@ -761,9 +761,9 @@ Examples::
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) {
Array<Integer> axis) {
auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes);
attrs->axis = std::move(axis);
static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
......@@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types,
const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr);
std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1
if (param->axes.size() == 0) {
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
......@@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types,
for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
}
for (const auto& e : param->axes) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
for (const auto& e : param->axis) {
original_shape.at(e->value).second = false;
}
for (const auto p : original_shape) {
if (p.second) {
......
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pattern_util.h
* \brief Header of internal operator functions
* These can be used for writing passes.
*/
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace tvm {
namespace relay {
/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that:
*
* rhs matches the dimension of lhs specified by lhs_axes
* rhs's value equals 1 on rest of dimensions.
*
* \param tlhs The type of left operand (data)
* \param trhs The type right operand (bias)
* \param lhs_axes The axes on lhs to match.
* \param rhs_value A squeezed version of rhs which only contains matched dimension.
* \return Whether match is successful.
*/
inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
const TensorTypeNode* trhs,
const Array<Integer>& lhs_axes,
Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false;
AttrsEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0;
NodePtr<SqueezeAttrs> squeeze_attrs;
if (rhs_value != nullptr) {
squeeze_attrs = make_node<SqueezeAttrs>();
}
for (size_t i = 0; i < tlhs->shape.size(); ++i) {
if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
return false;
}
++j;
} else if (i >= base) {
if (!is_const_int(trhs->shape[i - base], 1)) {
return false;
}
if (rhs_value != nullptr) {
squeeze_attrs->axis.push_back(static_cast<int>(i - base));
}
}
}
if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
static const Op& squeeze_op = Op::Get("squeeze");
*rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
}
return true;
}
/*!
* \brief Expand 1D Tensor to match axis.
*
* The result bias can be used to add or multiply to
* the target Tensor on the specified axis via broadcasting rule.
*
* \param bias The bias.
* \param target_ndim target dimension.
* \param axes The axis on the output we want to match on.
*/
inline Expr ExpandBiasToMatchAxis(Expr bias,
int target_ndim,
const Array<Integer>& axes) {
static const Op& expand_dims = Op::Get("expand_dims");
for (size_t i = axes.size(); i != 0; --i) {
if (i == axes.size()) {
int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
if (num_pad_axis > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(num_pad_axis);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
} else {
int64_t diff = axes[i]->value - axes[i - 1]->value;
CHECK_GE(diff, 0L);
if (diff > 0) {
auto attrs = make_node<ExpandDimsAttrs>();
attrs->axis = i;
attrs->num_newaxis = static_cast<int>(diff);
bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
}
}
}
return bias;
}
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Divide(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("divide");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
......@@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator {
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << GetRef<Expr>(op)
<< " at " << op->span;
Expr new_e = ExprMutator::VisitExpr_(op);
if (!checked_type.same_as(new_e->checked_type_)) {
// new_call and new_var's code is only going to be valid for VarNode/CallNode.
// Compiler optimization will likely fold these away for other nodes.
CallNode* new_call =(
std::is_base_of<CallNode, T>::value ?
static_cast<CallNode*>(new_e.node_.get()) : nullptr);
VarNode* new_var =(
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
// check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_);
bool need_update_call = (
std::is_base_of<CallNode, T>::value &&
it->second.type_args.defined() &&
!it->second.type_args.same_as(new_call->type_args));
bool need_update_var = (
std::is_base_of<VarNode, T>::value &&
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());
if (!need_update_type && !need_update_var && !need_update_call) return new_e;
if (!new_e.node_.unique()) {
// Copy on write optimization
// If new_e is an old expression,
// we make a copy mutating an existing reference.
if (!new_e.node_.unique()) {
new_e = Expr(make_node<T>(*new_e.as<T>()));
}
new_e->checked_type_ = checked_type;
new_e = Expr(make_node<T>(*new_e.as<T>()));
new_call = (
std::is_base_of<CallNode, T>::value ?
static_cast<CallNode*>(new_e.node_.get()) : nullptr);
new_var = (
std::is_base_of<VarNode, T>::value ?
static_cast<VarNode*>(new_e.node_.get()) : nullptr);
}
if (it->second.type_args.defined()) {
Call call = Downcast<Call>(new_e);
const CallNode* const_call_ref = call.operator->();
CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
call_ref->type_args = it->second.type_args;
// attach the information.
if (need_update_type) {
new_e->checked_type_ = checked_type;
}
for (size_t i = 0; i < call->type_args.size(); i++) {
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
if (need_update_call) {
new_call->type_args = it->second.type_args;
for (size_t i = 0; i < new_call->type_args.size(); i++) {
new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i]));
}
}
if (need_update_var) {
new_var->type_annotation = checked_type;
}
return new_e;
}
......@@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator {
private:
const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_;
// whether attach the checked type as type_annotation
// if original type anntation is missing.
bool update_missing_type_annotation_{true};
};
......
......@@ -55,8 +55,8 @@ def test_transpose_infer_type():
def test_squeeze_infer_type():
n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(2,))
assert "axes=" in y.astext()
y = relay.squeeze(x, axis=(2,))
assert "axis=" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(1, 4), "float32")
......@@ -64,7 +64,7 @@ def test_squeeze_infer_type():
n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x)
assert "axes=" not in y.astext()
assert "axis=" not in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(4,), "float32")
......@@ -74,7 +74,7 @@ def test_squeeze_infer_type():
def test_squeeze_bad_axes_infer_type():
n, t, d = 1, 4, 1
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.squeeze(x, axes=(1,))
y = relay.squeeze(x, axis=(1,))
yy = relay.ir_pass.infer_type(y)
......
from tvm import relay
def test_fold_fwd_simple():
"""Simple testcase."""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
x = relay.add(x, in_bias)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_bias, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
x = relay.add(x, in_bias)
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2)
def test_fold_fwd_dual_path():
"""scale axis being consumed by two consumers"""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.multiply(in_scale, x)
x = relay.nn.relu(x)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def expected(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, in_scale)
x = relay.subtract(x, in_bias)
y1 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x,
relay.multiply(conv_weight, in_scale),
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
return relay.Function(args, z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3)
def test_fold_fwd_fail():
"""testcase where we canont fold"""
def before(x, conv_weight, in_bias, in_scale, channels):
x = relay.multiply(x, in_scale)
xx = relay.nn.leaky_relu(x, alpha=0.1)
y1 = relay.nn.conv2d(xx, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
padding=(1, 1))
z = relay.add(y1, x)
return relay.Function(relay.ir_pass.free_vars(z), z)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
check((2, 11, 10, 4), 4)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment