Commit d08ec106 by Haichen Shen Committed by Wuwei Lin

[Fix] Fix a few bugs when dtype is fp16 (#4088)

* Fix layer norm for fp16

* [Fix] Fix arange for fp16

* [Fix] Fix mxnet frontend for fp16

* [Fix] Fix arange for fp16

* remove comments

* x

* fix nnvm
parent 9d5cba20
......@@ -615,12 +615,17 @@ def _mx_arange(inputs, attrs):
if attrs.get_int("repeat", 1) != 1:
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
dtype = attrs.get_str("dtype", "float32")
stop = attrs.get_str("stop", "None")
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
if stop == "None":
stop = None
else:
stop = _expr.const(float(stop), dtype=dtype)
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0), dtype=dtype)
new_attrs["stop"] = stop
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0), dtype=dtype)
new_attrs["dtype"] = dtype
return _op.arange(**new_attrs)
......@@ -863,7 +868,8 @@ def _mx_contrib_div_sqrt_dim(inputs, _):
assert len(inputs) == 1
ndim = len(_infer_type(inputs[0]).checked_type.shape)
dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
sqrt_dim = _op.sqrt(dim.astype('float32'))
dtype = _infer_type(inputs[0]).checked_type.dtype
sqrt_dim = _op.sqrt(dim.astype(dtype))
out = inputs[0] / sqrt_dim
return out
......
......@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
from .. import expr as _expr
from .. import op as _op
from .common import get_relay_op
from .common import infer_type as _infer_type
def _warn_not_used(attr, op='nnvm'):
import warnings
......@@ -123,20 +124,22 @@ def _elemwise_sum(inputs, _, _dtype='float32'):
def _binop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'):
def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar)
return _impl
def _rbinop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'):
def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0])
return _impl
......
......@@ -28,6 +28,7 @@
#include <tvm/expr_operator.h>
#include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <tvm/runtime/packed_func.h>
#include <topi/transform.h>
#include <topi/elemwise.h>
#include <topi/broadcast.h>
......@@ -1139,11 +1140,41 @@ and type as the input array.
TVM_REGISTER_NODE_TYPE(ArangeAttrs);
double ToScalar(const runtime::NDArray& array) {
if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else {
return reinterpret_cast<float*>(array->data)[0];
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<int8_t*>(array->data)[0];
} else if (array->dtype.bits == 16) {
return reinterpret_cast<int16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<int64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<uint8_t*>(array->data)[0];
} else if (array->dtype.bits == 16) {
return reinterpret_cast<uint16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<uint32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<uint64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLFloat) {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (array->dtype.bits == 16) {
return reinterpret_cast<__fp16*>(array->data)[0];
}
#endif
if (array->dtype.bits == 32) {
return reinterpret_cast<float*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<double*>(array->data)[0];
}
}
LOG(FATAL) << "Unknown data type: " << tvm::runtime::TVMType2String(array->dtype);
// make compiler happy
return -std::numeric_limits<double>::infinity();
}
bool ArangeRel(const Array<Type>& types,
......
......@@ -75,7 +75,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
const auto param = attrs.as<LayerNormAttrs>();
CHECK(param);
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr mean = Mean(data, {param->axis}, true, false);
Expr var = Variance(data, mean, {param->axis}, true, false);
Expr denom = Sqrt(Add(var, epsilon));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment