Commit 584a32ae by Balint Cristian Committed by Wuwei Lin

[Relay] Handle float16 constants & fix BatchNorm (#3260)

parent c8a0f524
......@@ -27,6 +27,7 @@
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_
#include <builtin_fp16.h>
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
......@@ -49,6 +50,9 @@ namespace relay {
} else if (type == Float(32)) { \
typedef float DType; \
{__VA_ARGS__} \
} else if (type == Float(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(64)) { \
typedef int64_t DType; \
{__VA_ARGS__} \
......@@ -204,7 +208,14 @@ template<typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) {
runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, {
if (dtype == Float(16)) {
// convert to float16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
} else {
*static_cast<DType*>(arr->data) = value;
}
})
return ConstantNode::make(arr);
}
......
......@@ -36,11 +36,13 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
Expr moving_mean,
Expr moving_var,
Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<BatchNormAttrs>();
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr var_add_eps = Add(moving_var, epsilon);
Expr sqrt_var = Sqrt(var_add_eps);
Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var);
Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);
if (param->scale) {
scale = Multiply(scale, gamma);
......@@ -52,8 +54,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
}
int axis = param->axis;
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
auto ndim = ttype->shape.size();
scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
......
......@@ -17,12 +17,12 @@
from tvm import relay as rly
from tvm.relay.ir_pass import simplify_inference, alpha_equal
def test_simplify_batchnorm():
def test_simplify_batchnorm(dtype='float32'):
def simple_bn(x, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, shape=None):
# expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
scale = rly.multiply(rly.const(1, 'float32') /
rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma)
scale = rly.multiply(rly.const(1, dtype) /
rly.sqrt(moving_var + rly.const(epsilon, dtype)), gamma)
shift = rly.add(
rly.multiply(rly.negative(moving_mean), scale), beta)
num_newaxis = len(shape) - (axis + 1)
......@@ -33,8 +33,8 @@ def test_simplify_batchnorm():
def check(dim, axis, nstep):
eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32')
ttype2 = rly.TensorType((10,), 'float32')
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
ttype2 = rly.TensorType((10,), dtype)
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
......@@ -43,10 +43,10 @@ def test_simplify_batchnorm():
y1, y2 = x, x
for _ in range(nstep):
y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'),
y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype),
gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = rly.nn.dropout(y1)
y2 = simple_bn(y2 + rly.const(1, 'float32'),
y2 = simple_bn(y2 + rly.const(1, dtype),
gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ttype1.shape)
y1 = rly.ir_pass.infer_type(y1)
......@@ -60,4 +60,5 @@ def test_simplify_batchnorm():
if __name__ == "__main__":
test_simplify_batchnorm()
test_simplify_batchnorm(dtype='float32')
test_simplify_batchnorm(dtype='float16')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment