Commit 62521453 by Yao Wang Committed by Haichen Shen

Add More Shape Functions (#4179)

* Add shape functions

* Fix get_const_tuple

* Fix cpplint

* Fix pylint

* Fix pylint

* rebase and fix

* Check Any for infer type

* Fix expand_dim shape func for zero rank input

* Fix pooling infer type

* Address comment

* Register layout transform attr
parent 10b77ef3
......@@ -219,15 +219,15 @@ def args_to_workload(x, topi_compute_func=None):
workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float,, np.float)):
elif isinstance(x, (str, int, float,, np.float, expr.Var)):
workload = x
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types or tvm.expr.Var only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
def template(func):
......@@ -163,7 +163,7 @@ def get_const_int(exp):
def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
"""Verifies input tuple is IntImm or Var, returns tuple of int or Var.
......@@ -175,4 +175,14 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int
The output.
return tuple(get_const_int(x) for x in in_tuple)
ret = []
for elem in in_tuple:
if isinstance(elem, expr.Var):
elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
elem = ir_pass.Simplify(elem)
if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
return tuple(ret)
......@@ -18,7 +18,11 @@
from __future__ import absolute_import
import topi
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ...api import convert
from ...hybrid import script
def _schedule_reduce(_, outs, target):
......@@ -39,3 +43,67 @@ _reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
def _create_axis_record(attrs, inputs):
axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
exclude = get_const_int(attrs.exclude) > 0
keepdims = get_const_int(attrs.keepdims) > 0
data_shape = inputs[0]
shape_size = data_shape.shape[0].value
axis_record = [-1] * shape_size
if axes is None:
axes = list(range(shape_size))
for i, axis in enumerate(axes):
if axis < 0:
axes[i] = shape_size + axis
if exclude:
ex_axes = []
for i in range(shape_size):
if i not in axes:
axes = ex_axes
for i in range(shape_size):
if i not in axes:
axis_record[i] = i
if not keepdims:
tmp = []
for i in axis_record:
if i >= 0:
axis_record = tmp
return axis_record
def _reduce_shape_func(data_shape, axis_record):
out = output_tensor((len(axis_record),), "int64")
for i in const_range(len(axis_record)):
if axis_record[i] >= 0:
out[i] = data_shape[axis_record[i]]
out[i] = int64(1)
return out
def reduce_shape_func(attrs, inputs, _):
Shape function for reduce op.
axis_record = _create_axis_record(attrs, inputs)
return [_reduce_shape_func(inputs[0], convert(axis_record))]
_reg.register_shape_func("argmax", False, reduce_shape_func)
_reg.register_shape_func("argmin", False, reduce_shape_func)
_reg.register_shape_func("all", False, reduce_shape_func)
_reg.register_shape_func("sum", False, reduce_shape_func)
_reg.register_shape_func("max", False, reduce_shape_func)
_reg.register_shape_func("min", False, reduce_shape_func)
_reg.register_shape_func("prod", False, reduce_shape_func)
_reg.register_shape_func("mean", False, reduce_shape_func)
_reg.register_shape_func("variance", False, reduce_shape_func)
......@@ -119,18 +119,6 @@ def _cast_shape_function(x):
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out
def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]
# shape func
def _broadcast_shape_func(x, y, ndim):
......@@ -161,9 +149,17 @@ def _broadcast_shape_func(x, y, ndim):
return out
def broadcast_shape_func(attrs, inputs, out_ndims):
Shape function for broadcast op.
return [_broadcast_shape_func(*inputs, out_ndims[0])]
register_shape_func("expand_dims", False, expand_dims_shape_func)
def elemwise_shape_func(attrs, inputs, _):
Shape function for elemwise op.
return [topi.math.identity(inputs[0])]
register_shape_func("cast", False, cast_shape_func)
register_shape_func("add", False, broadcast_shape_func)
......@@ -179,3 +175,6 @@ register_shape_func("less", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
import tvm
import topi
......@@ -303,3 +303,195 @@ def compute_argwhere(attrs, inputs, output_type, _):
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
def _layout_transform_shape_func(data_shape,
out = output_tensor((out_layout_len,), "int64")
for i in const_range(len(dst_equal_list)):
out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]]
for i in const_range(len(dst_mul_list)):
out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * \
for i in const_range(len(dst_div_list)):
out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] \
// dst_div_list[i][3]
out[dst_div_list[i][2]] = int64(dst_div_list[i][3])
for i in const_range(len(dst_mix_list)):
out[dst_mix_list[i][0]] = data_shape[dst_mix_list[i][1]] * \
dst_mix_list[i][2] // dst_mix_list[i][4]
out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4])
return out
@_reg.register_shape_func("layout_transform", False)
def layout_transform_shape_func(attrs, inputs, _):
Shape function for layout_transform op.
def _fetch_axis(layout):
major_axes = []
minor_axes = {}
num_start = -1
for i, item in enumerate(layout):
if "A" <= item <= "Z":
elif "a" <= item <= "z":
last_num = int(layout[num_start:i])
minor_axes[item] = last_num
num_start = -1
elif num_start < 0:
num_start = i
return major_axes, minor_axes
_, src_minor_axes = _fetch_axis(attrs.src_layout)
dst_major_axes, dst_minor_axes = _fetch_axis(attrs.dst_layout)
src_letter_list = []
dst_letter_list = []
for item in attrs.src_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
for item in attrs.dst_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
out_layout_len = len(dst_major_axes) + len(dst_minor_axes)
dst_equal_list = []
dst_mul_list = []
dst_div_list = []
dst_mix_list = []
for key in dst_major_axes:
if key.lower() not in dst_minor_axes:
if key.lower() not in src_minor_axes:
if key.lower() not in src_minor_axes:
return [_layout_transform_shape_func(inputs[0],
def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
out = output_tensor((ndim + num_newaxis,), "int64")
for i in const_range(out.shape[0]):
if i < axis:
out[i] = data_shape[i]
elif i < axis + num_newaxis:
out[i] = int64(1)
out[i] = data_shape[i - num_newaxis]
return out
@_reg.register_shape_func("expand_dims", False)
def expand_dim_shape_func(attrs, inputs, _):
Shape function for expand_dim op.
axis = get_const_int(attrs.axis)
num_newaxis = get_const_int(attrs.num_newaxis)
if axis < 0:
axis = inputs[0].shape[0] + axis + 1
ndim = inputs[0].shape[0] if inputs[0].shape else 0
return [_expand_dim_shape_func(inputs[0],
def _transpose_shape_func(data_shape, axes):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(len(axes)):
out[i] = data_shape[axes[i]]
return out
@_reg.register_shape_func("transpose", False)
def transpose_shape_func(attrs, inputs, _):
Shape function for transpose op.
axes = attrs.axes if attrs.axes is None else get_const_tuple(attrs.axes)
if axes is None:
axes = list(range(inputs[0].shape[0].value))
for i, axis in enumerate(axes):
if axis < 0:
axes[i] = inputs[0].shape[0] - axis
return [_transpose_shape_func(inputs[0], convert(axes))]
def _squeeze_shape_func(data_shape, keep_axes):
out = output_tensor((len(keep_axes),), "int64")
if len(keep_axes) == 0:
out_size = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out_size += 1
if out_size == 0:
out_size = 1
out = output_tensor((out_size,), "int64")
out[0] = int64(1)
pos = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out[pos] = data_shape[i]
pos += 1
for i in const_range(len(keep_axes)):
out[i] = data_shape[keep_axes[i]]
return out
@_reg.register_shape_func("squeeze", False)
def squeeze_shape_func(attrs, inputs, _):
Shape function for squeeze op.
axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
keep_axes = []
if axis is not None:
for i in range(inputs[0].shape[0].value):
if i not in axis:
return [_squeeze_shape_func(inputs[0], convert(keep_axes))]
def _reshape_like_shape_func(target_shape):
out = output_tensor((target_shape.shape[0],), "int64")
for i in const_range(target_shape.shape[0]):
out[i] = target_shape[i]
return out
@_reg.register_shape_func("reshape_like", False)
def reshape_like_shape_func(attrs, inputs, _):
Shape function for reshape_like op.
return [_reshape_like_shape_func(inputs[1])]
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
# pylint: disable=invalid-name, unused-argument, too-many-arguments
"""Backend compiler related feature registration"""
from __future__ import absolute_import
......@@ -22,6 +22,9 @@ import topi
from topi.util import get_const_tuple
from .. import op as reg
from ..op import OpPattern, schedule_injective
from .._tensor import elemwise_shape_func
from ....api import convert
from ....hybrid import script
# relu
reg.register_schedule("nn.relu", schedule_injective)
......@@ -766,7 +769,6 @@ reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
......@@ -775,8 +777,170 @@ def compute_cross_entropy(attrs, inputs, out_dtype, target):
reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(x * y) / x.shape[0]]
# shape func
def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
out = output_tensor((dshape.shape[0],), "int64")
ic_chunk = dshape[1]
height = dshape[2]
width = dshape[3]
ic_bn = dshape[4]
kheight = kshape[2]
kwidth = kshape[3]
dilated_kh = (kheight - 1) * dilation[0] + 1
dilated_kw = (kwidth - 1) * dilation[1] + 1
kflatten = int64(1)
for i in const_range(kshape.shape[0]):
kflatten *= kshape[i]
oc = kflatten // (kheight * kwidth * ic_chunk * ic_bn)
oc_chunk = oc // oc_bn
out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1
out[0] = dshape[0]
out[1] = oc_chunk
out[2] = out_height
out[3] = out_width
out[4] = int64(oc_bn)
return out
@reg.register_shape_func("nn.contrib_conv2d_NCHWc", False)
def conv2d_NCHWc_shape_func(attrs, inputs, _):
Shape function for contrib_conv2d_NCHWc op.
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
out_layout = attrs.out_layout
oc_bn = int(out_layout[4:-1])
return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1],
convert(strides), convert(padding),
convert(dilation), convert(oc_bn))]
def _pool2d_shape_func(data_shape, pool_size, strides,
padding, height_axis, width_axis):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(data_shape.shape[0]):
if i == height_axis:
out[i] = (data_shape[i] + padding[0] + padding[2] - pool_size[0]) // strides[0] + 1
elif i == width_axis:
out[i] = (data_shape[i] + padding[1] + padding[3] - pool_size[1]) // strides[1] + 1
out[i] = data_shape[i]
return out
def pool2d_shape_func(attrs, inputs, _):
Shape function for pool2d op.
pool_size = get_const_tuple(attrs.pool_size)
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
layout = attrs.layout
height_axis = layout.index("H")
width_axis = layout.index("W")
if len(padding) == 1:
padding = [padding[0]] * 4
elif len(padding) == 2:
padding = [padding[0], padding[1], padding[0], padding[1]]
return [_pool2d_shape_func(inputs[0], convert(pool_size),
convert(strides), convert(padding),
convert(height_axis), convert(width_axis))]
reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func)
reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func)
def _global_pool2d_shape_func(data_shape, height_axis, width_axis):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0]):
if i == height_axis or i == width_axis:
out[i] = int64(1)
out[i] = data_shape[i]
return out
def global_pool2d_shape_func(attrs, inputs, _):
Shape function for global pool2d op.
layout = attrs.layout
height_axis = width_axis = 1
for i, letter in enumerate(layout):
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))]
reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func)
reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func)
def _batch_flatten_shape_func(data_shape):
out = output_tensor((2,), "int64")
out[0] = data_shape[0]
out[1] = int64(1)
for i in const_range(data_shape.shape[0] - 1):
out[1] *= data_shape[i + 1]
return out
@reg.register_shape_func("nn.batch_flatten", False)
def batch_flatten_shape_func(attrs, inputs, _):
Shape function for batch_flatten op.
return [_batch_flatten_shape_func(inputs[0])]
def _dense_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0]
return out
@reg.register_shape_func("nn.dense", False)
def dense_shape_func(attrs, inputs, _):
Shape function for dense op.
ret = [_dense_shape_func(inputs[0], inputs[1])]
return ret
def _pad_shape_func(data_shape, pad_width):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0]):
out[i] = data_shape[i] + pad_width[i][0] + pad_width[i][1]
return out
@reg.register_shape_func("nn.pad", False)
def pad_shape_func(attrs, inputs, _):
Shape function for pad op.
pad_width = []
for pair in attrs.pad_width:
return [_pad_shape_func(inputs[0], convert(pad_width))]
reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
reg.register_shape_func("nn.relu", False, elemwise_shape_func)
......@@ -289,16 +289,22 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
std::unordered_map<const Variable*, Expr> bind_map;
std::unordered_set<size_t> symbolic_var_set;
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
if (<ir::Any>()) {
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
const auto* orig_shape_const =<IntImm>();
const auto* orig_axis_extent = orig_axis->dom-><IntImm>();
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
<< "Input shape mismatch at index " << i << ". Expected "
<< orig_axis->dom->extent << ", get " << orig_shape;
if (orig_shape_const) {
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
<< "Input shape mismatch at index " << i << ". Expected "
<< orig_axis->dom->extent << ", get " << orig_shape;
bind_map[orig_axis->var.get()] = Expr(0);
} else {
......@@ -316,7 +322,11 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
if (!LayoutAxis::Get(axis).IsPrimal()) {
} else {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
if (symbolic_var_set.count(i)) {
} else {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
return result;
......@@ -330,8 +330,19 @@ bool Conv2DWinogradRel(const Array<Type>& types,
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
if (!dshape_nchw[2].as<ir::Any>()) {
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2
- dilated_ksize_y) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
if (!dshape_nchw[3].as<ir::Any>()) {
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2
- dilated_ksize_x) / param->strides[1] + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
......@@ -116,10 +116,19 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
param->strides[0]) + 1);
oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
if (!dshape_nchw[2].as<ir::Any>()) {
oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
if (!dshape_nchw[3].as<ir::Any>()) {
oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
......@@ -408,7 +408,12 @@ bool BatchFlattenRel(const Array<Type>& types,
auto target_dim = make_const(Int(32), 1);
for (uint32_t i = 1; i < data->shape.size(); ++i) {
target_dim = target_dim * data->shape[i];
if (!data->shape[i].as<ir::Any>()) {
target_dim = target_dim * data->shape[i];
} else {
target_dim = data->shape[i];
std::vector<IndexExpr> oshape({data->shape[0], target_dim});
......@@ -148,8 +148,12 @@ bool PadRel(const Array<Type>& types,
<< "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width2 << ".";
auto padding = make_const(data->shape[i].type(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
if (!data->shape[i].as<ir::Any>()) {
auto padding = make_const(data->shape[i].type(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
} else {
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
......@@ -102,14 +102,25 @@ bool Pool2DRel(const Array<Type>& types,
if (param->ceil_mode) {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
param->strides[0] - 1) / param->strides[0]) + 1;
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
param->strides[1] - 1) / param->strides[1]) + 1;
if (dshape[hidx].as<ir::Any>()) {
oshape[hidx] = dshape[hidx];
} else {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
if (param->ceil_mode) {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
param->strides[0] - 1) / param->strides[0]) + 1;
} else {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
if (dshape[widx].as<ir::Any>()) {
oshape[widx] = dshape[widx];
} else {
if (param->ceil_mode) {
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
param->strides[1] - 1) / param->strides[1]) + 1;
} else {
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
// assign output type
......@@ -211,11 +211,20 @@ inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr> &in_s
auto max_shape = make_const(Int(64), 1);
bool is_dynamic_input = false;
for (int64_t axis : r_axes) {
max_shape *= in_shape[axis];
if (in_shape[axis].as<IntImm>()) {
max_shape *= in_shape[axis];
} else {
is_dynamic_input = true;
if (is_dynamic_input) {
CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits<int32_t>::max())))
<< "The maximum possible index of reduced shape cannot be more than int32 max.";
CHECK(reporter->Assert(max_shape < make_const(Int(64), std::numeric_limits<int32_t>::max())))
<< "The maximum possible index of reduced shape cannot be more than int32 max.";
if (param->keepdims) {
std::vector<IndexExpr> oshape(in_shape);
......@@ -797,8 +797,18 @@ bool ReshapeLikeRel(const Array<Type>& types,
if (reshape_like == nullptr) {
return false;
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
<< "Reshape inputs size should be compatible.";
// Only check When input data has static shape.
bool is_static_shape = true;
for (size_t i = 0; i < data->shape.size(); ++i) {
if (!data->shape[i].as<IntImm>()) {
is_static_shape = false;
if (is_static_shape) {
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
<< "Reshape inputs size should be compatible.";
reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
return true;
......@@ -2292,6 +2302,8 @@ RELAY_REGISTER_OP("slice_like")
.set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.layout_transform
Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
......@@ -52,9 +52,9 @@ inline Tensor flatten(const Tensor& x,
std::string name = "tensor",
std::string tag = kInjective) {
auto ishape = x->shape;
int dim = 1;
Expr dim = 1;
for (size_t i = 1; i < ishape.size(); ++i) {
dim = dim * static_cast<int>(topi::detail::GetConstInt(ishape[i]));
dim = dim * ishape[i];
Array<Expr> oshape({ ishape[0], dim });
......@@ -144,7 +144,7 @@ def equal_const_int(expr, value):
def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
"""Verifies input tuple is IntImm or Var, returns tuple of int or Var.
......@@ -156,7 +156,17 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int
The output.
return tuple(get_const_int(elem) for elem in in_tuple)
ret = []
for elem in in_tuple:
if isinstance(elem, tvm.expr.Var):
elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)):
elem = tvm.ir_pass.Simplify(elem)
if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
return tuple(ret)
def get_float_tuple(in_tuple):
......@@ -41,6 +41,13 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
Get default schedule config for the workload
static_data_shape = []
for dim in get_const_tuple(data.shape):
if isinstance(dim, tvm.expr.Var):
data = tvm.placeholder(static_data_shape, dtype=data.dtype)
if is_depthwise:
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
from .depthwise_conv2d import _fallback_schedule
......@@ -37,6 +37,12 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
return C
M, _ = get_const_tuple(data.shape)
# Always use dense_nopack for dynamic input.
# This is a temporary for CV models.
# TODO(kevinthesun): use kernel dispatcher instead.
if isinstance(M, tvm.expr.Var):
return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
# For small batch sizes, don't pack weight into cache-friendly layout
# because of overhead in packing and limited reuse from batch dimension
# TODO(icemelon9): use a more systematic way to determine which schedule to use
......@@ -53,9 +59,9 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
M, K = get_const_tuple(data.shape) # batch, in_dim
N, _ = get_const_tuple(weight.shape) # out_dim
# create tuning space
cfg.define_split("tile_y", M, num_outputs=3)
cfg.define_split("tile_x", N, num_outputs=3)
cfg.define_split("tile_k", K, num_outputs=2)
cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, M, N, K)
......@@ -87,9 +93,9 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", M, num_outputs=2)
cfg.define_split("tile_x", N, num_outputs=2)
cfg.define_split("tile_k", K, num_outputs=2)
cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, M, N, K)
......@@ -211,8 +217,15 @@ def _schedule_dense_nopack_template(cfg, s, C):
def _default_dense_pack_config(cfg, M, N, K):
vec_width = get_fp32_len()
# Generate default schedule for dynamic shape.
if isinstance(M, tvm.expr.Var):
M = 16
if isinstance(N, tvm.expr.Var):
N = 16
if isinstance(K, tvm.expr.Var):
K = 16
vec_width = get_fp32_len()
tilex_ii = 1
for bn in range(vec_width*2, 0, -1):
if N % bn == 0:
......@@ -241,6 +254,14 @@ def _default_dense_pack_config(cfg, M, N, K):
def _default_dense_nopack_config(cfg, M, N, K):
# Generate default schedule for dynamic shape.
if isinstance(M, tvm.expr.Var):
M = 16
if isinstance(N, tvm.expr.Var):
N = 16
if isinstance(K, tvm.expr.Var):
K = 16
vec_width = get_fp32_len()
tilek_bn = 1
for bn in range(vec_width*2, 0, -1):
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment