/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/relay/ir/op_strategy.cc
 * \brief The Relay operator Strategy and related data structure.
 */

#include <tvm/relay/op_strategy.h>

namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(OpImplementationNode);
TVM_REGISTER_NODE_TYPE(OpSpecializationNode);
TVM_REGISTER_NODE_TYPE(OpStrategyNode);

Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs,
                                            const Array<te::Tensor>& inputs,
                                            const Type& out_type) {
  return (*this)->fcompute(attrs, inputs, out_type);
}

te::Schedule OpImplementation::Schedule(const Attrs& attrs,
                                        const Array<te::Tensor> &outs,
                                        const Target& target) {
  return (*this)->fschedule(attrs, outs, target);
}

void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
                                         tvm::relay::FTVMSchedule fschedule,
                                         std::string name,
                                         int plevel) {
  auto n = make_object<OpImplementationNode>();
  n->fcompute = fcompute;
  n->fschedule = fschedule;
  n->name = std::move(name);
  n->plevel = plevel;
  (*this)->implementations.push_back(OpImplementation(n));
}

void OpStrategy::AddImplementation(FTVMCompute fcompute,
                                   FTVMSchedule fschedule,
                                   std::string name,
                                   int plevel) {
  auto curr_cond = te::SpecializedCondition::Current();
  auto self = this->operator->();
  Array<OpSpecialization> specializations = self->specializations;
  OpSpecialization op_spec;
  for (OpSpecialization op_spec : specializations) {
    if (op_spec->condition == curr_cond) {
      op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
      return;
    }
  }
  ObjectPtr<OpSpecializationNode> n = make_object<OpSpecializationNode>();
  n->condition = curr_cond;
  op_spec = OpSpecialization(n);
  op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
  self->specializations.push_back(op_spec);
}

TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    OpImplementation imp = args[0];
    Attrs attrs = args[1];
    Array<te::Tensor> inputs = args[2];
    Type out_type = args[3];
    *rv = imp.Compute(attrs, inputs, out_type);
});

TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    OpImplementation imp = args[0];
    Attrs attrs = args[1];
    Array<te::Tensor> outs = args[2];
    Target target = args[3];
    *rv = imp.Schedule(attrs, outs, target);
});

TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>();
    *rv = OpStrategy(n);
});

TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    OpStrategy strategy = args[0];
    FTVMCompute compute = args[1];
    FTVMSchedule schedule = args[2];
    std::string name = args[3];
    int plevel = args[4];
    strategy.AddImplementation(compute, schedule, name, plevel);
});

}  // namespace relay
}  // namespace tvm