/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file codegen_c.h
 * \brief Common utilities to generated C style code.
 */
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_H_

#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <tvm/runtime/container.h>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "codegen_source_base.h"

namespace tvm {
namespace codegen {

using namespace tir;
/*!
 * \brief A base class to generate C code.
 *
 *  CodeGenC have two modes: generate SSA formed C code or normal form.
 *
 * **NOTE** CodeGenC does not aim at generating C codes consumed by MSVC or GCC,
 * Rather, it's providing infrastructural abstraction for C variants like CUDA
 * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for
 * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
 */
class CodeGenC :
      public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
      public StmtFunctor<void(const Stmt&)>,
      public CodeGenSourceBase {
 public:
  /*!
   * \brief Initialize the code generator.
   * \param output_ssa Whether output SSA.
   */
  void Init(bool output_ssa);
  /*!
   * \brief Add the function to the generated module.
   * \param f The function to be compiled.
   * \param whether to append return 0 in the end.
   */
  void AddFunction(const PrimFunc& f);
  /*!
   * \brief Finalize the compilation and return the code.
   * \return The code.
   */
  std::string Finish();
  /*!
   * \brief Print the Stmt n to CodeGenC->stream
   * \param n The statement to be printed.
   */
  void PrintStmt(const Stmt& n) {
    VisitStmt(n);
  }
  /*!
   * \brief Print the expression n(or its ssa id if in ssa mode) into os
   * \param n The expression to be printed.
   * \param os The output stream
   */
  void PrintExpr(const PrimExpr& n, std::ostream& os);
  /*!
   * \brief Same as PrintExpr, but simply returns result string
   * \param n The expression to be printed.
   */
  std::string PrintExpr(const PrimExpr& n) {
    std::ostringstream os;
    PrintExpr(n, os);
    return os.str();
  }
  // The following parts are overloadable print operations.
  /*!
   * \brief Print the function header before the argument list
   *
   *  Example: stream << "void";
   */
  virtual void PrintFuncPrefix(); // NOLINT(*)
  /*!
   * \brief Print the final return at the end the function.
   */
  virtual void PrintFinalReturn(); // NOLINT(*)
  /*!
   * \brief Insert statement before function body.
   * \param f The function to be compiled.
   */
  virtual void PreFunctionBody(const PrimFunc& f) {}
  /*!
   * \brief Initialize codegen state for generating f.
   * \param f The function to be compiled.
   */
  virtual void InitFuncState(const PrimFunc& f);
  // expression
  void VisitExpr_(const VarNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LoadNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LetNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const CallNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const AddNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const SubNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const MulNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const DivNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const ModNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const MinNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const MaxNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const EQNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const NENode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LTNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const LENode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GTNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const GENode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const AndNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const OrNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const CastNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const NotNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const SelectNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const RampNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const ShuffleNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const BroadcastNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const IntImmNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const FloatImmNode* op, std::ostream& os) override;  // NOLINT(*)
  void VisitExpr_(const StringImmNode* op, std::ostream& os) override;  // NOLINT(*)
  // statment
  void VisitStmt_(const LetStmtNode* op) override;
  void VisitStmt_(const StoreNode* op) override;
  void VisitStmt_(const ForNode* op) override;
  void VisitStmt_(const IfThenElseNode* op) override;
  void VisitStmt_(const AllocateNode* op) override;
  void VisitStmt_(const AttrStmtNode* op) override;
  void VisitStmt_(const AssertStmtNode* op) override;
  void VisitStmt_(const EvaluateNode* op) override;
  void VisitStmt_(const SeqStmtNode* op) override;
  void VisitStmt_(const ProducerConsumerNode* op) override;
  /*!
   * Print Type represetnation of type t.
   * \param t The type representation.
   * \param os The stream to print the ctype into
   */
  virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
  /*!
   * Print Type represetnation of type type.
   * \param type The type representation.
   * \param os The stream to print the ctype into
   */
  virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*)
  /*!
   * \brief Print expr representing the thread tag
   * \param IterVar iv The thread index to be binded;
   */
  virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
  virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
  virtual void PrintStorageSync(const CallNode* op);  // NOLINT(*)
  // Binary vector op.
  virtual void PrintVecBinaryOp(
      const std::string&op, DataType op_type,
      PrimExpr lhs, PrimExpr rhs, std::ostream& os);  // NOLINT(*)
  // print vector load
  virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base);
  // print vector store
  virtual void PrintVecStore(const VarNode* buffer,
                             DataType t, PrimExpr base,
                             const std::string& value);  // NOLINT(*)
  // print load of single element
  virtual void PrintVecElemLoad(
      const std::string& vec, DataType t, int i, std::ostream& os);  // NOLINT(*)
  // print store of single element.
  virtual void PrintVecElemStore(
      const std::string& vec, DataType t, int i, const std::string& value);
  // Get a cast type from to
  virtual std::string CastFromTo(std::string value, DataType from, DataType target);

 protected:
  // Print reference to struct location
  std::string GetStructRef(
      DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
  // Print reference to a buffer as type t in index.
  virtual std::string GetBufferRef(
      DataType t, const VarNode* buffer, PrimExpr index);

  /*!
   * \brief Handle volatile loads.
   *
   * This is to workaround a bug in CUDA cuda_fp16.h. Volatile accesses
   * to shared memory are required for reductions. However, __half class
   * does not implement volatile member functions. CUDA codegen will cast
   * away volatile qualifier from CUDA __half types.
   */
  virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op,
                                   std::ostream& os) {
    // By default, do nothing but print the loaded value.
    os << value;
  }

  /*!
   * \brief Check if scope is part of type in the target language.
   *
   * **NOTE** In OpenCL, __local is part of type, so "__local int *"
   * is legal. This is not the case for CUDA, where "__shared__"
   * or "__constant__" is not part of type but a storage class (like
   * C/C++ static).
   */
  virtual bool IsScopePartOfType() const {
    return true;
  }

  /*!
   * \brief If buffer is allocated as type t.
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
  bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
  /*!
   * \brief Register the data type of buf_var
   * \param buf_var The buffer variable.
   * \param t The type to be checked.
   */
  void RegisterHandleType(const VarNode* buf_var, DataType t);
  // override
  void PrintSSAAssign(
      const std::string& target, const std::string& src, DataType t) final;
  /*! \brief reserves common C keywords */
  void ReserveKeywordsAsUnique();

  /*! \brief Check if buf_var is volatile or not. */
  bool IsVolatile(const VarNode *buf_var) const {
    return volatile_buf_.count(buf_var) != 0;
  }

  /*! \brief restrict keyword */
  std::string restrict_keyword_{""};
  /*! \brief the storage scope of allocation */
  std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
  /*! \brief the data type of allocated buffers */
  std::unordered_map<const VarNode*, DataType> handle_data_type_;

  /*!
   * \brief A RAII utility class for emitting code in a scoped region.
   */
  class EnterScopeRAII {
    // The codegen context.
    CodeGenC* cg;

    // The new scope level.
    int scope;

   public:
    explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) {
      cg->PrintIndent();
      cg->stream << "{\n";
      scope = cg->BeginScope();
    }
    ~EnterScopeRAII() {
      cg->EndScope(scope);
      cg->PrintIndent();
      cg->stream << "}\n";
    }
  };

 private:
  /*! \brief whether to print in SSA form */
  bool print_ssa_form_{false};
  /*! \brief set of volatile buf access */
  std::unordered_set<const VarNode*> volatile_buf_;
};

}  // namespace codegen
}  // namespace tvm
#endif  // TVM_TARGET_SOURCE_CODEGEN_C_H_