# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. "Definition of classic algorithms" # pylint: disable=invalid-name,unused-argument from __future__ import absolute_import import topi from topi.util import get_const_int from ..op import OpPattern, register_compute, register_schedule, register_pattern @register_schedule("argsort") def schedule_argsort(_, outs, target): """Schedule definition of argsort""" with target: return topi.generic.schedule_argsort(outs) @register_compute("argsort") def compute_argsort(attrs, inputs, _, target): """Compute definition of argsort""" axis = get_const_int(attrs.axis) is_ascend = bool(get_const_int(attrs.is_ascend)) dtype = str(attrs.dtype) return [ topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \ dtype=dtype, flag=False) ] register_pattern("argsort", OpPattern.OPAQUE)