Commit d08ec106 by Haichen Shen Committed by Wuwei Lin

[Fix] Fix a few bugs when dtype is fp16 (#4088)

* Fix layer norm for fp16

* [Fix] Fix arange for fp16

* [Fix] Fix mxnet frontend for fp16

* [Fix] Fix arange for fp16

* remove comments

* x

* fix nnvm
parent 9d5cba20
...@@ -615,12 +615,17 @@ def _mx_arange(inputs, attrs): ...@@ -615,12 +615,17 @@ def _mx_arange(inputs, attrs):
if attrs.get_int("repeat", 1) != 1: if attrs.get_int("repeat", 1) != 1:
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.') 'Attribute "repeat" is not supported in operator arange.')
new_attrs = {} dtype = attrs.get_str("dtype", "float32")
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
stop = attrs.get_str("stop", "None") stop = attrs.get_str("stop", "None")
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop)) if stop == "None":
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0)) stop = None
new_attrs["dtype"] = attrs.get_str("dtype", "float32") else:
stop = _expr.const(float(stop), dtype=dtype)
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0), dtype=dtype)
new_attrs["stop"] = stop
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0), dtype=dtype)
new_attrs["dtype"] = dtype
return _op.arange(**new_attrs) return _op.arange(**new_attrs)
...@@ -863,7 +868,8 @@ def _mx_contrib_div_sqrt_dim(inputs, _): ...@@ -863,7 +868,8 @@ def _mx_contrib_div_sqrt_dim(inputs, _):
assert len(inputs) == 1 assert len(inputs) == 1
ndim = len(_infer_type(inputs[0]).checked_type.shape) ndim = len(_infer_type(inputs[0]).checked_type.shape)
dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32")) dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
sqrt_dim = _op.sqrt(dim.astype('float32')) dtype = _infer_type(inputs[0]).checked_type.dtype
sqrt_dim = _op.sqrt(dim.astype(dtype))
out = inputs[0] / sqrt_dim out = inputs[0] / sqrt_dim
return out return out
......
...@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs ...@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .common import get_relay_op from .common import get_relay_op
from .common import infer_type as _infer_type
def _warn_not_used(attr, op='nnvm'): def _warn_not_used(attr, op='nnvm'):
import warnings import warnings
...@@ -123,20 +124,22 @@ def _elemwise_sum(inputs, _, _dtype='float32'): ...@@ -123,20 +124,22 @@ def _elemwise_sum(inputs, _, _dtype='float32'):
def _binop_scalar(new_op): def _binop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'): def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1 assert len(inputs) == 1
scalar = attrs.get_float("scalar") scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype) scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar) return new_op(inputs[0], scalar)
return _impl return _impl
def _rbinop_scalar(new_op): def _rbinop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'): def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1 assert len(inputs) == 1
scalar = attrs.get_float("scalar") scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype) scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0]) return new_op(scalar, inputs[0])
return _impl return _impl
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
#include <tvm/runtime/packed_func.h>
#include <topi/transform.h> #include <topi/transform.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/broadcast.h> #include <topi/broadcast.h>
...@@ -1139,11 +1140,41 @@ and type as the input array. ...@@ -1139,11 +1140,41 @@ and type as the input array.
TVM_REGISTER_NODE_TYPE(ArangeAttrs); TVM_REGISTER_NODE_TYPE(ArangeAttrs);
double ToScalar(const runtime::NDArray& array) { double ToScalar(const runtime::NDArray& array) {
if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) { if (array->dtype.code == kDLInt) {
return reinterpret_cast<int32_t*>(array->data)[0]; if (array->dtype.bits == 8) {
} else { return reinterpret_cast<int8_t*>(array->data)[0];
return reinterpret_cast<float*>(array->data)[0]; } else if (array->dtype.bits == 16) {
return reinterpret_cast<int16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<int64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<uint8_t*>(array->data)[0];
} else if (array->dtype.bits == 16) {
return reinterpret_cast<uint16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<uint32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<uint64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLFloat) {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (array->dtype.bits == 16) {
return reinterpret_cast<__fp16*>(array->data)[0];
}
#endif
if (array->dtype.bits == 32) {
return reinterpret_cast<float*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<double*>(array->data)[0];
}
} }
LOG(FATAL) << "Unknown data type: " << tvm::runtime::TVMType2String(array->dtype);
// make compiler happy
return -std::numeric_limits<double>::infinity();
} }
bool ArangeRel(const Array<Type>& types, bool ArangeRel(const Array<Type>& types,
......
...@@ -75,7 +75,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs, ...@@ -75,7 +75,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
const auto param = attrs.as<LayerNormAttrs>(); const auto param = attrs.as<LayerNormAttrs>();
CHECK(param); CHECK(param);
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon)); Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr mean = Mean(data, {param->axis}, true, false); Expr mean = Mean(data, {param->axis}, true, false);
Expr var = Variance(data, mean, {param->axis}, true, false); Expr var = Variance(data, mean, {param->axis}, true, false);
Expr denom = Sqrt(Add(var, epsilon)); Expr denom = Sqrt(Add(var, epsilon));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment