operation.h 17.5 KB
Newer Older
tqchen committed
1 2
/*!
 *  Copyright (c) 2016 by Contributors
tqchen committed
3
 * \file tvm/operation.h
tqchen committed
4 5 6 7 8 9
 * \brief Operation node can generate one or multiple Tensors
 */
#ifndef TVM_OPERATION_H_
#define TVM_OPERATION_H_

#include <string>
10 11
#include <vector>
#include <unordered_map>
12 13 14 15 16 17
#include "expr.h"
#include "ir_operator.h"
#include "tensor.h"
#include "schedule.h"
#include "arithmetic.h"
#include "buffer.h"
tqchen committed
18 19 20

namespace tvm {

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
using arith::IntSet;

/*!
 * \brief Temporary data structure to store union
 *  of bounds of each axis of Tensor.
 */
struct TensorDom {
  // constructor
  explicit TensorDom(int ndim)
      : data(ndim) {}
  /*! \brief The domain data */
  std::vector<std::vector<IntSet> > data;
};

/*!
 * \brief Base class of all operation nodes
 */
class OperationNode : public FunctionBaseNode {
 public:
  /*! \brief optional name of the operation */
  std::string name;
42 43
  /*! \brief optional tag of the operation */
  std::string tag;
44 45
  /*! \brief addtitional attributes of the operation*/
  Map<std::string, NodeRef> attrs;
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
  /*! \return name of the operation */
  const std::string& func_name() const final {
    return name;
  }
  /*!
   * \return The list of iteration variable at root
   * \note root_iter_vars dedides the shape of the outputs.
   */
  virtual Array<IterVar> root_iter_vars() const = 0;
  /*!
   * \brief Get data type. i-th output tensor.
   * \param i The output index.
   * \return type of i-th output.
   */
  virtual Type output_dtype(size_t i) const = 0;
  /*!
   * \brief Get shape of i-th output tensor.
   * \param i The output index.
   * \return shape of i-th output.
   */
  virtual Array<Expr> output_shape(size_t i) const = 0;
  /*!
   * \brief List all the input Tensors.
69
   * \return List of input tensors.
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
   */
  virtual Array<Tensor> InputTensors() const = 0;
  /*!
   * \brief Replace the input of the operation by pattern specified by rmap.
   *
   * \param self The reference to self.
   * \param rmap The replacement map.
   * \return self if nothing is replaced, otherwise return replaced op.
   */
  virtual Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
  /*!
   * \brief Propagate the bounds to inputs
   * \param self The reference to self.
   * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
   * \param out_dom_map The output domain.
   *  The function is only asked to fill the bounds for Tensors that
   *  is already in the out_dom_map
   */
  virtual void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
  /*!
   * \brief Gather the bound from output tensor.
   *  Set the range of each root_iter_vars in the op to out_dom_map
   *
   * \param self The reference to self.
   * \param tensor_dom Domain map of Tensor->access set of each dimension.
   * \param out_dom_map The output domain map of each IterVar to be setted.
   */
  virtual void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
  /*!
   * \brief Build the Realize statement that realizes
   *   the op's output tensors.
109
   * \param stage the op's stage.
110 111 112 113 114
   * \param realize_map The realization domain map of the operators.
   * \param body The body that is going to get
   * \return A realization statement that wraps body.
   */
  virtual Stmt BuildRealize(
115
      const Stage& stage,
116 117 118 119 120 121
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const = 0;
  /*!
   * \brief Build the statement that provide the output tensors.
   * \param stage The schedule stage of the op.
   * \param dom_map The domain map of all iteration domains.
122
   * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
123 124 125 126
   * \return A statement that add production and wraps consumer.
   */
  virtual Stmt BuildProvide(
      const Stage& stage,
127
      const std::unordered_map<IterVar, Range>& dom_map,
128
      bool debug_keep_trivial_loop) const = 0;
129 130 131 132 133 134

  static constexpr const char* _type_key = "Operation";

  TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
};

tqchen committed
135
/*!
136 137 138 139 140 141 142 143
 * \brief A placeholder op represents an input placeholder.
 */
class PlaceholderOpNode : public OperationNode {
 public:
  /*! \brief The shape of the input */
  Array<Expr> shape;
  /*! \brief The data type of the input. */
  Type dtype;
144 145
  // override behavior.
  int num_outputs() const final;
146 147 148
  Array<IterVar> root_iter_vars() const final;
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
149 150 151 152 153 154 155 156 157 158 159 160 161
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
162
      const Stage& stage,
163 164 165 166
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
167
      const std::unordered_map<IterVar, Range>& dom_map,
168
      bool debug_keep_trivial_loop) const final;
169 170 171

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
172 173
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
174 175 176 177 178 179 180 181
    v->Visit("shape", &shape);
    v->Visit("dtype", &dtype);
  }
  static Operation make(std::string name,
                        Array<Expr> shape,
                        Type dtype);

  static constexpr const char* _type_key = "PlaceholderOp";
182
  TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
183 184 185
};

/*!
tqchen committed
186
 * \brief A Compute op that compute a tensor on certain domain.
tqchen committed
187
 */
webberg committed
188
class TVM_DLL ComputeOpNode : public OperationNode {
tqchen committed
189
 public:
190 191
  /*! \brief IterVar on each axis */
  Array<IterVar> axis;
192 193
  /*! \brief IterVar on each reduction axis, if the body is a Reduce */
  Array<IterVar> reduce_axis;
tqchen committed
194
  /*! \brief the compute expression */
195
  Array<Expr> body;
tqchen committed
196 197
  /*! \brief constructor */
  ComputeOpNode() {}
198 199
  // override functions
  int num_outputs() const final;
200
  Array<IterVar> root_iter_vars() const final;
tqchen committed
201 202
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
203 204 205 206 207 208 209 210 211 212 213 214 215
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
216
      const Stage& stage,
217 218 219 220
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
221
      const std::unordered_map<IterVar, Range>& dom_map,
222
      bool debug_keep_trivial_loop) const final;
tqchen committed
223

tqchen committed
224 225
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
226
    v->Visit("tag", &tag);
227
    v->Visit("attrs", &attrs);
228
    v->Visit("axis", &axis);
229
    v->Visit("reduce_axis", &reduce_axis);
tqchen committed
230 231
    v->Visit("body", &body);
  }
232
  static Operation make(std::string name,
233
                        std::string tag,
234
                        Map<std::string, NodeRef> attrs,
235
                        Array<IterVar> axis,
236
                        Array<Expr> body);
237 238

  static constexpr const char* _type_key = "ComputeOp";
239
  TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
tqchen committed
240 241
};

242 243 244 245 246 247 248 249 250 251 252 253 254 255
/*!
 * \brief Symbolic scan.
 */
class ScanOpNode : public OperationNode {
 public:
  /*! \brief IterVar to scan over */
  IterVar scan_axis;
  /*! \brief the initialization tensors */
  Array<Tensor> init;
  /*! \brief the update function represented by tensor */
  Array<Tensor> update;
  /*! \brief The placeholder to refer as states in update. */
  Array<Tensor> state_placeholder;
  /*!
256 257 258 259 260
   * \brief the inputs to the scan, these are optionally provided
   *  But they can be helpful to provide hints to speedup get of scan body.
   */
  Array<Tensor> inputs;
  /*!
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
   * \brief Spatial axis to indicate spatial dimension of each output.
   *  They corresponds to flattened spatial axis of the outputs.
   *
   *  [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
   *  These are auxiliary data structure for storing result of bound inference.
   *  They do not corresponds to splittable iterations, thus the name comes
   *  with underscore.
   */
  Array<IterVar> spatial_axis_;
  /*! \brief constructor */
  ScanOpNode() {}
  // override behavior.
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
277 278 279 280 281 282 283 284 285 286 287 288 289
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
290
      const Stage& stage,
291 292 293 294
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
295
      const std::unordered_map<IterVar, Range>& dom_map,
296
      bool debug_keep_trivial_loop) const final;
297 298 299

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
300
    v->Visit("tag", &tag);
301
    v->Visit("attrs", &attrs);
302 303 304 305
    v->Visit("scan_axis", &scan_axis);
    v->Visit("init", &init);
    v->Visit("update", &update);
    v->Visit("state_placeholder", &state_placeholder);
306
    v->Visit("inputs", &inputs);
307 308 309
    v->Visit("spatial_axis_", &spatial_axis_);
  }
  static Operation make(std::string name,
310
                        std::string tag,
311
                        Map<std::string, NodeRef> attrs,
312 313 314
                        IterVar axis,
                        Array<Tensor> init,
                        Array<Tensor> update,
315 316
                        Array<Tensor> state_placeholder,
                        Array<Tensor> input);
317 318

  static constexpr const char* _type_key = "ScanOp";
319
  TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
320 321
};

322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
/*!
 * \brief External computation that cannot be splitted.
 */
class ExternOpNode : public OperationNode {
 public:
  /*! \brief The input tensors */
  Array<Tensor> inputs;
  /*! \brief Symbolic placeholder representationinputs */
  Array<Buffer> input_placeholders;
  /*! \brief Symbolic placeholder representation of outputs */
  Array<Buffer> output_placeholders;
  /*! \brief the statement that generates the computation. */
  Stmt body;

  /*! \brief constructor */
  ExternOpNode() {}
  // override functions
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
356
      const Stage& stage,
357 358 359 360
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
361
      const std::unordered_map<IterVar, Range>& dom_map,
362
      bool debug_keep_trivial_loop) const final;
363 364 365

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
366
    v->Visit("tag", &tag);
367
    v->Visit("attrs", &attrs);
368
    v->Visit("inputs", &inputs);
369 370
    v->Visit("input_placeholders", &input_placeholders);
    v->Visit("output_placeholders", &output_placeholders);
371 372
    v->Visit("body", &body);
  }
373
  EXPORT static Operation make(std::string name,
374 375 376 377 378 379
                               std::string tag,
                               Map<std::string, NodeRef> attrs,
                               Array<Tensor> inputs,
                               Array<Buffer> input_placeholders,
                               Array<Buffer> output_placeholders,
                               Stmt body);
380 381 382 383

  static constexpr const char* _type_key = "ExternOp";
  TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
};
tqchen committed
384 385 386 387

/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;

388 389 390
/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;

tqchen committed
391
/*!
392 393 394 395 396
 * \brief create a place holder tensor.
 * \param shape The shape of the tensor.
 * \param dtype the data type of the tensor.
 * \param name The name of the Tensor.
 */
397 398 399
TVM_DLL Tensor placeholder(Array<Expr> shape,
                           Type dtype = Float(32),
                           std::string name = "placeholder");
400 401

/*!
tqchen committed
402 403 404 405 406
 * \brief Construct a new tensor by computing over shape,
 *  using the computation rule: result_tensor[axis] = fcompute(axis)
 * \param shape Shape of the tensor.
 * \param fcompute The compute function to create the tensor.
 * \param name The optional name of the tensor.
407
 * \param tag The optional tag of the tensor.
408
 * \param attrs Optional additional attributes of the compute.
tqchen committed
409
 */
410 411 412
TVM_DLL Tensor compute(Array<Expr> shape,
                       FCompute fcompute,
                       std::string name = "tensor",
413 414
                       std::string tag = "",
                       Map<std::string, NodeRef> attrs = {});
tqchen committed
415

416
/*!
417 418 419 420 421
 * \brief Construct a new tensor by computing over shape,
 *  using the computation rule: result_tensor[axis] = fcompute(axis)
 * \param shape Shape of the tensor.
 * \param fcompute The compute function to create the tensors.
 * \param name The optional name of the tensor.
422
 * \param tag The optional tag of the tensor.
423
 * \param attrs Optional additional attributes of the compute.
424
 */
425 426 427
TVM_DLL Array<Tensor> compute(Array<Expr> shape,
                              FBatchCompute fcompute,
                              std::string name = "tensor",
428 429
                              std::string tag = "",
                              Map<std::string, NodeRef> attrs = {});
430 431

/*!
432
 * \brief Construct new tensors by scan.
433 434 435 436
 *
 * \param init The intialize tensor of first K steps.
 * \param update The update tensor indicated the updated result after each timestamp.
 * \param state_placeholder The placeholder for the states.
437 438
 * \param inputs The inputs to the scan body, this is optional,
 *    but recommended to provide concrete information about scan body.
439
 * \param name The optional name of the tensor.
440
 * \param tag The optional tag of the tensor.
441
 * \param attrs Optional additional attributes of the compute.
442
 */
443 444 445 446 447
TVM_DLL Array<Tensor> scan(Array<Tensor> init,
                           Array<Tensor> update,
                           Array<Tensor> state_placeholder,
                           Array<Tensor> inputs = Array<Tensor>(),
                           std::string name = "scan",
448 449
                           std::string tag = "",
                           Map<std::string, NodeRef> attrs = {});
450

tqchen committed
451
// same as compute, specialized for different fcompute function
452
inline Tensor compute(Array<Expr> shape,
tqchen committed
453
                      std::function<Expr(Var)> f,
454
                      std::string name = "tensor",
455 456
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
457
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
458
  return compute(shape, fc, name, tag, attrs);
tqchen committed
459
}
460
inline Tensor compute(Array<Expr> shape,
tqchen committed
461
                      std::function<Expr(Var, Var)> f,
462
                      std::string name = "tensor",
463 464
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
465
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
466
  return compute(shape, fc, name, tag, attrs);
tqchen committed
467
}
468
inline Tensor compute(Array<Expr> shape,
tqchen committed
469
                      std::function<Expr(Var, Var, Var)> f,
470
                      std::string name = "tensor",
471 472
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
473
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
474
  return  compute(shape, fc, name, tag, attrs);
tqchen committed
475
}
476
inline Tensor compute(Array<Expr> shape,
tqchen committed
477
                      std::function<Expr(Var, Var, Var, Var)> f,
478
                      std::string name = "tensor",
479 480
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
481
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
482
  return compute(shape, fc, name, tag, attrs);
tqchen committed
483 484
}

485 486 487 488
// inline function.
inline const OperationNode* Operation::operator->() const {
  return static_cast<const OperationNode*>(node_.get());
}
tqchen committed
489 490
}  // namespace tvm
#endif  // TVM_OPERATION_H_