/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*!
 * \file tvm/ir.h
 * \brief Additional high level nodes in the IR
 */
// Acknowledgement: Most low-level IR nodes originate from Halide.

#ifndef TVM_IR_H_
#define TVM_IR_H_

#include <type_traits>
#include <string>
#include <vector>
#include <utility>
#include "base.h"
#include "expr.h"
#include "runtime/util.h"

namespace tvm {
namespace ir {

using IntImm = tvm::IntImm;
using Variable = tvm::Variable;

/*! \brief constant unsigned integer. */
class UIntImm : public ExprNode {
 public:
  /*! \brief The constant value content. */
  uint64_t value;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("value", &value);
  }

  TVM_DLL static Expr make(Type t, uint64_t value);

  static constexpr const char* _type_key = "UIntImm";
  TVM_DECLARE_NODE_TYPE_INFO(UIntImm, ExprNode);
};

/*! \brief Floating point constants. */
class FloatImm : public ExprNode {
 public:
  /*! \brief The constant value content. */
  double value;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("value", &value);
  }

  TVM_DLL static Expr make(Type t, double value);

  static constexpr const char* _type_key = "FloatImm";
  TVM_DECLARE_NODE_TYPE_INFO(FloatImm, ExprNode);
};

/*! \brief String constants, only used in asserts. */
class StringImm : public ExprNode {
 public:
  /*! \brief The constant value content. */
  std::string value;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("value", &value);
  }

  TVM_DLL Expr static make(std::string value);

  static constexpr const char* _type_key = "StringImm";
  TVM_DECLARE_NODE_TYPE_INFO(StringImm, ExprNode);
};

/*!
 * \brief Cast value from one data type to another.
 * \note The lanes of value should keep fixed.
 */
class Cast : public ExprNode {
 public:
  /*! \brief Original data type. */
  Expr value;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("value", &value);
  }

  TVM_DLL static Expr make(Type t, Expr v);

  static constexpr const char* _type_key = "Cast";
  TVM_DECLARE_NODE_TYPE_INFO(Cast, ExprNode);
};

/*!
 * \brief Base template to implement binary ops.
 * \tparam T The type of the child class.
 */
template<typename T>
class BinaryOpNode : public ExprNode {
 public:
  /*! \brief The left operand. */
  Expr a;
  /*! \brief The right operand. */
  Expr b;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &(this->type));
    v->Visit("a", &a);
    v->Visit("b", &b);
  }

  static Expr make(Expr a, Expr b) {
    CHECK(a.defined()) << "ValueError: a is undefined\n";
    CHECK(b.defined()) << "ValueError: b is undefined\n";
    CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
    NodePtr<T> node = make_node<T>();
    node->type = a.type();
    node->a = std::move(a);
    node->b = std::move(b);
    return Expr(node);
  }

  TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
};

/*! \brief a + b */
class Add : public BinaryOpNode<Add> {
 public:
  static constexpr const char* _type_key = "Add";
};

/*! \brief a - b */
class Sub : public BinaryOpNode<Sub> {
 public:
  static constexpr const char* _type_key = "Sub";
};

/*! \brief a * b */
class Mul : public BinaryOpNode<Mul> {
 public:
  static constexpr const char* _type_key = "Mul";
};

/*!
 * \brief a / b in the C semnatics.
 * \note For integer division, C standard uses trunc div.
 */
class Div : public BinaryOpNode<Div> {
 public:
  static constexpr const char* _type_key = "Div";
};

/*!
 * \brief a % b in the C semnatics.
 * \note For integer division, C standard uses trunc div.
 */
class Mod : public BinaryOpNode<Mod> {
 public:
  static constexpr const char* _type_key = "Mod";
};

/*! \brief Floor division, floor(a/b) */
class FloorDiv : public BinaryOpNode<FloorDiv> {
 public:
  static constexpr const char* _type_key = "FloorDiv";
};

/*! \brief The remainder of the floordiv */
class FloorMod : public BinaryOpNode<FloorMod> {
 public:
  static constexpr const char* _type_key = "FloorMod";
};

/*! \brief min(a, b) */
class Min : public BinaryOpNode<Min> {
 public:
  static constexpr const char* _type_key = "Min";
};

/*! \brief max(a, b) */
class Max : public BinaryOpNode<Max> {
 public:
  static constexpr const char* _type_key = "Max";
};

/*!
 * \brief Base template to implement comparison ops.
 * \tparam T The type of the child class.
 */
template<typename T>
class CmpOpNode : public ExprNode {
 public:
  /*! \brief The left operand. */
  Expr a;
  /*! \brief The right operand. */
  Expr b;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &(this->type));
    v->Visit("a", &a);
    v->Visit("b", &b);
  }

  static Expr make(Expr a, Expr b) {
    CHECK(a.defined()) << "ValueError: a is undefined\n";
    CHECK(b.defined()) << "ValueError: b is undefined\n";
    CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
    NodePtr<T> node = make_node<T>();
    node->type = Bool(a.type().lanes());
    node->a = std::move(a);
    node->b = std::move(b);
    return Expr(node);
  }

  TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
};

/*! \brief a == b */
class EQ : public CmpOpNode<EQ> {
 public:
  static constexpr const char* _type_key = "EQ";
};

/*! \brief a != b */
class NE : public CmpOpNode<NE> {
 public:
  static constexpr const char* _type_key = "NE";
};

/*! \brief a < b */
class LT : public CmpOpNode<LT> {
 public:
  static constexpr const char* _type_key = "LT";
};

/*! \brief a <= b */
struct LE : public CmpOpNode<LE> {
 public:
  static constexpr const char* _type_key = "LE";
};

/*! \brief a > b */
class GT : public CmpOpNode<GT> {
 public:
  static constexpr const char* _type_key = "GT";
};

/*! \brief a >= b */
class GE : public CmpOpNode<GE> {
 public:
  static constexpr const char* _type_key = "GE";
};

/*! \brief a && b */
class And : public ExprNode {
 public:
  /*! \brief The left operand. */
  Expr a;
  /*! \brief The right operand. */
  Expr b;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &(this->type));
    v->Visit("a", &a);
    v->Visit("b", &b);
  }

  TVM_DLL static Expr make(Expr a, Expr b);

  static constexpr const char* _type_key = "And";
  TVM_DECLARE_NODE_TYPE_INFO(And, ExprNode);
};

/*! \brief a || b */
class Or : public ExprNode {
 public:
  /*! \brief The left operand. */
  Expr a;
  /*! \brief The right operand. */
  Expr b;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("a", &a);
    v->Visit("b", &b);
  }

  TVM_DLL static Expr make(Expr a, Expr b);

  static constexpr const char* _type_key = "Or";
  TVM_DECLARE_NODE_TYPE_INFO(Or, ExprNode);
};

/*! \brief !a */
class Not : public ExprNode {
 public:
  /*! \brief The input operand. */
  Expr a;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("a", &a);
  }

  TVM_DLL static Expr make(Expr a);

  static constexpr const char* _type_key = "Not";
  TVM_DECLARE_NODE_TYPE_INFO(Not, ExprNode);
};

/*!
 * \brief return true_value if condition is true, otherwise return false_value.
 * \note Both true_value and false_value could be evaluated
 *       regardless of the condition value.
 *       Do not use it to guard against out of bound access,
 *       please use if_then_else instead.
 */
class Select : public ExprNode {
 public:
  /*! \brief The condition */
  Expr condition;
  /*! \brief value to be returned when condition is true. */
  Expr true_value;
  /*! \brief value to be returned when condition is false. */
  Expr false_value;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("condition", &condition);
    v->Visit("true_value", &true_value);
    v->Visit("false_value", &false_value);
  }

  TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);

  static constexpr const char* _type_key = "Select";
  TVM_DECLARE_NODE_TYPE_INFO(Select, ExprNode);
};

/*!
 * \brief Load the value from buffer_var.
 *
 *  Equivalent to ((DType*)buffer_var)[index]
 *  where DType is the type specified by type().element_of().
 *
 *  For example, if type = float32x3, then the load will corresponds to
 *
 * \code
 *
 *  auto buffer = static_cast<float*>(buffer_var);
 *  auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
 *
 * \endcode
 */
class Load : public ExprNode {
 public:
  /*! \brief The buffer variable. */
  Var buffer_var;
  /*! \brief The index locations to be loaded. */
  Expr index;
  /*! \brief The predicate to mask which lanes would be loaded. */
  Expr predicate;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("buffer_var", &buffer_var);
    v->Visit("index", &index);
    v->Visit("predicate", &predicate);
  }

  TVM_DLL static Expr make(Type type, Var buffer_var, Expr index, Expr predicate);

  static constexpr const char* _type_key = "Load";
  TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode);
};

/*!
 * \brief Construct a vector with lanes elements
 *        where its i-th element equals base + i * stride.
 *  This is useful to construct a index for a continuous vector load.
 *
 *  Examples:
 *  - ramp(0, 1, 3) = [0, 1, 2]
 *  - ramp(1, 2, 4) = [1, 3, 5, 7]
 */
class Ramp : public ExprNode {
 public:
  /*! \brief The base value. */
  Expr base;
  /*! \brief The stride of each step. */
  Expr stride;
  /*! \brief Total number of lanes. */
  int lanes;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("base", &base);
    v->Visit("stride", &stride);
    v->Visit("lanes", &lanes);
  }

  TVM_DLL static Expr make(Expr base, Expr stride, int lanes);

  static constexpr const char* _type_key = "Ramp";
  TVM_DECLARE_NODE_TYPE_INFO(Ramp, ExprNode);
};

/*! \brief Create a vector where all the elements are value. */
class Broadcast : public ExprNode {
 public:
  /*! \brief The base value. */
  Expr value;
  /*! \brief The numerb of lanes. */
  int lanes;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("value", &value);
    v->Visit("lanes", &lanes);
  }

  TVM_DLL static Expr make(Expr value, int lanes);

  static constexpr const char* _type_key = "Broadcast";
  TVM_DECLARE_NODE_TYPE_INFO(Broadcast, ExprNode);
};

/*!
 * \brief Let binding. Bind var to value then evaluate body.
 */
class Let : public ExprNode {
 public:
  /*! \brief The variable. */
  Var var;
  /*! \brief The value to be binded. */
  Expr value;
  /*! \brief The result expression. */
  Expr body;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("var", &var);
    v->Visit("value", &value);
    v->Visit("body", &body);
  }

  TVM_DLL static Expr make(Var var, Expr value, Expr body);

  static constexpr const char* _type_key = "Let";
  TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode);
};

// Call node, represent a function call or a multi-dimensional array load.
//
// TODO(tvm-team):
// Refactor call with more explicit property registrations.
// rather than calling a string symbol.
// We should move most information into function itself and remove name.

/*! \brief Base node of internal functions. */
class FunctionBaseNode : public Node {
 public:
  /*! \return the name of the function */
  virtual const std::string& func_name() const = 0;
  /*! \return the number of outputs of this function */
  virtual int num_outputs() const = 0;
};

/*! \brief reference to a function */
class FunctionRef : public NodeRef {
 public:
  TVM_DEFINE_NODE_REF_METHODS(FunctionRef, NodeRef, FunctionBaseNode);
};

/*!
 * \brief Call node.
 */
class Call : public ExprNode {
 public:
  /*! \brief Possible types of calls. */
  enum CallType : int {
    /*! \brief Extern "C" function. */
    Extern = 0,
    /*! \brief Extern CXX function. */
    ExternCPlusPlus = 1,
    /*! \brief Extern "C" without side-effect. */
    PureExtern = 2,
    /*! \brief Halide-style call, evaluates func(args). */
    Halide = 3,
    /*! \brief Intrinsic functions. */
    Intrinsic = 4,
    /*! \brief Intrinsic functions that are pure. */
    PureIntrinsic = 5
  };
  /*! \brief The name of the function/intrinsic. */
  std::string name;
  /*! \brief The arguments. */
  Array<Expr> args;
  /*! \brief Type of calls. */
  CallType call_type;
  /*! \brief The function to be called. */
  FunctionRef func;
  /*! \brief The output value index if func's value is a tuple. */
  int value_index{0};

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("name", &name);
    v->Visit("args", &args);
    v->Visit("call_type", &call_type);
    v->Visit("func", &func);
    v->Visit("value_index", &value_index);
  }

  TVM_DLL static Expr make(Type type,
                           std::string name,
                           Array<Expr> args,
                           CallType call_type,
                           FunctionRef func = FunctionRef(),
                           int value_index = 0);

  /*! \return Whether call node is pure. */
  bool is_pure() const {
    return (call_type == PureExtern ||
            call_type == PureIntrinsic ||
            call_type == Halide);
  }

  /*!
   * \return Whether call node corresponds to a defined intrinsic.
   * \param intrin_name The name of the intrinsic.
   */
  bool is_intrinsic(const char* intrin_name) const {
    return
        ((call_type == Intrinsic ||
          call_type == PureIntrinsic) &&
         name == intrin_name);
  }

  /*! \return Whether call node can be vectorized. */
  bool is_vectorizable() const;

  static constexpr const char* _type_key = "Call";
  TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);

  // Build-in intrinsics
  static constexpr const char* reinterpret = "reinterpret";
  static constexpr const char* bitwise_and = "bitwise_and";
  static constexpr const char* bitwise_not = "bitwise_not";
  static constexpr const char* bitwise_xor = "bitwise_xor";
  static constexpr const char* bitwise_or = "bitwise_or";
  static constexpr const char* shift_left = "shift_left";
  static constexpr const char* shift_right = "shift_right";
  static constexpr const char* popcount = "popcount";
  static constexpr const char* likely = "likely";
  static constexpr const char* glsl_texture_store = "glsl_texture_store";
  static constexpr const char* prefetch = "prefetch";
  static constexpr const char* isnan = "isnan";

  /*! \brief Vectorizable intrinsic list. */
  static const char* vectorizable_intrinsics[];
};

/*!
 * \brief Shuffle instruction.
 *  vec = concat(vectors)
 *  result = (vec[indices[0]], vec[indices[1]] ...)
 */
class Shuffle : public ExprNode {
 public:
  /*! \brief the input vectors. */
  Array<Expr> vectors;
  /*! \brief The indices of each element. */
  Array<Expr> indices;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("vectors", &vectors);
    v->Visit("indices", &indices);
  }

  TVM_DLL static Expr make(Array<Expr> vectors, Array<Expr> indices);
  TVM_DLL static Expr make_concat(Array<Expr> vectors);
  TVM_DLL static Expr make_extract_element(Expr vector, int index);

  static constexpr const char* _type_key = "Shuffle";
  TVM_DECLARE_NODE_TYPE_INFO(Shuffle, ExprNode);
};

// Reduce operator
class CommReducerNode;

class CommReducer : public NodeRef {
 public:
  CommReducer() {}
  explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {}
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const CommReducerNode* get() const;
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const CommReducerNode* operator->() const;
  /*! \brief type indicate the container type */
  using ContainerType = CommReducerNode;
};

/*!
 * \brief A commutative reducer node to represent a commutative
 *  binary operator with identity element
 */
class CommReducerNode : public Node {
 public:
  /*! \brief The left argument of reducer */
  Array<Var> lhs;
  /*! \brief The right argument of reducer */
  Array<Var> rhs;
  /*! \brief The result of reducer */
  Array<Expr> result;
  /*!
   * \brief The identity element of reducer, which leaves other
   *  elements unchanged when combined with it, with respect to
   *  the binary operation of this reducer uses.
   */
  Array<Expr> identity_element;
  /*! \brief Function call operator to combine a and b */
  Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
  /*! \brief construct CommReducer from args, result and identity_element */
  TVM_DLL static CommReducer make(Array<Var> lhs,
                                  Array<Var> rhs,
                                  Array<Expr> result,
                                  Array<Expr> identity_element);

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("lhs", &lhs);
    v->Visit("rhs", &rhs);
    v->Visit("result", &result);
    v->Visit("identity_element", &identity_element);
  }

  static constexpr const char* _type_key = "CommReducer";
  TVM_DECLARE_NODE_TYPE_INFO(CommReducerNode, Node);
};

inline const CommReducerNode* CommReducer::get() const {
  return static_cast<CommReducerNode*>(node_.get());
}
inline const CommReducerNode* CommReducer::operator->() const {
  return static_cast<CommReducerNode*>(node_.get());
}

/*! \brief Reduction operator operator */
class Reduce : public ExprNode {
 public:
  /*! \brief The commutative combiner */
  CommReducer combiner;
  /*! \brief The source operand */
  Array<Expr> source;
  /*! \brief The reduction axis */
  Array<IterVar> axis;
  /*!
   * \brief Predicate on the reduction
   *  Only add the body to reduction if condition is true.
   */
  Expr condition;
  /*! \brief the index of this reduce node */
  int value_index;

  /*! \brief construct expr from op and rdom */
  TVM_DLL static Expr make(CommReducer combiner,
                           Array<Expr> src,
                           Array<IterVar> rdom,
                           Expr condition,
                           int value_index);

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("combiner", &combiner);
    v->Visit("source", &source);
    v->Visit("axis", &axis);
    v->Visit("condition", &condition);
    v->Visit("value_index", &value_index);
  }

  static constexpr const char* _type_key = "Reduce";
  TVM_DECLARE_NODE_TYPE_INFO(Reduce, ExprNode);
};

/*! \brief Any shape. */
class Any : public ExprNode {
 public:
  void VisitAttrs(AttrVisitor* v) final {}
  /*! \brief Convert to var. */
  Var ToVar() const {
    return Variable::make(Int(32), "any_dim");
  }

  TVM_DLL static Expr make();

  static constexpr const char* _type_key = "Any";
  TVM_DECLARE_NODE_TYPE_INFO(Any, ExprNode);
};

// Statements
/*!
 * \brief Let binding, bind var to value, then run body.
 */
class LetStmt : public StmtNode {
 public:
  /*! \brief The variable. */
  Var var;
  /*! \brief The value to be binded. */
  Expr value;
  /*! \brief The body block. */
  Stmt body;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("var", &var);
    v->Visit("value", &value);
    v->Visit("body", &body);
  }

  TVM_DLL static Stmt make(Var var, Expr value, Stmt body);

  static constexpr const char* _type_key = "LetStmt";
  TVM_DECLARE_NODE_TYPE_INFO(LetStmt, StmtNode);
};

/*!
 * \brief Define certain auxiliary attribute for the body to be a symbolic value.
 *  This provide auxiliary information for IR passes that transforms body.
 *
 *  In terms of effect, this is equivalent to Block(Evaluate(value), body).
 *
 *  Examples of possible usage:
 *    - Bound of function, variables.
 *    - Hint which block corresponds to a parallel region.
 */
class AttrStmt : public StmtNode {
 public:
  /*! \brief this is attribute about certain node */
  NodeRef node;
  /*! \brief the type key of the attribute */
  std::string attr_key;
  /*! \brief The attribute value, value is well defined at current scope. */
  Expr value;
  /*! \brief The body statement to be executed */
  Stmt body;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("node", &node);
    v->Visit("attr_key", &attr_key);
    v->Visit("value", &value);
    v->Visit("body", &body);
  }

  TVM_DLL static Stmt make(NodeRef node,
                           std::string type_key,
                           Expr value,
                           Stmt body);

  static constexpr const char* _type_key = "AttrStmt";
  TVM_DECLARE_NODE_TYPE_INFO(AttrStmt, StmtNode);
};

/*!
 * \brief Assert condition, if an error occurs, return the error message.
 */
class AssertStmt : public StmtNode {
 public:
  /*! \brief Condition to be checked. */
  Expr condition;
  /*! \brief Error message when assertion failed. */
  Expr message;
  /*!
   * \brief Body which this assertion holds true.
   *  Will be executed after the assertion.
   */
  Stmt body;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("condition", &condition);
    v->Visit("message", &message);
    v->Visit("body", &body);
  }

  TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);

  static constexpr const char* _type_key = "AssertStmt";
  TVM_DECLARE_NODE_TYPE_INFO(AssertStmt, StmtNode);
};

// TODO(tvm-team): consider consolidate with AttrStmt.
/*! \brief annotation node of producer/consumer relation. */
class ProducerConsumer : public StmtNode {
 public:
  /*! \brief The corresponding tensor. */
  FunctionRef func;
  /*! \brief Whether the relation is producer. */
  bool is_producer;
  /*! \brief Body to be executed. */
  Stmt body;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("func", &func);
    v->Visit("is_producer", &is_producer);
    v->Visit("body", &body);
  }

  TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);

  static constexpr const char* _type_key = "ProducerConsumer";
  TVM_DECLARE_NODE_TYPE_INFO(ProducerConsumer, StmtNode);
};

/*!
 * \brief Store value to the buffer.
 *
 *  Equivalent to ((DType*)buffer_var)[index] = value.
 *  where DType is the type specified by type().element_of().
 *
 *  For example, if type = float32x3, then the load will corresponds to
 *
 * \code
 *
 *  auto buffer = static_cast<float*>(buffer_var);
 *  buffer[index.v0] = value.v0;
 *  buffer[index.v1] = value.v1;
 *  buffer[index.v2] = value.v2;
 *
 * \endcode
 * \sa Load
 */
class Store : public StmtNode {
 public:
  /*! \brief The buffer variable. */
  Var buffer_var;
  /*! \brief The value to be stored. */
  Expr value;
  /*! \brief The index locations to be stored. */
  Expr index;
  /*! \brief The predicate to mask which lanes would be stored. */
  Expr predicate;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("buffer_var", &buffer_var);
    v->Visit("value", &value);
    v->Visit("index", &index);
    v->Visit("predicate", &predicate);
  }

  TVM_DLL static Stmt make(Var buffer_var,
                           Expr value,
                           Expr index,
                           Expr predicate);

  static constexpr const char* _type_key = "Store";
  TVM_DECLARE_NODE_TYPE_INFO(Store, StmtNode);
};

/*!
 * \brief Store value into mult-dimensional array defined by func.
 */
class Provide : public StmtNode {
 public:
  /*! \brief The function to be updated. */
  FunctionRef func;
  /*! \brief The output value index if func's value is a tuple. */
  int value_index{0};
  /*! \brief The value to be stored. */
  Expr value;
  /*! \brief The index arguments of the function. */
  Array<Expr> args;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("func", &func);
    v->Visit("value_index", &value_index);
    v->Visit("value", &value);
    v->Visit("args", &args);
  }

  TVM_DLL static Stmt make(FunctionRef func,
                           int value_index,
                           Expr value,
                           Array<Expr> args);

  static constexpr const char* _type_key = "Provide";
  TVM_DECLARE_NODE_TYPE_INFO(Provide, StmtNode);
};

/*!
 * \brief Allocate a buffer that can be used in body.
 */
class Allocate : public StmtNode {
 public:
  /*! \brief The buffer variable. */
  Var buffer_var;
  /*! \brief The type of the buffer. */
  DataType type;
  /*! \brief The extents of the buffer. */
  Array<Expr> extents;
  /*! \brief Only allocate buffer when condition is satisfied. */
  Expr condition;
  /*! \brief The body to be executed. */
  Stmt body;
  // The following two fields are deprecated
  // kept for backward compatibility and will be refactored later.
  Expr new_expr;
  std::string free_function;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("buffer_var", &buffer_var);
    v->Visit("dtype", &type);
    v->Visit("extents", &extents);
    v->Visit("condition", &condition);
    v->Visit("body", &body);
  }

  TVM_DLL static Stmt make(Var buffer_var,
                           DataType type,
                           Array<Expr> extents,
                           Expr condition,
                           Stmt body,
                           Expr new_expr = Expr(),
                           std::string free_function = std::string());

  /*!
   * \brief If the buffer size is constant, return the size.
   *        Otherwise return 0.
   * \return The result.
   */
  int32_t constant_allocation_size() const {
    return constant_allocation_size(extents);
  }
  /*!
   * \brief If the buffer size is constant, return the size.
   *        Otherwise return 0.
   * \param extents The extents of the buffer.
   * \return The result.
   */
  TVM_DLL static int32_t constant_allocation_size(
      const Array<Expr>& extents);

  static constexpr const char* _type_key = "Allocate";
  TVM_DECLARE_NODE_TYPE_INFO(Allocate, StmtNode);
};

/*! \brief Free the resources in the buffer before the scope ends. */
class Free : public StmtNode {
 public:
  /*! \brief The buffer variable. */
  Var buffer_var;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("buffer_var", &buffer_var);
  }

  TVM_DLL static Stmt make(Var buffer_var);

  static constexpr const char* _type_key = "Free";
  TVM_DECLARE_NODE_TYPE_INFO(Free, StmtNode);
};

/*!
 * \brief Annotate the bounds where func need to be written and read in body.
 *  We will need to allocate space for the corresponding regions.
 */
class Realize : public StmtNode {
 public:
  /*! \brief The function to be realized. */
  FunctionRef func;
  /*! \brief The output value index if func's value is a tuple. */
  int value_index;
  /*! \brief The data type of the array. */
  DataType type;
  /*! \brief Bounds to be realized. */
  Region bounds;
  /*! \brief Only realize if condition holds. */
  Expr condition;
  /*! \brief The body of realization. */
  Stmt body;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("func", &func);
    v->Visit("value_index", &value_index);
    v->Visit("dtype", &type);
    v->Visit("bounds", &bounds);
    v->Visit("condition", &condition);
    v->Visit("body", &body);
  }

  TVM_DLL static Stmt make(FunctionRef func,
                           int value_index,
                           DataType type,
                           Region bounds,
                           Expr condition,
                           Stmt body);

  static constexpr const char* _type_key = "Realize";
  TVM_DECLARE_NODE_TYPE_INFO(Realize, StmtNode);
};

/*!
 * \brief A sequence of statements.
 */
class Block : public StmtNode {
 public:
  /*! \brief The first statement. */
  Stmt first;
  /*! \brief The restof statments. */
  Stmt rest;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("first", &first);
    v->Visit("rest", &rest);
  }

  TVM_DLL static Stmt make(Stmt first, Stmt rest);
  TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);

  static constexpr const char* _type_key = "Block";
  TVM_DECLARE_NODE_TYPE_INFO(Block, StmtNode);
};

/*!
 * \brief IfThenElse statment.
 */
class IfThenElse : public StmtNode {
 public:
  /*! \brief The condition. */
  Expr condition;
  /*! \brief The branch to be executed when condition is true. */
  Stmt then_case;
  /*! \brief The branch to be executed when condition is false, can be null. */
  Stmt else_case;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("condition", &condition);
    v->Visit("then_case", &then_case);
    v->Visit("else_case", &else_case);
  }

  TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());

  static constexpr const char* _type_key = "IfThenElse";
  TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode);
};

/*!
 * \brief Evaluates an expression.
 *  This is mostly used for putting a Call node into Stmt.
 *
 *  If value do not have side-effect, this node can be safely removed.
 */
class Evaluate : public StmtNode {
 public:
  /*! \brief The expression to be evaluated. */
  Expr value;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("value", &value);
  }

  TVM_DLL static Stmt make(Expr v);

  static constexpr const char* _type_key = "Evaluate";
  TVM_DECLARE_NODE_TYPE_INFO(Evaluate, StmtNode);
};

/*! \brief Additional annotation of for loop. */
enum class ForType : int {
  /*! \brief serial execution. */
  Serial = 0,
  /*! \brief parallel execution on CPU. */
  Parallel = 1,
  /*! \brief Vector SIMD loop annotaion. */
  Vectorized = 2,
  /*! \brief Unroll annotation. */
  Unrolled = 3
};

// Kevice api of for loop
// kept for backward compatibility
// consider refactor and remove later.
enum class DeviceAPI: int {
  None = 0
};

/*!
 * \brief A for loop, with poissible type annotations.
 *
 * \code
 *
 *  for (loop_var = min; loop_var < min + extent; ++loop_var) {
 *    // body
 *  }
 * \endcode
 */
class For : public StmtNode {
 public:
  /*! \brief The loop variable. */
  Var loop_var;
  /*! \brief The minimum value of iteration. */
  Expr min;
  /*! \brief The extent of the iteration. */
  Expr extent;
  /*! \brief The type of the for loop. */
  ForType for_type;
  /*!
   * \brief Deprecated, reserved for backward compatibility.
   *  Consider refactor and remove later.
   */
  DeviceAPI device_api;
  /*! \brief The body of the for loop. */
  Stmt body;

  TVM_DLL static Stmt make(Var loop_var,
                           Expr min,
                           Expr extent,
                           ForType for_type,
                           DeviceAPI device_api,
                           Stmt body);

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("loop_var", &loop_var);
    v->Visit("min", &min);
    v->Visit("extent", &extent);
    v->Visit("for_type", &for_type);
    v->Visit("device_api", &device_api);
    v->Visit("body", &body);
  }

  static constexpr const char* _type_key = "For";
  TVM_DECLARE_NODE_TYPE_INFO(For, StmtNode);
};

/*!
 * \brief A prefetch hint of func.
 */
class Prefetch : public StmtNode {
 public:
  /*! \brief The function to be prefetched. */
  FunctionRef func;
  /*! \brief The output value index if func's value is a tuple. */
  int value_index;
  /*! \brief The data type of the array. */
  DataType type;
  /*! \brief Bounds to be prefetched. */
  Region bounds;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("func", &func);
    v->Visit("value_index", &value_index);
    v->Visit("type", &type);
    v->Visit("bounds", &bounds);
  }

  TVM_DLL static Stmt make(FunctionRef func,
                           int value_index,
                           DataType type,
                           Region bounds);

  static constexpr const char* _type_key = "Prefetch";
  TVM_DECLARE_NODE_TYPE_INFO(Prefetch, StmtNode);
};

/*!
 * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
 */
struct TensorKey {
  FunctionRef f;
  int value_index;

  inline bool operator==(const TensorKey& other) const {
    return f == other.f && value_index == other.value_index;
  }
  inline std::string GetName() const {
    if (f->num_outputs() == 1) return f->func_name();
    std::ostringstream os;
    os << f->func_name() << ".v" << value_index;
    return os.str();
  }
};

/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
/*! \brief Mark launching extent of thread, used by device API. */
constexpr const char* thread_extent = "thread_extent";
/*! \brief Mark launching of a virtual thread. */
constexpr const char* virtual_thread = "virtual_thread";
/*! \brief Mark region is processed by a co-proccesor */
constexpr const char* coproc_scope = "coproc_scope";
/*!
 * \brief Mark region creates coprocessor micro ops,
 *  can be reused if corresponding variable is independent.
 */
constexpr const char* coproc_uop_scope = "coproc_uop_scope";
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope";
/*!
 * \brief Mark the scope as generated by extern primitive.
 *  such scope can contain arbitrary ir program and we need to be careful
 *  when make certain assumptions about the structure of the program.
 */
constexpr const char* extern_scope = "extern_scope";
/*!
 * \brief Mark the scope as when computation start to happen
 *  This can hint some code generator to create a new function for compute.
 */
constexpr const char* compute_scope = "compute_scope";
/*! \brief Mark storage scope of buffers */
constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage alignement requirement of buffers */
constexpr const char* storage_alignment = "storage_alignment";
/*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope";
/*! \brief The allocation context for global malloc in host. */
constexpr const char* device_context_id = "device_context_id";
/*! \brief The device type. */
constexpr const char* device_context_type = "device_context_type";
/*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark region is guarded by the pragma extension */
constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import llvm source or file into the final code gen module */
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*!
 * \brief Mark of prefetch scope, value=offset,
 *  run prefetch of Tensor on the current loop scope
 */
constexpr const char* prefetch_scope = "prefetch_scope";
/*!
 * \brief Marks production of double buffer data
 */
constexpr const char* double_buffer_scope = "double_buffer_scope";
/*!
 * \brief Marks region used by double buffer write
 */
constexpr const char* double_buffer_write = "double_buffer_write";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope";
/*!
 * \brief Mark alignment of buffer dimension
 *  stmt.node is Tensor
 *  stmt.value is tvm_tuple(dim, align, offset)
 *  This gives hint to require stride of dim to be k * align + offset.
 */
constexpr const char* buffer_dim_align = "buffer_dim_align";
/*! \brief Mark stores/loads with theirs bounds.  */
constexpr const char* buffer_bound = "buffer_bound";
/*!
 * \brief Bind the buffer specification to the region of the op
 *  When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
 *  stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
 *  The scope represents that we need to bind the storage region of tensor to buffer.
 *  This will affect replacement of some variables inside the scope that
 *  corresponds to field of buffer to be the actual expressions of tensor during
 *  storage flattening phase.
 */
constexpr const char* buffer_bind_scope = "buffer_bind_scope";
// Pipeline related attributes
/*! \brief channel read scope */
constexpr const char* channel_read_scope = "channel_read_scope";
/*! \brief Advance step of channel after end of scope */
constexpr const char* channel_read_advance = "channel_read_advance";
/*! \brief channel write scope */
constexpr const char* channel_write_scope = "channel_write_scope";
/*! \brief Advance step of channel after end of scope */
constexpr const char* channel_write_advance = "channel_write_advance";
/*! \brief pipeline stage scope, implies always execution */
constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
/*! \brief pipeline execution scope, implies the scope can be pipelined. */
constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
/*!
 * \brief Mark that this stage is an OpenGL shader. Since OpenGL shader only
 * allows writing out to one element of the output texture, the Provide node
 * gets translated to a special Call::glsl_texture_store statement instead of a
 * Store statement.
 */
constexpr const char* opengl_stage_scope = "opengl_stage_scope";

/*!
 * \brief Mark that it is in the device scope.
 */
constexpr const char* device_scope = "device_scope";

/*!
 * \brief Check if attr_key is a pragma key extension
 * \param attr_key The attr key to be compared
 * \return true if it is a pragma key
 */
inline bool IsPragmaKey(const std::string& attr_key) {
  return attr_key.compare(0, 7, "pragma_") == 0;
}

}  // namespace attr

/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
/*!
 * \brief See pesudo code
 *
 *  Handle tvm_address_of(Load *op) {
 *     return &op->buffer_var[index];
 *  }
 */
constexpr const char* tvm_address_of = "tvm_address_of";
/*!
 * \brief Same as select, used for unsafe memory access.
 *
 *  Type tvm_if_then_else(cond, a, b) {
 *    return cond ? a : b;
 *  }
 */
constexpr const char* tvm_if_then_else = "tvm_if_then_else";
/*!
 * \brief Get head access address with memory access pattern info.
 *
 *  This operator also marks range of the memory access
 *  The offset and extent are in unit of the DType(including vectorization factor).
 *  rw_mask is a bit_mask setting whether the access is a read(1) or write(2).
 *  The access is assume to happen in the current expression.
 *
 *  PtrType tvm_access_ptr(Expr dtype, DType* data,
 *                         int offset, int extent,
 *                         int rw_mask) {
 *    // DType == dtype.type();
 *    return &data[offset];
 *  }
 */
constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*!
 * \brief Create a function local static handle that iniitalizes to nullptr.
 *  can be used to cache function local static resources.
 */
constexpr const char* tvm_static_handle = "tvm_static_handle";
/*!
 * \brief Return a unique context id, used for hint of workspace separation.
 *  Different context id ganrantees not having overlapping workspace.
 */
constexpr const char* tvm_context_id = "tvm_context_id";
/*!
 * \brief tvm_tuple is not an actual function and cannot codegen.
 *  It is used to represent tuple structure in value field of AttrStmt,
 *  for the sake of giving hint to optimization.
 *
 *  Handle tvm_tuple(value0, value1, ..., value_n);
 */
constexpr const char* tvm_tuple = "tvm_tuple";
/*!
 * \brief See pesudo code
 *
 *  Type tvm_struct_get(StructType* arr, int index, int field_id) {
 *     return arr[index]->field;
 *  }
 * \sa TVMStructFieldKind
 */
constexpr const char* tvm_struct_get = "tvm_struct_get";
/*!
 * \brief See pesudo code
 *
 *  Handle tvm_struct_set(StructType* arr, int index, int field_id, value) {
 *     arr[index]->field = value;
 *  }
 * \sa TVMStructFieldKind
 */
constexpr const char* tvm_struct_set = "tvm_struct_set";
/*!
 * \brief See pesudo code
 *
 *  bool tvm_handle_is_null(void* handle) {
 *     return handle == nullptr
 *  }
 */
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
 * \brief See pesudo code
 *
 *  void tvm_throw_last_error() {
 *    throw TVMGetLastError();
 *  }
 */
constexpr const char* tvm_throw_last_error = "tvm_throw_last_error";
/*!
 * \brief See pesudo code
 *
 *  dtype in {shape, array, arg_value, arg_tcode}
 *
 *  Handle tvm_stack_alloca(string dtype, int num) {
 *     return new on stack dtype[num];
 *  }
 */
constexpr const char* tvm_stack_alloca = "tvm_stack_alloca";
/*!
 * \brief Allocate a shape tuple on stack, return the handle.
 *
 *  Handle tvm_stack_make_shape(list args) {
 *     ret = alloca stack int64_t[len(args)];
 *     for i in range(len(args)):
 *        ret[i] = args[i]
 *     return &ret[0];
 *  }
 */
constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
/*!
 * \brief Allocate a NDArray(DLTensor) on stack, return the handle.
 *
 *  Type tvm_stack_make_array(Expr data,
 *                            Expr shape,
 *                            Expr strides,
 *                            Expr ndim,
 *                            Expr dtype,
 *                            Expr elem_offset) {
 *     ret = alloca stack DLTensor();
 *     ret->data = data;
 *     ret->shape = shape;
 *     ret->strides = strides != 0 ? strides : nullptr;
 *     ret->ndim = ndim;
 *     ret->dtype = dtype.type();
 *     ret->byte_offset = elem_offset * sizeof(dtype);
 *     return ret;
 *  }
 */
constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
/*!
 * \brief See pesudo code
 *
 *  int tvm_call_packed(name, TVMValue* args) {
 *     ModuleNode* env = GetCurrentEnv();
 *     const PackedFunc* f = env->GetFuncFromEnv(name);
 *     (*f)(args, type_code_of(args), len(args));
 *     return 0;
 *  }
 */
constexpr const char* tvm_call_packed = "tvm_call_packed";
/*!
 * \brief See pesudo code
 *
 *  int tvm_call_trace_packed(name, TVMValue* args) {
 *     ModuleNode* env = GetCurrentEnv();
 *     const PackedFunc* f = env->GetFuncFromEnv(name);
 *     (*f)(args, type_code_of(args), len(args));
 *     return 0;
 *  }
 */
constexpr const char *tvm_call_trace_packed = "tvm_call_trace_packed";
/*!
 * \brief See pesudo code
 *  Mark the content as thread local context, can get optimized
 *  by only call the call once at thread start.
 *
 *  Do not allow nesting(getting a thread context from another).
 *
 *  Handle tvm_thread_context(Expr call) {
 *     return call;
 *  }
 */
constexpr const char* tvm_thread_context = "tvm_thread_context";
/*!
 * \brief Lowered version of call packed, the space of value and
 *  type codes are explicitly allocated.
 *
 *  int tvm_call_packed_lowered(name,
 *                              TVMValue* value_stack,
 *                              int* tcode_stack,
 *                              int begin,
 *                              int end) {
 *     ModuleNode* env = GetCurrentEnv();
 *     const PackedFunc* f = env->GetFuncFromEnv(name);
 *     f->CallPacked(TVMArgs(value_stack[begin:end],
 *                           tcode_stack[begin:end]),
 *                   TVMRetValue(value_stack + end, tcode_stack + end));
 *  }
 */
constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
/*!
 * \brief Lowered version of trace intrinsic, the space of value and
 *  type codes are explicitly allocated. The return value is the
 *  (end - 1) value on the stack.
 *
 *  int tvm_call_trace_packed_lowered(name,
 *                                    TVMValue* value_stack,
 *                                    int* tcode_stack,
 *                                    int begin,
 *                                    int end) {
 *     ModuleNode* env = GetCurrentEnv();
 *     const PackedFunc* f = env->GetFuncFromEnv(name);
 *     f->CallPacked(TVMArgs(value_stack[begin:end],
 *                           tcode_stack[begin:end]),
 *                   TVMRetValue(value_stack + end, tcode_stack + end));
 *  }
 */
constexpr const char *tvm_call_trace_packed_lowered =
    "tvm_call_trace_packed_lowered";
/*!
 * \brief See pseudo code
 *
 *  int tvm_storage_sync(std::string storage_scope) {
 *     __sync(storage_scope);
 *     return 0;
 *  }
 */
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
 * \brief See pseudo code
 *
 *  Type tvm_warp_shuffle(Type value, warp_id) {
 *     return (value passed in by warp indicated by warp_id);
 *  }
 */
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*!
 * \brief Initialize the global barrier.
 *  Call this at beginning of kernel that need global barrier.
 */
constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
/*!
 * \brief See pesudo code
 *
 *  void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
 *                            Var reduce_temp0, .., Var thread_idx1, ...) {
 *     // constraint by the other thread_idx remain the same.
 *     // reduce_temp is used to save intermediate result.
 *     reduce_temp0, ... = reduce(combiner, source0, ..., cond
 *       over [thread_idx1, thread_idx2] passed by any caller)
 *  }
 */
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";

}   // namespace intrinsic

/*!
 * \brief Create a type annotation expression
 * \param dtype The data type
 * \return Expr a expression with dtype.
 */
inline Expr TypeAnnotation(Type dtype) {
  return ir::Call::make(dtype,
                        "type_annotation", {},
                        ir::Call::PureIntrinsic);
}

// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);

}  // namespace ir
}  // namespace tvm

namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
  std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
    size_t lhs = k.f.hash();
    size_t rhs = static_cast<size_t>(k.value_index);
    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
    return lhs;
  }
};
}  // namespace std

#endif  // TVM_IR_H_