/*! * Copyright (c) 2017 by Contributors * \file tensor_intrin.h * \brief Tensor intrinsic operations. */ #ifndef TVM_TENSOR_INTRIN_H_ #define TVM_TENSOR_INTRIN_H_ #include <string> #include "./tensor.h" #include "./buffer.h" namespace tvm { // Internal node container of tensor intrinsics. class TensorIntrinNode; /*! \brief Tensor intrinsic node. */ class TensorIntrin : public NodeRef { public: TensorIntrin() {} explicit TensorIntrin(std::shared_ptr<Node> n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container */ inline const TensorIntrinNode* operator->() const; /*! \brief specify container node */ using ContainerType = TensorIntrinNode; }; /*! \brief Node to represent a Tensor intrinsic operator */ class TensorIntrinNode : public Node { public: /*! \brief The name of the intrinsic */ std::string name; /*! \brief The operation this intrinsics is carrying out */ Operation op; /*! \brief List of inputs of operator, placeholder in postdfs order */ Array<Tensor> inputs; /*! * \brief Symbolic buffers of each output/input tensor * buffers[0:len(inputs)] are buffers of the inputs. * buffers[len(inputs):] are buffers of each output. * * \note When a field in Buffer is Var, it means we can be flexible * wrt that field and Var can occur in body. * When it is a constant, it means we can only take data in that shape. */ Array<Buffer> buffers; /*! \brief The normal statement to execute the intrinsic */ Stmt body; /*! * \brief Special statement for reduction op, can be None * reset the value of output buffer to identity value. */ Stmt reduce_init; /*! * \brief Special statement for reduction op, can be None * Reduce: do a reduction of current output buffer with the result. */ Stmt reduce_update; /*! \brief constructor */ TensorIntrinNode() {} void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); v->Visit("op", &op); v->Visit("inputs", &inputs); v->Visit("buffers", &buffers); v->Visit("body", &body); v->Visit("reduce_init", &reduce_init); v->Visit("reduce_update", &reduce_update); } static TensorIntrin make(std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers, Stmt body, Stmt reduce_init, Stmt reduce_update); static constexpr const char* _type_key = "TensorIntrin"; TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node); }; inline const TensorIntrinNode* TensorIntrin::operator->() const { return static_cast<const TensorIntrinNode*>(node_.get()); } } // namespace tvm #endif // TVM_TENSOR_INTRIN_H_