Commit dd9d76ac by ziheng Committed by Tianqi Chen

[RELAY/PASS] Simplify inference. (#2033)

parent 2f9ab71e
......@@ -10,7 +10,6 @@ def test_simplify_batchnorm():
scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
shift = sym.elemwise_add(
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
shape = [-1 if i == axis else 1 for i in range(len(shape))]
# for 2D
num_newaxis=len(shape) - axis - 1
if num_newaxis:
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay."""
from __future__ import absolute_import
from numbers import Number as _Number
import numpy as _np
from .base import RelayNode, register_relay_node
......@@ -11,6 +12,8 @@ from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
# will be registered afterwards
_op_make = None
class Expr(RelayNode):
"""The base type for all Relay expressions."""
......@@ -48,6 +51,62 @@ class Expr(RelayNode):
"""
return _make.dtype_cast(self, dtype)
def __add__(self, other):
if isinstance(other, Expr):
return _op_make.add(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, Expr):
return _op_make.subtract(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))
def __rsub__(self, other):
if isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))
def __mul__(self, other):
if isinstance(other, Expr):
return _op_make.multiply(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))
def __rmul__(self, other):
return self.__mul__(other)
def __div__(self, other):
if isinstance(other, Expr):
return _op_make.divide(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))
def __rdiv__(self, other):
if isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))
def __truediv__(self, other):
return self.__div__(other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
@register_relay_node
class Constant(Expr):
......@@ -305,7 +364,7 @@ class TupleWrapper(object):
def __repr__(self):
return ("TupleWrapper(" + self.tuple_value.__repr__() +
", " + self.size + ")")
", " + str(self.size) + ")")
def astype(self, _):
raise TypeError("astype cannot be used on tuple")
......
......@@ -160,6 +160,21 @@ def free_type_vars(expr):
"""
return _ir_pass.free_type_vars(expr)
def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase.
Parameters
----------
e: tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
return _ir_pass.simplify_inference(expr)
def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).
......
......@@ -15,3 +15,11 @@ from . import vision
from . import _tensor
from ..expr import Expr
from ..base import register_relay_node
def _register_op_make():
from . import _make
from .. import expr
expr._op_make = _make
_register_op_make()
......@@ -120,6 +120,40 @@ inline bool IsDepthwiseConv2D(const Call& call,
}
/*!
* \brief Create a Constant with a scalar
*
* \param dtype The data type.
* \param value The value of the scalar.
* \return A Constant.
*/
template<typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) {
CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch";
runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
*static_cast<T*>(arr->data) = value;
return ConstantNode::make(arr);
}
inline Expr Negative(Expr x) {
static const Op& op = Op::Get("negative");
return CallNode::make(op, {x}, Attrs(), {});
}
inline Expr Sqrt(Expr x) {
static const Op& op = Op::Get("sqrt");
return CallNode::make(op, {x}, Attrs(), {});
}
inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
......
/*!
* Copyright (c) 2018 by Contributors
* \file simplify_inference.cc
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include "./pattern_util.h"
namespace tvm {
namespace relay {
Expr BatchNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Expr moving_mean,
Expr moving_var) {
const auto param = attrs.as<BatchNormAttrs>();
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr var_add_eps = Add(moving_var, epsilon);
Expr sqrt_var = Sqrt(var_add_eps);
Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var);
if (param->scale) {
scale = Multiply(scale, gamma);
}
Expr neg_mean = Negative(moving_mean);
Expr shift = Multiply(neg_mean, scale);
if (param->center) {
shift = Add(shift, beta);
}
int axis = param->axis;
const auto* tdata = data->type_as<TensorTypeNode>();
scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis});
shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis});
Expr out = Multiply(data, scale);
out = Add(out, shift);
return out;
}
class InferenceSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const TupleGetItemNode* n) final {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& dropout = Op::Get("nn.dropout");
Expr new_e = ExprMutator::VisitExpr_(n);
const auto* new_n = new_e.as<TupleGetItemNode>();
if (new_n->index != 0) {
return new_e;
}
if (const auto* call = new_n->tuple.as<CallNode>()) {
if (call->op.same_as(batch_norm)) {
return BatchNormToInferUnpack(call->attrs,
call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]);
} else if (call->op.same_as(dropout)) {
return call->args[0];
}
}
return new_e;
}
};
Expr SimplifyInference(const Expr& e) {
return InferenceSimplifier().Mutate(e);
}
TVM_REGISTER_API("relay._ir_pass.simplify_inference")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SimplifyInference(args[0]);
});
} // namespace relay
} // namespace tvm
from tvm import relay as rly
from tvm.relay.ir_pass import simplify_inference, alpha_equal
def test_simplify_batchnorm():
def simple_bn(x, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, shape=None):
# expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
scale = rly.multiply(rly.const(1, 'float32') /
rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma)
shift = rly.add(
rly.multiply(rly.negative(moving_mean), scale), beta)
num_newaxis = len(shape) - (axis + 1)
if num_newaxis:
scale = rly.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
shift = rly.expand_dims(shift, axis=1, num_newaxis=num_newaxis)
return x * scale + shift
def check(dim, axis, nstep):
eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32')
ttype2 = rly.TensorType((10,), 'float32')
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
moving_var = rly.var("moving_var", ttype2)
moving_mean = rly.var("moving_mean", ttype2)
y1, y2 = x, x
for _ in range(nstep):
y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'),
gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = rly.nn.dropout(y1)
y1 = rly.ir_pass.infer_type(y1)
y1 = simplify_inference(y1)
y2 = simple_bn(y2 + rly.const(1, 'float32'),
gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ttype1.shape)
assert rly.ir_pass.graph_equal(y1, y2)
check(2, 1, 1)
check(4, 1, 1)
check(4, 0, 3)
if __name__ == "__main__":
test_simplify_batchnorm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment