/*! * Copyright (c) 2016 by Contributors * \file tvm/expr.h * \brief The Expr and related elements in DataFlow construction. */ #ifndef TVM_EXPR_H_ #define TVM_EXPR_H_ #include <ir/Expr.h> #include <ir/IROperator.h> #include <ir/IRPrinter.h> #include <string> #include <algorithm> #include "base.h" #include "runtime/c_runtime_api.h" namespace tvm { using HalideIR::Type; using HalideIR::Float; using HalideIR::Bool; using HalideIR::Int; using HalideIR::UInt; using HalideIR::Handle; using HalideIR::ExprHash; using HalideIR::ExprEqual; using HalideIR::Expr; using HalideIR::VarExpr; using HalideIR::IR::RangeNode; using HalideIR::IR::FunctionRef; using HalideIR::IR::FunctionBaseNode; using HalideIR::Internal::Stmt; using HalideIR::Internal::IRPrinter; using HalideIR::Internal::Variable; using HalideIR::Internal::make_const; using HalideIR::Internal::make_zero; using HalideIR::Internal::make_one; using HalideIR::Internal::as_const_int; using HalideIR::Internal::as_const_uint; using HalideIR::Internal::const_true; using HalideIR::Internal::const_false; using HalideIR::Internal::is_no_op; inline Type TVMShapeIndexType() { if (std::is_signed<tvm_index_t>::value) { return Int(sizeof(tvm_index_t) * 8); } else { return UInt(sizeof(tvm_index_t) * 8); } } inline Type TVMType2Type(TVMType t) { return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes); } inline TVMType Type2TVMType(Type t) { TVMType ret; ret.code = static_cast<uint8_t>(t.code()); ret.bits = static_cast<uint8_t>(t.bits()); ret.lanes = static_cast<uint16_t>(t.lanes()); return ret; } // Get number of bytes considering vector type. inline int GetVectorBytes(Type dtype) { int data_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; } /*! \brief a named variable in TVM */ class Var : public HalideIR::VarExpr { public: EXPORT explicit Var(const std::string& name_hint = "v", Type t = Int(32)) : VarExpr(name_hint, t) {} explicit Var(NodePtr<Node> n) : VarExpr(n) {} explicit Var(VarExpr v) : VarExpr(v) {} /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ Var copy_with_suffix(const std::string& suffix) const { return Var((*this)->name_hint + suffix, (*this)->type); } /*! \brief type indicate the container type */ using ContainerType = Variable; }; /*! \brief container class of iteration variable. */ class IterVarNode; /*! * \brief same as HalideIR::IR::Range * except it provide an constructor with (begin, end) * * \note Traditional Halide's Range have a constructor with * (begin, extent), which does not match the convention in e.g. python. * We decided to correct it by removing the constructor in HalideIR, * and add it back in TVM's range. */ class Range : public HalideIR::IR::Range { public: /*! \brief constructor */ Range() {} explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {} /*! * \brief constructor by begin and end * \param begin The begin of the range. * \param end The end of the range. */ TVM_DLL Range(Expr begin, Expr end); TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); }; /*! * \brief Type of iteration variable. * Each IterVar have a specific type. * * The type of iter var can be overriden via * stage.iter_var_attrs given they are compatible. */ enum IterVarType : int { /*! * \brief Data parallel iteration. * This normally corresponds to axis of Tensor. * Allow all IterVar manipulations. * * \note This does not mean the loop * have to be executed in parallel fashion. */ kDataPar = 0, /*! * \brief The IterVar itself is a thread-index * of a fixed thread launching group. * Note that this is already assumed to be paralellized. * * Disallow: split/fuse/vectorize/parallel */ kThreadIndex = 1, /*! * \brief Communicative reduction. * Cannot be directly parallelized. * * Disallow: parallel/vectorize */ kCommReduce = 2, /*! * \brief Serial loops with loop carry dependency, * the iteration must execute in order. * Cannot be re-ordered. * * Disallow: reorder/parallel/vectorize */ kOrdered = 3, /*! * \brief IterVar is opaque, * * May not corresponds to any generated loop * Disallow all IterVar manipulations and compute_at * * \note This is usually used to implement composite op * or external op, where the */ kOpaque = 4, // The following are possible additional // types that are provided during schedule /*! * \brief The execution is unrolled. */ kUnrolled = 5, /*! * \brief The loop is vectorized. */ kVectorized = 6, /*! * \brief The loop is parallelized. */ kParallelized = 7, /*! * \brief Marks boundary of tensorization intrinsic. */ kTensorized = 8 }; /*! * \brief Iteration Variable, * represents an iteration over an integer interval. */ class IterVar : public NodeRef { public: // construct a new iter var without a domain IterVar() {} // construct from shared ptr. explicit IterVar(NodePtr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const IterVarNode* operator->() const; /*! * \return the corresponding var in the IterVar. */ inline operator Expr() const; /*! \brief specify container node */ using ContainerType = IterVarNode; }; /*! * \brief Create a new IterVar that represents an axis in thread. * * \param dom Optional, domain of the thread axis. * \param tag The thread tag of the axis. */ TVM_DLL IterVar thread_axis(Range dom, std::string tag); /*! * \brief Create a new IterVar for reduction operations. * * \param dom The domain of the reduction axis. * \param name The name of the reduction axis. */ TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); using Domain = Array<Range>; // print functions for expr TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*) /*! * \brief Dump the node to stderr, used for debug purposes. * \param node The input node */ TVM_DLL void Dump(const NodeRef& node); // definition of Node. /*! * \brief An iteration variable representing an iteration * over a one dimensional interval. */ class IterVarNode : public Node { public: /*! * \brief the domain of iteration, if known, can be None * For the intermediate schedule node, before schedule. */ Range dom; /*! \brief The looping variable */ Var var; /*! \brief The type of the IterVar */ IterVarType iter_type; /*! * \brief additional tag on the iteration variable, * set this if this is binded already to a known thread tag. */ std::string thread_tag; void VisitAttrs(AttrVisitor* v) final { v->Visit("dom", &dom); v->Visit("var", &var); v->Visit("iter_type", &iter_type); v->Visit("thread_tag", &thread_tag); } TVM_DLL static IterVar make(Range dom, Var var, IterVarType iter_type, std::string thread_tag = ""); static constexpr const char* _type_key = "IterVar"; TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); }; // inline implementations inline const IterVarNode* IterVar::operator->() const { return static_cast<const IterVarNode*>(node_.get()); } inline IterVar::operator Expr() const { return (*this)->var; } inline const char* IterVarType2String(IterVarType t) { switch (t) { case kDataPar: return "DataPar"; case kThreadIndex: return "ThreadIndex"; case kCommReduce: return "CommReduce"; case kOrdered: return "Ordered"; case kOpaque: return "Opaque"; case kUnrolled: return "Unrolled"; case kVectorized: return "Vectorized"; case kParallelized: return "Parallelized"; case kTensorized: return "Tensorized"; } return "Unknown"; } /*! * \brief Construct a new Var expression * \param name_hint The name hint for the expression * \param t The type of the expression */ TVM_DLL Var var(const std::string& name_hint, Type t = Int(32)); /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map * \param dmap The container map * \return The corresponding unordered_map. * \tparam K the key of the Map. * \tparam V the value of the Map. */ template<typename K, typename V> inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) { std::unordered_map<K, V> ret; for (auto kv : dmap) { ret[kv.first] = kv.second; } return ret; } } // namespace tvm namespace std { template <> struct hash<::tvm::IterVar> { std::size_t operator()(const ::tvm::IterVar& k) const { return k.hash(); } }; } #endif // TVM_EXPR_H_