/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/relay/backend/contrib/codegen_c/codegen_c.h
 * \brief The base class for external codegen tools.
 */
#ifndef TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_
#define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/container.h>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {
namespace contrib {

struct Output {
  std::string name;
  std::string dtype;
  int size;
  bool need_copy;
};

class CSourceModuleCodegenBase {
 public:
  CSourceModuleCodegenBase() = default;
  virtual ~CSourceModuleCodegenBase() = default;

  /*!
   * \brief Create a runtime module for the external library. For example, it
   * could be a CSourceModule that can be directly compiled and linked together
   * with a DSOModule, or a json style module that emitts a json artifact that
   * is able to be executed by a customized json runtime.
   *
   * \param ref The ext_func Relay expression/module to be executed using extern ops.
   *
   * \return A runtime module.
   */
  virtual runtime::Module CreateCSourceModule(const ObjectRef& ref) = 0;

  /*!
   * \brief Get the external symbol of the Relay function name.
   *
   * \param func The provided function.
   *
   * \return An external symbol.
   */
  std::string GetExtSymbol(const Function& func) const {
    const auto name_node =
        func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
    CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
    return std::string(name_node);
  }
};

// The base class to generate the declaration functions in C.
class CodegenCBase {
 public:
  virtual ~CodegenCBase() {}

 protected:
  /*! \brief Print indents using spaces. */
  void PrintIndents() {
    for (int i = 0; i < indent_; i++) {
      code_stream_ << ' ';
    }
  }

  /*!
   * \brief Enter a new scope.
   */
  void EnterScope() { indent_ += 2; }

  /*!
   * \brief Exit a scope.
   */
  void ExitScope() {
    CHECK_GE(indent_, 2U) << "Wrong ident found.";
    indent_ -= 2;
  }

  /*!
   * \brief Gerenate C code for the external function.
   *
   * \param func_name The name of the external function.
   * \param args arguments to the external function.
   *
   * \code
   *
   * // An example code for the generated C function.
   * extern "C" void foo_wrapper_(DLTensor* arg0,
   *                              DLTensor* arg1,
   *                              DLTensor* out) {
   *   foo_(static_cast<float*>(arg0->data),
   *        static_cast<float*>(arg1->data),
   *        static_cast<float*>(out->data));
   *   return 0;
   * }
   *
   * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_);
   *
   * \endcode
   */
  void GenerateBackendCFunc(const std::string& func_name,
                            const Array<Var>& args,
                            const Output& out) {
    // Print signature
    code_stream_ << "\n";
    code_stream_ << "extern \"C\" int " << func_name << "_wrapper_(";
    for (size_t i = 0; i < args.size(); i++) {
      code_stream_ << "DLTensor* arg" << i << ",\n";
      code_stream_ << "\t";
    }
    if (args.size() > 0) {
      code_stream_ << "DLTensor* arg" << args.size() << ") {\n";
    }

    EnterScope();

    // Generate the internal call.
    PrintIndents();
    code_stream_ << func_name << "_(";
    for (size_t i = 0; i < args.size(); i++) {
      const auto& dtype_str = GetDtypeString(args[i]);
      code_stream_ << "static_cast<" << dtype_str << "*>(arg" << i << "->data),\n";
      PrintIndents();
    }
    if (args.size() > 0) {
      code_stream_ << "static_cast<" << out.dtype << "*>(arg" << args.size() << "->data)";
    }
    code_stream_ << ");\n";
    PrintIndents();
    code_stream_ << "return 0;\n";
    ExitScope();
    code_stream_ << "}\n\n";

    // Generate the macro
    code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", "
                 << func_name << "_wrapper_);\n\n";
  }

  /*!
   * \brief Emit the code for external runtime.
   *
   * \return The code string.
   */
  virtual std::string JIT() = 0;

  /*!
   * \brief Extract the shape from a Relay tensor type.
   *
   * \param type The provided type.
   *
   * \return The extracted shape in a list.
   */
  std::vector<int> GetShape(const Type& type) const {
    const auto* ttype = type.as<TensorTypeNode>();
    CHECK(ttype) << "Expect TensorTypeNode";
    std::vector<int> shape;
    for (size_t i = 0; i < ttype->shape.size(); ++i) {
      auto* val = ttype->shape[i].as<IntImmNode>();
      CHECK(val);
      shape.push_back(val->value);
    }
    return shape;
  }

  /*!
   * \brief Check if a call has the provided name.
   *
   * \param call A Relay call node.
   * \param op_name The name of the expected call.
   *
   * \return true if the call's name is equivalent to the given name. Otherwise,
   * false.
   */
  bool IsOp(const CallNode* call, const std::string& op_name) const {
    const auto* op_node = call->op.as<OpNode>();
    CHECK(op_node) << "Expects a single op.";
    Op op = GetRef<Op>(op_node);
    return op == Op::Get(op_name);
  }

  /*!
   * \brief A common interface that is used by various external runtime to
   * generate the wrapper to invoke external kernels.
   *
   * \param ext_func_id The unique id of an external function. It will be used
   * during runtime to pick the correct external function.
   * \param args The arguments used by the external function.
   * \param buf_decl The declaration of temporary buffers that used to store the
   * intermeidate of each external kernel.
   * \param body The statements of the external function.
   * \param out The name and id pairs for output.
   *
   * \return The emitted code string.
   */
  std::string JitImpl(const std::string& ext_func_id, const Array<Var>& args,
                      const std::vector<std::string>& buf_decl,
                      const std::vector<std::string>& body,
                      const std::vector<Output>& out) {
    // Create the signature. For example, it could be:
    // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
    code_stream_ << "extern \"C\" void " << ext_func_id << "_(";

    CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support.";

    for (const auto& arg : args) {
      const auto& dtype_str = GetDtypeString(arg);
      code_stream_ << dtype_str << "* " << arg->name_hint() << ", ";
    }
    code_stream_ << out[0].dtype << "* out) {\n";
    this->EnterScope();

    // Function body
    for (auto decl : buf_decl) {
      this->PrintIndents();
      code_stream_ << decl << "\n";
    }
    code_stream_ << "\n";
    for (auto stmt : body) {
      this->PrintIndents();
      code_stream_ << stmt << "\n";
    }

    // Copy output
    if (out[0].need_copy) {
      this->PrintIndents();
      code_stream_ << "std::memcpy(out, " << out[0].name << ", 4 * " << out[0].size << ");\n";

      // Free buffers
      for (size_t i = 0; i < buf_decl.size(); i++) {
        this->PrintIndents();
        code_stream_ << "std::free(buf_" << i << ");\n";
      }
    }

    this->ExitScope();
    code_stream_ << "}\n";

    // Create the wrapper to call the ext_func
    this->GenerateBackendCFunc(ext_func_id, args, out[0]);
    return code_stream_.str();
  }

  /*!
   * \brief Returns dtype string
   *
   * \param var Var to get the dtype of
   *
   * \return The dtype string.
   */
  std::string GetDtypeString(const Var& var) {
    auto ttype = var->checked_type().as<TensorTypeNode>();
    CHECK(ttype) << "Expect TensorTypeNode";
    return GetDtypeString(ttype);
  }

  /*!
   * \brief Returns dtype string
   *
   * \param ttype TensorTypeNode* to get the dtype of
   *
   * \return The dtype string.
   */
  std::string GetDtypeString(const TensorTypeNode* ttype) {
    std::string dtype;
    if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) {
      dtype = "float";
    } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) {
      dtype = "int";
    } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) {
      dtype = "int64_t";
    } else {
      LOG(FATAL) << "Unsupported dtype " << ttype->dtype;
    }

    return dtype;
  }

  /*! \brief The external function source code stream. */
  std::ostringstream code_stream_;

 private:
  /*! \brief Indent of the source code. */
  int indent_{0};
};

}  // namespace contrib
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_