tensor_intrin.h 2.77 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
/*!
 *  Copyright (c) 2017 by Contributors
 * \file tensor_intrin.h
 * \brief Tensor intrinsic operations.
 */
#ifndef TVM_TENSOR_INTRIN_H_
#define TVM_TENSOR_INTRIN_H_

#include <string>
#include "./tensor.h"
#include "./buffer.h"

namespace tvm {

// Internal node container of tensor intrinsics.
class TensorIntrinNode;

/*! \brief Tensor intrinsic node. */
class TensorIntrin : public NodeRef {
 public:
  TensorIntrin() {}
  explicit TensorIntrin(std::shared_ptr<Node> n) : NodeRef(n) {}
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const TensorIntrinNode* operator->() const;

  /*! \brief specify container node */
  using ContainerType = TensorIntrinNode;
};

/*! \brief Node to represent a Tensor intrinsic operator */
class TensorIntrinNode : public Node {
 public:
  /*! \brief The name of the intrinsic */
  std::string name;
  /*! \brief The operation this intrinsics is carrying out */
  Operation op;
  /*! \brief List of inputs of operator, placeholder in postdfs order */
  Array<Tensor> inputs;
  /*!
   * \brief Symbolic buffers of each output/input tensor
   *  buffers[0:len(inputs)] are buffers of the inputs.
   *  buffers[len(inputs):] are buffers of each output.
   *
   * \note When a field in Buffer is Var, it means we can be flexible
   *  wrt that field and Var can occur in body.
   *  When it is a constant, it means we can only take data in that shape.
   */
  Array<Buffer> buffers;
  /*! \brief The normal statement to execute the intrinsic */
  Stmt body;
  /*!
   * \brief Special statement for reduction op, can be None
   *  reset the value of output buffer to identity value.
   */
  Stmt reduce_init;
  /*!
   * \brief Special statement for reduction op, can be None
   *  Reduce: do a reduction of current output buffer with the result.
   */
  Stmt reduce_update;
  /*! \brief constructor */
  TensorIntrinNode() {}

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
    v->Visit("op", &op);
    v->Visit("inputs", &inputs);
    v->Visit("buffers", &buffers);
    v->Visit("body", &body);
    v->Visit("reduce_init", &reduce_init);
    v->Visit("reduce_update", &reduce_update);
  }

  static TensorIntrin make(std::string name,
                           Operation op,
                           Array<Tensor> inputs,
                           Array<Buffer> buffers,
                           Stmt body,
                           Stmt reduce_init,
                           Stmt reduce_update);

  static constexpr const char* _type_key = "TensorIntrin";
  TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node);
};

inline const TensorIntrinNode* TensorIntrin::operator->() const {
  return static_cast<const TensorIntrinNode*>(node_.get());
}
}  // namespace tvm
#endif  // TVM_TENSOR_INTRIN_H_