# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" # pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks from __future__ import absolute_import import tvm import topi from topi.util import get_const_int, get_const_tuple from . import op as _reg from ._reduce import _schedule_reduce from .op import OpPattern from ...hybrid import script from ...api import convert schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective schedule_concatenate = _reg.schedule_concatenate _reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_schedule("broadcast_to", schedule_broadcast) _reg.register_schedule("broadcast_to_like", schedule_broadcast) _reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("reshape", schedule_injective) _reg.register_schedule("reshape_like", schedule_injective) _reg.register_schedule("full", schedule_injective) _reg.register_schedule("full_like", schedule_injective) _reg.register_schedule("arange", schedule_injective) _reg.register_schedule("reverse", schedule_injective) _reg.register_schedule("repeat", schedule_broadcast) _reg.register_schedule("tile", schedule_broadcast) _reg.register_schedule("cast", schedule_injective) _reg.register_schedule("cast_like", schedule_injective) _reg.register_schedule("reinterpret", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("split", schedule_injective) _reg.register_schedule("take", schedule_injective) _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) _reg.register_schedule("stack", schedule_injective) _reg.register_schedule("concatenate", schedule_concatenate) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective) _reg.register_schedule("sequence_mask", schedule_injective) _reg.register_schedule("one_hot", schedule_injective) # layout_transform _reg.register_schedule("layout_transform", schedule_injective) _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) # shape func @script def _arange_shape_func(start, stop, step): out = output_tensor((1,), "int64") out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0]))) return out @_reg.register_shape_func("arange", True) def arange_shape_func(attrs, inputs, _): return [_arange_shape_func(*inputs)] @script def _concatenate_shape_func(inputs, axis): ndim = inputs[0].shape[0] out = output_tensor((ndim,), "int64") for i in const_range(ndim): if i != axis: out[i] = inputs[0][i] for j in const_range(1, len(inputs)): assert out[i] == inputs[j][i], \ "Dims mismatch in the inputs of concatenate." else: out[i] = int64(0) for j in const_range(len(inputs)): out[i] += inputs[j][i] return out @_reg.register_shape_func("concatenate", False) def concatenate_shape_func(attrs, inputs, _): axis = get_const_int(attrs.axis) return [_concatenate_shape_func(inputs, convert(axis))] @script def _reshape_shape_func(data_shape, newshape, ndim): out = output_tensor((ndim,), "int64") src_idx = 0 dst_idx = 0 infer_idx = -1 copy = False skip = 0 for i in const_range(len(newshape)): if skip > 0: skip -= 1 elif newshape[i] > 0: out[dst_idx] = int64(newshape[i]) src_idx += 1 dst_idx += 1 elif newshape[i] == 0: out[dst_idx] = data_shape[src_idx] src_idx += 1 dst_idx += 1 elif newshape[i] == -1: assert infer_idx < 0, "One and only one dim can be inferred" out[dst_idx] = int64(1) infer_idx = i dst_idx += 1 elif newshape[i] == -2: copy = True elif newshape[i] == -3: assert data_shape.shape[0] - src_idx > 1, \ "Not enough dims in input shape for -3" out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] src_idx += 2 dst_idx += 1 elif newshape[i] == -4: assert len(newshape) - i > 2, "Not enough dims in new shape for -4" if newshape[i+1] == -1: assert newshape[i+2] != -1, "Split dims cannot both be -1." out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2]) out[dst_idx+1] = int64(newshape[i+2]) else: out[dst_idx] = int64(newshape[i+1]) if newshape[i+2] == -1: out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1]) else: out[dst_idx+1] = int64(newshape[i+2]) assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ "Product of split dims doesn't match to input dim" src_idx += 1 dst_idx += 2 skip = 2 else: assert False, "Invalid special values in new shape" if len(data_shape.shape) > 0: # if data is not constant, we can then handle -1 and -2 if copy: for i in range(src_idx, data_shape.shape[0]): out[dst_idx] = data_shape[i] dst_idx += 1 if infer_idx >= 0: old_size = int64(1) for i in const_range(data_shape.shape[0]): old_size *= data_shape[i] new_size = int64(1) for i in const_range(out.shape[0]): new_size *= out[i] out[infer_idx] = old_size // new_size return out @_reg.register_shape_func("reshape", False) def reshape_shape_func(attrs, inputs, out_ndims): newshape = get_const_tuple(attrs.newshape) return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])] @script def _take_no_axis_shape_func(indices_shape, out_ndim): out = output_tensor((out_ndim,), "int64") for i in const_range(out_ndim): out[i] = indices_shape[i] return out @script def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim): out = output_tensor((out_ndim,), "int64") for i in const_range(axis): out[i] = data_shape[i] if len(indices_shape.shape) == 0: # indices is constant for i in const_range(axis+1, len(data_shape)): out[i-1] = data_shape[i] else: for i in const_range(len(indices_shape)): out[axis+i] = indices_shape[i] for i in const_range(axis+1, len(data_shape)): out[len(indices_shape)+i-1] = data_shape[i] return out @_reg.register_shape_func("take", False) def take_shape_func(attrs, inputs, out_ndims): """ Shape function for take op. """ if attrs.axis is None: return [_take_no_axis_shape_func(inputs[1], out_ndims[0])] else: axis = get_const_int(attrs.axis) data_ndim = int(inputs[0].shape[0]) if axis < 0: axis += data_ndim assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] @script def _argwhere_shape_func_1d(condition): out = output_tensor((2, ), "int64") out[0] = int64(0) out[1] = int64(1) for i1 in range(condition.shape[0]): if condition[i1] != 0: out[0] += int64(1) return out @script def _argwhere_shape_func_2d(condition): out = output_tensor((2, ), "int64") out[0] = int64(0) out[1] = int64(2) for i1 in range(condition.shape[0]): for i2 in range(condition.shape[1]): if condition[i1, i2] != 0: out[0] += int64(1) return out @script def _argwhere_shape_func_3d(condition): out = output_tensor((2, ), "int64") out[0] = int64(0) out[1] = int64(3) for i1 in range(condition.shape[0]): for i2 in range(condition.shape[1]): for i3 in range(condition.shape[2]): if condition[i1, i2, i3] != 0: out[0] += int64(1) return out @script def _argwhere_shape_func_4d(condition): out = output_tensor((2, ), "int64") out[0] = int64(0) out[1] = int64(4) for i1 in range(condition.shape[0]): for i2 in range(condition.shape[1]): for i3 in range(condition.shape[2]): for i4 in range(condition.shape[3]): if condition[i1, i2, i3, i4] != 0: out[0] += int64(1) return out @script def _argwhere_shape_func_5d(condition): out = output_tensor((2, ), "int64") out[0] = int64(0) out[1] = int64(5) for i1 in range(condition.shape[0]): for i2 in range(condition.shape[1]): for i3 in range(condition.shape[2]): for i4 in range(condition.shape[3]): for i5 in range(condition.shape[4]): if condition[i1, i2, i3, i4, i5] != 0: out[0] += int64(1) return out @_reg.register_shape_func("argwhere", True) def argwhere_shape_func(attrs, inputs, out_ndims): """ Shape function for argwhere. """ if len(inputs[0].shape) == 1: return [_argwhere_shape_func_1d(inputs[0])] elif len(inputs[0].shape) == 2: return [_argwhere_shape_func_2d(inputs[0])] elif len(inputs[0].shape) == 3: return [_argwhere_shape_func_3d(inputs[0])] elif len(inputs[0].shape) == 4: return [_argwhere_shape_func_4d(inputs[0])] elif len(inputs[0].shape) == 5: return [_argwhere_shape_func_5d(inputs[0])] return ValueError("Does not support rank higher than 5 in argwhere") @_reg.register_schedule("argwhere") def schedule_argwhere(_, outs, target): """Schedule definition of argwhere""" with target: return topi.generic.schedule_argwhere(outs) @_reg.register_compute("argwhere") def compute_argwhere(attrs, inputs, output_type, _): """Compute definition of argwhere""" output_shape = [] for s in output_type.shape: if hasattr(s, "value"): output_shape.append(s) else: # see Any, replace it with a var output_shape.append(tvm.var("any_dim", "int32")) new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") return [topi.argwhere(new_output_type, inputs[0])]