Commit 63a91ebf by Xingjian Shi Committed by Yizhi Liu

Numpy compatible dtype inference for `tvm.convert` and `tvm.const` (#3861)

* numpy compatible type inference

* update

* try to fix

* fix

* try to fix

* fix lint

* Update nn.h

* cast to int32

* try to fix

* fix again

* retrigger ci
parent 2f5b155a
...@@ -30,6 +30,23 @@ def _set_class_node_base(cls): ...@@ -30,6 +30,23 @@ def _set_class_node_base(cls):
_CLASS_NODE_BASE = cls _CLASS_NODE_BASE = cls
def _scalar_type_inference(value):
if hasattr(value, 'dtype'):
dtype = str(value.dtype)
elif isinstance(value, bool):
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype
class NodeGeneric(object): class NodeGeneric(object):
"""Base class for all classes that can be converted to node.""" """Base class for all classes that can be converted to node."""
def asnode(self): def asnode(self):
...@@ -86,7 +103,7 @@ def const(value, dtype=None): ...@@ -86,7 +103,7 @@ def const(value, dtype=None):
value : int or float value : int or float
The input value The input value
dtype : str dtype : str or None, optional
The data type. The data type.
Returns Returns
...@@ -95,8 +112,5 @@ def const(value, dtype=None): ...@@ -95,8 +112,5 @@ def const(value, dtype=None):
Constant expression corresponds to the value. Constant expression corresponds to the value.
""" """
if dtype is None: if dtype is None:
if isinstance(value, Integral): dtype = _scalar_type_inference(value)
dtype = 'int32'
else:
dtype = 'float32'
return _api_internal._const(value, dtype) return _api_internal._const(value, dtype)
...@@ -23,6 +23,7 @@ from numbers import Integral as _Integral ...@@ -23,6 +23,7 @@ from numbers import Integral as _Integral
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
...@@ -73,7 +74,7 @@ def max_value(dtype): ...@@ -73,7 +74,7 @@ def max_value(dtype):
return _api_internal._max_value(dtype) return _api_internal._max_value(dtype)
def const(value, dtype): def const(value, dtype=None):
"""construct a constant """construct a constant
Parameters Parameters
...@@ -81,7 +82,7 @@ def const(value, dtype): ...@@ -81,7 +82,7 @@ def const(value, dtype):
value : number value : number
The content of the constant number. The content of the constant number.
dtype : str dtype : str or None, optional
The data type. The data type.
Returns Returns
...@@ -89,6 +90,8 @@ def const(value, dtype): ...@@ -89,6 +90,8 @@ def const(value, dtype):
const_val: tvm.Expr const_val: tvm.Expr
The result expression. The result expression.
""" """
if dtype is None:
dtype = _scalar_type_inference(value)
return _api_internal._const(value, dtype) return _api_internal._const(value, dtype)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import numpy as np
def test_const(): def test_const():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
...@@ -22,6 +23,22 @@ def test_const(): ...@@ -22,6 +23,22 @@ def test_const():
assert x.dtype == tvm.int32 assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.expr.IntImm)
def test_scalar_dtype_inference():
for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.const(data).dtype == str(np.array(data).dtype)
assert tvm.const(1).dtype == 'int32'
assert tvm.const(1.0).dtype == 'float32'
for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
np.int8(1), np.int16(1), np.int32(1), np.int64(1),
np.float16(1), np.float32(1), np.float64(1)]:
assert tvm.convert(data).dtype == str(np.array(data).dtype)
assert tvm.convert(1).dtype == 'int32'
assert tvm.convert(1.0).dtype == 'float32'
def test_make(): def test_make():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
y = tvm.var("x") y = tvm.var("x")
...@@ -175,6 +192,7 @@ if __name__ == "__main__": ...@@ -175,6 +192,7 @@ if __name__ == "__main__":
test_cast() test_cast()
test_attr() test_attr()
test_const() test_const()
test_scalar_dtype_inference()
test_make() test_make()
test_ir() test_ir()
test_basic() test_basic()
......
...@@ -97,8 +97,8 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input, ...@@ -97,8 +97,8 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
std::string tag = kInjective) { std::string tag = kInjective) {
Array<Expr> out_shape; Array<Expr> out_shape;
out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]); out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(shape[1]); out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]); out_shape.push_back(input->shape[3]);
return compute( return compute(
...@@ -132,8 +132,8 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, ...@@ -132,8 +132,8 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
Array<Expr> out_shape; Array<Expr> out_shape;
out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]); out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]); out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(shape[1]); out_shape.push_back(cast(Int(32), shape[1]));
return compute( return compute(
out_shape, [&](const Array<Var>& indices) { out_shape, [&](const Array<Var>& indices) {
...@@ -166,8 +166,8 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, ...@@ -166,8 +166,8 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
Array<Expr> out_shape; Array<Expr> out_shape;
out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]); out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]); out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(shape[1]); out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[4]); out_shape.push_back(input->shape[4]);
return compute( return compute(
...@@ -233,8 +233,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, ...@@ -233,8 +233,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
std::string tag = kInjective) { std::string tag = kInjective) {
Array<Expr> out_shape; Array<Expr> out_shape;
out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]); out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(shape[1]); out_shape.push_back(cast(Int(32), shape[1]));
out_shape.push_back(input->shape[3]); out_shape.push_back(input->shape[3]);
Expr cone = make_const(Int(32), 1); Expr cone = make_const(Int(32), 1);
...@@ -311,8 +311,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, ...@@ -311,8 +311,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
Array<Expr> out_shape; Array<Expr> out_shape;
out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]); out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]); out_shape.push_back(cast(Int(32), shape[0]));
out_shape.push_back(shape[1]); out_shape.push_back(cast(Int(32), shape[1]));
Expr cone = make_const(Int(32), 1); Expr cone = make_const(Int(32), 1);
......
...@@ -182,12 +182,20 @@ inline tvm::Tensor pad(const tvm::Tensor& t, ...@@ -182,12 +182,20 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
CHECK_GE(pad_before.size(), 1); CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size()); CHECK_EQ(pad_before.size(), pad_after.size());
tvm::Array<tvm::Expr> output_shape; tvm::Array<tvm::Expr> output_shape;
tvm::Array<tvm::Expr> pad_before_int32;
tvm::Array<tvm::Expr> pad_after_int32;
for (const auto &ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::Int(32), ele));
}
for (const auto &ele : pad_after) {
pad_after_int32.push_back(tvm::cast(tvm::Int(32), ele));
}
for (size_t i = 0; i < t->shape.size(); ++i) { for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) { if (i >= pad_before.size()) {
output_shape.push_back(t->shape[i]); output_shape.push_back(t->shape[i]);
} else { } else {
output_shape.push_back( output_shape.push_back(
tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i])); tvm::ir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
} }
} }
...@@ -199,18 +207,18 @@ inline tvm::Tensor pad(const tvm::Tensor& t, ...@@ -199,18 +207,18 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
tvm::Array<tvm::Expr> indices; tvm::Array<tvm::Expr> indices;
tvm::Array<tvm::Expr> sel; tvm::Array<tvm::Expr> sel;
for (size_t i = 0; i < t->shape.size(); ++i) { for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) { if (i >= pad_before_int32.size()) {
indices.push_back(ovars[i]); indices.push_back(ovars[i]);
continue; continue;
} }
if (!topi::detail::EqualCheck(pad_before[i], 0)) { if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
sel.push_back(ovars[i] >= pad_before[i]); sel.push_back(ovars[i] >= pad_before_int32[i]);
indices.push_back(ovars[i] - pad_before[i]); indices.push_back(ovars[i] - pad_before_int32[i]);
} else { } else {
indices.push_back(ovars[i]); indices.push_back(ovars[i]);
} }
if (!topi::detail::EqualCheck(pad_after[i], 0)) { if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i])); sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
} }
} }
if (sel.size() != 0) { if (sel.size() != 0) {
......
...@@ -77,7 +77,7 @@ inline Tensor dilate(const Tensor& x, ...@@ -77,7 +77,7 @@ inline Tensor dilate(const Tensor& x,
Array<Expr> out_shape; Array<Expr> out_shape;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
out_shape.push_back(tvm::ir::Simplify( out_shape.push_back(tvm::ir::Simplify(
(x->shape[i] - 1) * strides[i] + 1)); (x->shape[i] - 1) * cast(Int(32), strides[i] + 1)));
} }
return tvm::compute( return tvm::compute(
......
...@@ -73,18 +73,18 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -73,18 +73,18 @@ inline Tensor pool_impl(const Tensor& x,
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
auto kernel_height = kernel_size[0]; auto kernel_height = cast(Int(32), kernel_size[0]);
auto kernel_width = kernel_size[1]; auto kernel_width = cast(Int(32), kernel_size[1]);
auto stride_height = stride_size[0]; auto stride_height = cast(Int(32), stride_size[0]);
auto stride_width = stride_size[1]; auto stride_width = cast(Int(32), stride_size[1]);
auto height = x->shape[height_axis]; auto height = x->shape[height_axis];
auto width = x->shape[width_axis]; auto width = x->shape[width_axis];
auto pad_top = padding_size[0]; auto pad_top = cast(Int(32), padding_size[0]);
auto pad_left = padding_size[1]; auto pad_left = cast(Int(32), padding_size[1]);
auto pad_bottom = padding_size[2]; auto pad_bottom = cast(Int(32), padding_size[2]);
auto pad_right = padding_size[3]; auto pad_right = cast(Int(32), padding_size[3]);
if (ceil_mode) { if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when // Additional padding to ensure we do ceil instead of floor when
...@@ -179,18 +179,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, ...@@ -179,18 +179,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
auto kernel_height = kernel_size[0]; auto kernel_height = cast(Int(32), kernel_size[0]);
auto kernel_width = kernel_size[1]; auto kernel_width = cast(Int(32), kernel_size[1]);
auto stride_height = stride_size[0]; auto stride_height = cast(Int(32), stride_size[0]);
auto stride_width = stride_size[1]; auto stride_width = cast(Int(32), stride_size[1]);
auto height = x->shape[height_axis]; auto height = x->shape[height_axis];
auto width = x->shape[width_axis]; auto width = x->shape[width_axis];
auto pad_top = padding_size[0]; auto pad_top = cast(Int(32), padding_size[0]);
auto pad_left = padding_size[1]; auto pad_left = cast(Int(32), padding_size[1]);
auto pad_bottom = padding_size[2]; auto pad_bottom = cast(Int(32), padding_size[2]);
auto pad_right = padding_size[3]; auto pad_right = cast(Int(32), padding_size[3]);
if (ceil_mode) { if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when // Additional padding to ensure we do ceil instead of floor when
...@@ -471,8 +471,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x, ...@@ -471,8 +471,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
auto height = x->shape[height_axis]; auto height = x->shape[height_axis];
auto width = x->shape[width_axis]; auto width = x->shape[width_axis];
auto out_height = output_size[0]; auto out_height = cast(Int(32), output_size[0]);
auto out_width = output_size[1]; auto out_width = cast(Int(32), output_size[1]);
Array<Expr> out_shape = x->shape; Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, out_height); out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width); out_shape.Set(width_axis, out_width);
......
...@@ -208,9 +208,14 @@ inline Tensor reshape(const Tensor& x, ...@@ -208,9 +208,14 @@ inline Tensor reshape(const Tensor& x,
std::string name = "T_reshape", std::string name = "T_reshape",
std::string tag = kInjective) { std::string tag = kInjective) {
auto x_shape = x->shape; auto x_shape = x->shape;
Array<Expr> newshape_int32;
for (const auto &ele : newshape) {
newshape_int32.push_back(cast(Int(32), ele));
}
return compute( return compute(
newshape, [&](const Array<Var>& indices) { newshape_int32, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape), return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
x_shape)); x_shape));
}, name, tag); }, name, tag);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment