# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Definition of vision ops"""
from __future__ import absolute_import

import topi
from topi.util import get_const_int, get_const_float, get_float_tuple
from .. import op as reg
from ..op import OpPattern


@reg.register_schedule("vision.multibox_prior")
def schedule_multibox_prior(_, outs, target):
    """Schedule definition of multibox_prior"""
    with target:
        return topi.generic.schedule_multibox_prior(outs)


@reg.register_compute("vision.multibox_prior")
def compute_multibox_prior(attrs, inputs, _, target):
    """Compute definition of multibox_prior"""
    sizes = get_float_tuple(attrs.sizes)
    ratios = get_float_tuple(attrs.ratios)
    steps = get_float_tuple(attrs.steps)
    offsets = get_float_tuple(attrs.offsets)
    clip = bool(get_const_int(attrs.clip))
    return [
        topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, steps,
                                       offsets, clip)
    ]


reg.register_pattern("vision.multibox_prior", OpPattern.OPAQUE)


# multibox_transform_loc
@reg.register_schedule("vision.multibox_transform_loc")
def schedule_multibox_transform_loc(_, outs, target):
    """Schedule definition of multibox_detection"""
    with target:
        return topi.generic.schedule_multibox_transform_loc(outs)


@reg.register_compute("vision.multibox_transform_loc")
def compute_multibox_transform_loc(attrs, inputs, _, target):
    """Compute definition of multibox_detection"""
    clip = bool(get_const_int(attrs.clip))
    threshold = get_const_float(attrs.threshold)
    variances = get_float_tuple(attrs.variances)
    return topi.vision.ssd.multibox_transform_loc(
        inputs[0], inputs[1], inputs[2], clip, threshold, variances)


reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)


# Get counts of valid boxes
@reg.register_schedule("vision.get_valid_counts")
def schedule_get_valid_counts(_, outs, target):
    """Schedule definition of get_valid_counts"""
    with target:
        return topi.generic.schedule_get_valid_counts(outs)


@reg.register_compute("vision.get_valid_counts")
def compute_get_valid_counts(attrs, inputs, _, target):
    """Compute definition of get_valid_counts"""
    score_threshold = get_const_float(attrs.score_threshold)
    id_index = get_const_int(attrs.id_index)
    score_index = get_const_int(attrs.score_index)
    return topi.vision.get_valid_counts(inputs[0], score_threshold,
                                        id_index, score_index)

reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)


# non-maximum suppression
@reg.register_schedule("vision.non_max_suppression")
def schedule_nms(_, outs, target):
    """Schedule definition of nms"""
    with target:
        return topi.generic.schedule_nms(outs)


@reg.register_compute("vision.non_max_suppression")
def compute_nms(attrs, inputs, _, target):
    """Compute definition of nms"""
    return_indices = bool(get_const_int(attrs.return_indices))
    max_output_size = get_const_int(attrs.max_output_size)
    iou_threshold = get_const_float(attrs.iou_threshold)
    force_suppress = bool(get_const_int(attrs.force_suppress))
    top_k = get_const_int(attrs.top_k)
    coord_start = get_const_int(attrs.coord_start)
    score_index = get_const_int(attrs.score_index)
    id_index = get_const_int(attrs.id_index)
    invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
    return [
        topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
                                        iou_threshold, force_suppress, top_k,
                                        coord_start, score_index, id_index,
                                        return_indices, invalid_to_bottom)
    ]


reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)