Commit 584a32ae by Balint Cristian Committed by Wuwei Lin

[Relay] Handle float16 constants & fix BatchNorm (#3260)

parent c8a0f524
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_ #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_ #define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <builtin_fp16.h>
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -49,6 +50,9 @@ namespace relay { ...@@ -49,6 +50,9 @@ namespace relay {
} else if (type == Float(32)) { \ } else if (type == Float(32)) { \
typedef float DType; \ typedef float DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (type == Float(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(64)) { \ } else if (type == Int(64)) { \
typedef int64_t DType; \ typedef int64_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
...@@ -204,7 +208,14 @@ template<typename T> ...@@ -204,7 +208,14 @@ template<typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) { inline Constant MakeConstantScalar(DataType dtype, T value) {
runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, { TVM_DTYPE_DISPATCH(dtype, DType, {
*static_cast<DType*>(arr->data) = value; if (dtype == Float(16)) {
// convert to float16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
} else {
*static_cast<DType*>(arr->data) = value;
}
}) })
return ConstantNode::make(arr); return ConstantNode::make(arr);
} }
......
...@@ -36,11 +36,13 @@ Expr BatchNormToInferUnpack(const Attrs attrs, ...@@ -36,11 +36,13 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
Expr moving_mean, Expr moving_mean,
Expr moving_var, Expr moving_var,
Type tdata) { Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<BatchNormAttrs>(); const auto param = attrs.as<BatchNormAttrs>();
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon)); Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr var_add_eps = Add(moving_var, epsilon); Expr var_add_eps = Add(moving_var, epsilon);
Expr sqrt_var = Sqrt(var_add_eps); Expr sqrt_var = Sqrt(var_add_eps);
Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var); Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);
if (param->scale) { if (param->scale) {
scale = Multiply(scale, gamma); scale = Multiply(scale, gamma);
...@@ -52,8 +54,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs, ...@@ -52,8 +54,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
} }
int axis = param->axis; int axis = param->axis;
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
auto ndim = ttype->shape.size(); auto ndim = ttype->shape.size();
scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
from tvm import relay as rly from tvm import relay as rly
from tvm.relay.ir_pass import simplify_inference, alpha_equal from tvm.relay.ir_pass import simplify_inference, alpha_equal
def test_simplify_batchnorm(): def test_simplify_batchnorm(dtype='float32'):
def simple_bn(x, gamma, beta, moving_mean, moving_var, def simple_bn(x, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, shape=None): axis=1, epsilon=1e-5, shape=None):
# expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
scale = rly.multiply(rly.const(1, 'float32') / scale = rly.multiply(rly.const(1, dtype) /
rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma) rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma)
shift = rly.add( shift = rly.add(
rly.multiply(rly.negative(moving_mean), scale), beta) rly.multiply(rly.negative(moving_mean), scale), beta)
num_newaxis = len(shape) - (axis + 1) num_newaxis = len(shape) - (axis + 1)
...@@ -33,8 +33,8 @@ def test_simplify_batchnorm(): ...@@ -33,8 +33,8 @@ def test_simplify_batchnorm():
def check(dim, axis, nstep): def check(dim, axis, nstep):
eps = 0.01 eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32') ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
ttype2 = rly.TensorType((10,), 'float32') ttype2 = rly.TensorType((10,), dtype)
x = rly.var("x", ttype1) x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2) beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2) gamma = rly.var("gamma", ttype2)
...@@ -43,10 +43,10 @@ def test_simplify_batchnorm(): ...@@ -43,10 +43,10 @@ def test_simplify_batchnorm():
y1, y2 = x, x y1, y2 = x, x
for _ in range(nstep): for _ in range(nstep):
y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'), y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype),
gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = rly.nn.dropout(y1) y1 = rly.nn.dropout(y1)
y2 = simple_bn(y2 + rly.const(1, 'float32'), y2 = simple_bn(y2 + rly.const(1, dtype),
gamma, beta, moving_mean, moving_var, gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ttype1.shape) epsilon=eps, axis=axis, shape=ttype1.shape)
y1 = rly.ir_pass.infer_type(y1) y1 = rly.ir_pass.infer_type(y1)
...@@ -60,4 +60,5 @@ def test_simplify_batchnorm(): ...@@ -60,4 +60,5 @@ def test_simplify_batchnorm():
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_batchnorm() test_simplify_batchnorm(dtype='float32')
test_simplify_batchnorm(dtype='float16')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment