Commit 63a91ebf by Xingjian Shi Committed by Yizhi Liu

Numpy compatible dtype inference for `tvm.convert` and `tvm.const` (#3861)

* numpy compatible type inference

* update

* try to fix

* fix

* try to fix

* fix lint

* Update nn.h

* cast to int32

* try to fix

* fix again

* retrigger ci
parent 2f5b155a
......@@ -30,6 +30,23 @@ def _set_class_node_base(cls):
_CLASS_NODE_BASE = cls
def _scalar_type_inference(value):
if hasattr(value, 'dtype'):
dtype = str(value.dtype)
elif isinstance(value, bool):
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
......@@ -86,7 +103,7 @@ def const(value, dtype=None):
value : int or float
The input value
dtype : str
dtype : str or None, optional
The data type.
Returns
......@@ -95,8 +112,5 @@ def const(value, dtype=None):
Constant expression corresponds to the value.
"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)
......@@ -23,6 +23,7 @@ from numbers import Integral as _Integral
from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
......@@ -73,7 +74,7 @@ def max_value(dtype):
return _api_internal._max_value(dtype)
def const(value, dtype):
def const(value, dtype=None):
"""construct a constant
Parameters
......@@ -81,7 +82,7 @@ def const(value, dtype):
value : number
The content of the constant number.
dtype : str
dtype : str or None, optional
The data type.
Returns
......@@ -89,6 +90,8 @@ def const(value, dtype):
const_val: tvm.Expr
The result expression.
"""
if dtype is None:
dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype)
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
def test_const():
x = tvm.const(1, "int32")
......@@ -22,6 +23,22 @@ def test_const():
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm)
def test_scalar_dtype_inference():
for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.const(data).dtype == str(np.array(data).dtype)
assert tvm.const(1).dtype == 'int32'
assert tvm.const(1.0).dtype == 'float32'
for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.convert(data).dtype == str(np.array(data).dtype)
assert tvm.convert(1).dtype == 'int32'
assert tvm.convert(1.0).dtype == 'float32'
def test_make():
x = tvm.const(1, "int32")
y = tvm.var("x")
......@@ -175,6 +192,7 @@ if __name__ == "__main__":
test_cast()
test_attr()
test_const()
test_scalar_dtype_inference()
test_make()
test_ir()
test_basic()
......
......@@ -97,8 +97,8 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]);
return compute(
......@@ -132,8 +132,8 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
return compute(
out_shape, [&](const Array<Var>& indices) {
......@@ -166,8 +166,8 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[4]);
return compute(
......@@ -233,8 +233,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]);
Expr cone = make_const(Int(32), 1);
......@@ -311,8 +311,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(cast(Int(32), shape[1]));
Expr cone = make_const(Int(32), 1);
......
......@@ -182,12 +182,20 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size());
tvm::Array<tvm::Expr> output_shape;
tvm::Array<tvm::Expr> pad_before_int32;
tvm::Array<tvm::Expr> pad_after_int32;
for (const auto &ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::Int(32), ele));
}
for (const auto &ele : pad_after) {
pad_after_int32.push_back(tvm::cast(tvm::Int(32), ele));
}
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
output_shape.push_back(t->shape[i]);
} else {
output_shape.push_back(
tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i]));
tvm::ir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
}
}
......@@ -199,18 +207,18 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
tvm::Array<tvm::Expr> indices;
tvm::Array<tvm::Expr> sel;
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
if (i >= pad_before_int32.size()) {
indices.push_back(ovars[i]);
continue;
}
if (!topi::detail::EqualCheck(pad_before[i], 0)) {
sel.push_back(ovars[i] >= pad_before[i]);
indices.push_back(ovars[i] - pad_before[i]);
if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
sel.push_back(ovars[i] >= pad_before_int32[i]);
indices.push_back(ovars[i] - pad_before_int32[i]);
} else {
indices.push_back(ovars[i]);
}
if (!topi::detail::EqualCheck(pad_after[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i]));
if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
}
}
if (sel.size() != 0) {
......
......@@ -77,7 +77,7 @@ inline Tensor dilate(const Tensor& x,
Array<Expr> out_shape;
for (size_t i = 0; i < n; ++i) {
out_shape.push_back(tvm::ir::Simplify(
(x->shape[i] - 1) * strides[i] + 1));
(x->shape[i] - 1) * cast(Int(32), strides[i] + 1)));
}
return tvm::compute(
......
......@@ -73,18 +73,18 @@ inline Tensor pool_impl(const Tensor& x,
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto kernel_height = cast(Int(32), kernel_size[0]);
auto kernel_width = cast(Int(32), kernel_size[1]);
auto stride_height = cast(Int(32), stride_size[0]);
auto stride_width = cast(Int(32), stride_size[1]);
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto pad_top = padding_size[0];
auto pad_left = padding_size[1];
auto pad_bottom = padding_size[2];
auto pad_right = padding_size[3];
auto pad_top = cast(Int(32), padding_size[0]);
auto pad_left = cast(Int(32), padding_size[1]);
auto pad_bottom = cast(Int(32), padding_size[2]);
auto pad_right = cast(Int(32), padding_size[3]);
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
......@@ -179,18 +179,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto kernel_height = cast(Int(32), kernel_size[0]);
auto kernel_width = cast(Int(32), kernel_size[1]);
auto stride_height = cast(Int(32), stride_size[0]);
auto stride_width = cast(Int(32), stride_size[1]);
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto pad_top = padding_size[0];
auto pad_left = padding_size[1];
auto pad_bottom = padding_size[2];
auto pad_right = padding_size[3];
auto pad_top = cast(Int(32), padding_size[0]);
auto pad_left = cast(Int(32), padding_size[1]);
auto pad_bottom = cast(Int(32), padding_size[2]);
auto pad_right = cast(Int(32), padding_size[3]);
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
......@@ -471,8 +471,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto out_height = output_size[0];
auto out_width = output_size[1];
auto out_height = cast(Int(32), output_size[0]);
auto out_width = cast(Int(32), output_size[1]);
Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
......
......@@ -208,9 +208,14 @@ inline Tensor reshape(const Tensor& x,
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
Array<Expr> newshape_int32;
for (const auto &ele : newshape) {
newshape_int32.push_back(cast(Int(32), ele));
}
return compute(
newshape, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape),
newshape_int32, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
x_shape));
}, name, tag);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment