/*! * Copyright (c) 2017 by Contributors * \file lowered_func.h * \brief Information about a lowered TVM function. * This data structure is final step toward codegen. */ #ifndef TVM_LOWERED_FUNC_H_ #define TVM_LOWERED_FUNC_H_ #include <tvm/container.h> #include <ir/FunctionBase.h> #include <string> #include "./base.h" #include "./expr.h" #include "./tensor.h" namespace tvm { // Internal node container of lowered function. class LoweredFuncNode; /*! * \brief LoweredFunc represents function after lowering. * This is the final IR representation before codegen. */ class LoweredFunc : public FunctionRef { public: LoweredFunc() {} explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const LoweredFuncNode* operator->() const; /*! \brief specify container node */ using ContainerType = LoweredFuncNode; }; /*! \brief Node container of LoweredFunc */ class LoweredFuncNode : public FunctionBaseNode { public: /*! \brief The name of the function */ std::string name; /*! * \brief The arguments of the function * This function can only take pod type(int, float) and void* as arguments. */ Array<Var> args; /*! * \brief The IterVar axis of threads * Each axis need host function to specify a size. * \note Calling convention into LoweredFunc * * Assume we have a LoweredFunc f, a call into f * Call(f, arg1, arg2, ..., arg_n, * size_axis_1, size_axis_2, ... size_axis_m) * * Here n = len(args), m = len(thread_axis) * * The CodeGen should take this and translate this call * to corresponding API specific kernel launchs or function calls. */ Array<IterVar> thread_axis; /*! * \brief The hint data type of Var handles defined in LetStmt * Can be used as hint when generating type signiture. * The creation rule is given by * handle_data_type[var_handle] = make_const(the_type, 0); * * \note Expr is used instead Type, because Type cannot be hold by Map. * constant Expr of given type is used. */ Map<Var, Expr> handle_data_type; /*! \brief Whether this function is packed function */ bool is_packed_func{true}; /*! \brief The body statment of the function */ Stmt body; /*! \return name of the operation */ const std::string& func_name() const final { return name; } // there is no return value, but return 1 // to enable Call into this function. int num_outputs() const final { return 1; } void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); v->Visit("args", &args); v->Visit("thread_axis", &thread_axis); v->Visit("handle_data_type", &handle_data_type); v->Visit("is_packed_func", &is_packed_func); v->Visit("body", &body); } static constexpr const char* _type_key = "LoweredFunc"; TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node); }; // Implementations of inline functions inline const LoweredFuncNode* LoweredFunc::operator->() const { return static_cast<const LoweredFuncNode*>(node_.get()); } } // namespace tvm namespace std { template <> struct hash<::tvm::LoweredFunc> { std::size_t operator()(const ::tvm::LoweredFunc& k) const { return k.hash(); } }; } #endif // TVM_LOWERED_FUNC_H_