/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/expr.h * \brief The Expr and related elements in DataFlow construction. */ #ifndef TVM_EXPR_H_ #define TVM_EXPR_H_ #include <string> #include <algorithm> #include <unordered_map> #include <iostream> #include "base.h" #include "dtype.h" #include "node/node.h" #include "node/container.h" #include "node/ir_functor.h" #include "runtime/c_runtime_api.h" namespace tvm { /*! \brief Base node of all expressions. */ class ExprNode : public Node { public: /*! \brief The data type of the expression. */ DataType type; static constexpr const char* _type_key = "Expr"; TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node); }; /*! \brief Container of all expressions. */ class Expr : public NodeRef { public: Expr() {} explicit Expr(ObjectPtr<Object> ptr) : NodeRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. */ TVM_DLL Expr(int32_t value); // NOLINT(*) /*! * \brief construct from float. * \param value The value to be constructed. */ TVM_DLL Expr(float value); // NOLINT(*) /*! * \brief construct from string. * \param str The value to be constructed. */ TVM_DLL Expr(std::string str); // NOLINT(*) /*! \return the data type of this expression. */ DataType type() const { return static_cast<const ExprNode*>(get())->type; } /*! \brief type indicate the container type */ using ContainerType = ExprNode; }; /*! \brief Base node of all statements. */ class StmtNode : public Node { public: static constexpr const char* _type_key = "Stmt"; TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node); }; /*! \brief Container of all statements */ class Stmt : public NodeRef { public: TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode); }; class Var; /*! * \brief A variable node in the IR. * * A variable is uniquely identified by its address. * * Each variable is only binded once in the following nodes: * - Allocate * - For * - Let * - LetStmt */ class Variable : public ExprNode { public: /*! * \brief The hint to the variable name. * \note Each variable is uniquely identified by its address. */ std::string name_hint; static Var make(DataType dtype, std::string name_hint); void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("name", &name_hint); } static constexpr const char* _type_key = "Variable"; TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode); }; /*! \brief a named variable in TVM */ class Var : public Expr { public: explicit Var(ObjectPtr<Object> n) : Expr(n) {} TVM_DLL explicit Var(std::string name_hint = "v", Type t = Int(32)); /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ Var copy_with_suffix(const std::string& suffix) const { return Var((*this)->name_hint + suffix, (*this)->type); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ const Variable* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ const Variable* get() const { return static_cast<const Variable*>(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = Variable; }; // Backward compatibility, will be removed later. using VarExpr = Var; using BaseExprNode = ExprNode; using ExprHash = NodeHash; using ExprEqual = NodeEqual; class Integer; /*! \brief ExprNode: constant integer. */ class IntImm : public ExprNode { public: /*! \brief the Internal value. */ int64_t value; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } TVM_DLL static Integer make(DataType t, int64_t value); static constexpr const char* _type_key = "IntImm"; TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode); }; /*! * \brief Container of constant integer (IntImm). * * This is used to store and automate type check * attributes that must be constant integer. */ class Integer : public Expr { public: Integer() : Expr() {} /*! * \brief constructor from node. */ explicit Integer(ObjectPtr<Object> node) : Expr(node) {} /*! * \brief Construct integer from int value. */ Integer(int value) : Expr(value) {} // NOLINT(*) /*! * \brief Assign an expression to integer. * \param other another expression. */ Integer& operator=(const Integer& other) { data_ = other.data_; return *this; } /*! * \brief Get pointer to the internal value. * \return the content of the integer. */ const IntImm* operator->() const { return static_cast<const IntImm*>(get()); } /*! * \brief convert to int64_t */ operator int64_t() const { CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } /*! \brief type indicate the container type */ using ContainerType = IntImm; }; /*! \brief range over one dimension */ class RangeNode : public Node { public: /*! \brief beginning of the node */ Expr min; /*! \brief the extend of range */ Expr extent; /*! \brief constructor */ RangeNode() {} RangeNode(Expr min, Expr extent) : min(min), extent(extent) {} void VisitAttrs(AttrVisitor* v) { v->Visit("min", &min); v->Visit("extent", &extent); } static constexpr const char* _type_key = "Range"; TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node); }; /*! \brief Range constainer */ class Range : public NodeRef { public: /*! * \brief constructor by begin and end * \param begin The begin of the range. * \param end The end of the range. */ TVM_DLL Range(Expr begin, Expr end); /*! * \brief construct a new range with min and extent * The corresponding constructor is removed, * because that is counter convention of tradition meaning * of range(begin, end) * * \param min The minimum range. * \param extent The extent of the range. */ static Range make_by_min_extent(Expr min, Expr extent); // declare range. TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode); }; /*! \brief container class of iteration variable. */ class IterVarNode; using Region = Array<Range>; /*! * \brief Type of iteration variable. * Each IterVar have a specific type. * * The type of iter var can be overriden via * stage.iter_var_attrs given they are compatible. */ enum IterVarType : int { /*! * \brief Data parallel iteration. * This normally corresponds to axis of Tensor. * Allow all IterVar manipulations. * * \note This does not mean the loop * have to be executed in parallel fashion. */ kDataPar = 0, /*! * \brief The IterVar itself is a thread-index * of a fixed thread launching group. * Note that this is already assumed to be paralellized. * * Disallow: split/fuse/vectorize/parallel */ kThreadIndex = 1, /*! * \brief Communicative reduction. * Cannot be directly parallelized. * * Disallow: parallel/vectorize */ kCommReduce = 2, /*! * \brief Serial loops with loop carry dependency, * the iteration must execute in order. * Cannot be re-ordered. * * Disallow: reorder/parallel/vectorize */ kOrdered = 3, /*! * \brief IterVar is opaque, * * May not corresponds to any generated loop * Disallow all IterVar manipulations and compute_at * * \note This is usually used to implement composite op * or external op, where the */ kOpaque = 4, // The following are possible additional // types that are provided during schedule /*! * \brief The execution is unrolled. */ kUnrolled = 5, /*! * \brief The loop is vectorized. */ kVectorized = 6, /*! * \brief The loop is parallelized. */ kParallelized = 7, /*! * \brief Marks boundary of tensorization intrinsic. */ kTensorized = 8 }; /*! * \brief Iteration Variable, * represents an iteration over an integer interval. */ class IterVar : public NodeRef { public: // construct a new iter var without a domain IterVar() {} // construct from shared ptr. explicit IterVar(ObjectPtr<Object> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const IterVarNode* operator->() const; /*! * \return the corresponding var in the IterVar. */ inline operator Expr() const; /*! \brief specify container node */ using ContainerType = IterVarNode; }; /*! * \brief Create a new IterVar that represents an axis in thread. * * \param dom Optional, domain of the thread axis. * \param tag The thread tag of the axis. */ TVM_DLL IterVar thread_axis(Range dom, std::string tag); /*! * \brief Create a new IterVar for reduction operations. * * \param dom The domain of the reduction axis. * \param name The name of the reduction axis. */ TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); using Domain = Array<Range>; /*! * \brief Dump the node to stderr, used for debug purposes. * \param node The input node */ TVM_DLL void Dump(const NodeRef& node); // definition of Node. /*! * \brief An iteration variable representing an iteration * over a one dimensional interval. */ class IterVarNode : public Node { public: /*! * \brief the domain of iteration, if known, can be None * For the intermediate schedule node, before schedule. */ Range dom; /*! \brief The looping variable */ Var var; /*! \brief The type of the IterVar */ IterVarType iter_type; /*! * \brief additional tag on the iteration variable, * set this if this is binded already to a known thread tag. */ std::string thread_tag; void VisitAttrs(AttrVisitor* v) { v->Visit("dom", &dom); v->Visit("var", &var); v->Visit("iter_type", &iter_type); v->Visit("thread_tag", &thread_tag); } TVM_DLL static IterVar make(Range dom, Var var, IterVarType iter_type, std::string thread_tag = ""); static constexpr const char* _type_key = "IterVar"; TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); }; // inline implementations inline const IterVarNode* IterVar::operator->() const { return static_cast<const IterVarNode*>(data_.get()); } inline IterVar::operator Expr() const { return (*this)->var; } inline const char* IterVarType2String(IterVarType t) { switch (t) { case kDataPar: return "DataPar"; case kThreadIndex: return "ThreadIndex"; case kCommReduce: return "CommReduce"; case kOrdered: return "Ordered"; case kOpaque: return "Opaque"; case kUnrolled: return "Unrolled"; case kVectorized: return "Vectorized"; case kParallelized: return "Parallelized"; case kTensorized: return "Tensorized"; } return "Unknown"; } /*! * \brief Construct a new Var expression * \param name_hint The name hint for the expression * \param t The type of the expression */ TVM_DLL Var var(std::string name_hint, Type t = Int(32)); /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map * \param dmap The container map * \return The corresponding unordered_map. * \tparam K the key of the Map. * \tparam V the value of the Map. */ template<typename K, typename V> inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) { std::unordered_map<K, V> ret; for (auto kv : dmap) { ret[kv.first] = kv.second; } return ret; } // Printer infra. /*! \brief A Pretty printer class to print the IR. */ class IRPrinter { public: /*! \brief The output stream */ std::ostream& stream; /*! \brief The indentation level. */ int indent{0}; explicit IRPrinter(std::ostream& stream) // NOLINT(*) : stream(stream) {} /*! \brief The node to be printed. */ TVM_DLL void Print(const ObjectRef& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); // Allow registration to be printer. using FType = IRFunctor<void(const ObjectRef&, IRPrinter *)>; TVM_DLL static FType& vtable(); }; // default print function for all nodes inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) IRPrinter(os).Print(n); return os; } } // namespace tvm namespace std { template <> struct hash<::tvm::IterVar> : public ::tvm::NodeHash { }; } #endif // TVM_EXPR_H_