operation.h 23.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20
/*!
tqchen committed
21
 * \file tvm/operation.h
tqchen committed
22 23 24 25 26 27
 * \brief Operation node can generate one or multiple Tensors
 */
#ifndef TVM_OPERATION_H_
#define TVM_OPERATION_H_

#include <string>
28 29
#include <vector>
#include <unordered_map>
30
#include "expr.h"
31
#include "expr_operator.h"
32 33 34 35
#include "tensor.h"
#include "schedule.h"
#include "arithmetic.h"
#include "buffer.h"
tqchen committed
36 37 38

namespace tvm {

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
using arith::IntSet;

/*!
 * \brief Temporary data structure to store union
 *  of bounds of each axis of Tensor.
 */
struct TensorDom {
  // constructor
  explicit TensorDom(int ndim)
      : data(ndim) {}
  /*! \brief The domain data */
  std::vector<std::vector<IntSet> > data;
};

/*!
 * \brief Base class of all operation nodes
 */
class OperationNode : public FunctionBaseNode {
 public:
  /*! \brief optional name of the operation */
  std::string name;
60 61
  /*! \brief optional tag of the operation */
  std::string tag;
62 63
  /*! \brief addtitional attributes of the operation*/
  Map<std::string, NodeRef> attrs;
64 65 66 67 68 69
  /*! \return name of the operation */
  const std::string& func_name() const final {
    return name;
  }
  /*!
   * \return The list of iteration variable at root
70
   * \note root_iter_vars decides the shape of the outputs.
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
   */
  virtual Array<IterVar> root_iter_vars() const = 0;
  /*!
   * \brief Get data type. i-th output tensor.
   * \param i The output index.
   * \return type of i-th output.
   */
  virtual Type output_dtype(size_t i) const = 0;
  /*!
   * \brief Get shape of i-th output tensor.
   * \param i The output index.
   * \return shape of i-th output.
   */
  virtual Array<Expr> output_shape(size_t i) const = 0;
  /*!
   * \brief List all the input Tensors.
87
   * \return List of input tensors.
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
   */
  virtual Array<Tensor> InputTensors() const = 0;
  /*!
   * \brief Replace the input of the operation by pattern specified by rmap.
   *
   * \param self The reference to self.
   * \param rmap The replacement map.
   * \return self if nothing is replaced, otherwise return replaced op.
   */
  virtual Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
  /*!
   * \brief Propagate the bounds to inputs
   * \param self The reference to self.
   * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
   * \param out_dom_map The output domain.
   *  The function is only asked to fill the bounds for Tensors that
   *  is already in the out_dom_map
   */
  virtual void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
  /*!
   * \brief Gather the bound from output tensor.
   *  Set the range of each root_iter_vars in the op to out_dom_map
   *
   * \param self The reference to self.
   * \param tensor_dom Domain map of Tensor->access set of each dimension.
   * \param out_dom_map The output domain map of each IterVar to be setted.
   */
  virtual void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
  /*!
   * \brief Build the Realize statement that realizes
   *   the op's output tensors.
127
   * \param stage the op's stage.
128 129 130 131 132
   * \param realize_map The realization domain map of the operators.
   * \param body The body that is going to get
   * \return A realization statement that wraps body.
   */
  virtual Stmt BuildRealize(
133
      const Stage& stage,
134 135 136 137 138 139
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const = 0;
  /*!
   * \brief Build the statement that provide the output tensors.
   * \param stage The schedule stage of the op.
   * \param dom_map The domain map of all iteration domains.
140
   * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
141 142 143 144
   * \return A statement that add production and wraps consumer.
   */
  virtual Stmt BuildProvide(
      const Stage& stage,
145
      const std::unordered_map<IterVar, Range>& dom_map,
146
      bool debug_keep_trivial_loop) const = 0;
147 148 149 150 151 152

  static constexpr const char* _type_key = "Operation";

  TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
};

tqchen committed
153
/*!
154 155 156 157 158 159 160 161
 * \brief A placeholder op represents an input placeholder.
 */
class PlaceholderOpNode : public OperationNode {
 public:
  /*! \brief The shape of the input */
  Array<Expr> shape;
  /*! \brief The data type of the input. */
  Type dtype;
162 163
  // override behavior.
  int num_outputs() const final;
164 165 166
  Array<IterVar> root_iter_vars() const final;
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
167 168 169 170 171 172 173 174 175 176 177 178 179
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
180
      const Stage& stage,
181 182 183 184
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
185
      const std::unordered_map<IterVar, Range>& dom_map,
186
      bool debug_keep_trivial_loop) const final;
187 188 189

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
190 191
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
192 193 194 195 196 197 198 199
    v->Visit("shape", &shape);
    v->Visit("dtype", &dtype);
  }
  static Operation make(std::string name,
                        Array<Expr> shape,
                        Type dtype);

  static constexpr const char* _type_key = "PlaceholderOp";
200
  TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
201 202 203
};

/*!
tqchen committed
204
 * \brief A Compute op that compute a tensor on certain domain.
205 206
 * This is the base class for ComputeOp (operating on a scalar at a time) and
 * TensorComputeOp (operating on a TensorSlice at a time)
tqchen committed
207
 */
208
class TVM_DLL BaseComputeOpNode : public OperationNode {
tqchen committed
209
 public:
210 211
  /*! \brief IterVar on each axis */
  Array<IterVar> axis;
212 213
  /*! \brief IterVar on each reduction axis, if the body is a Reduce */
  Array<IterVar> reduce_axis;
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
  // override functions
  Array<IterVar> root_iter_vars() const final;
  Array<Expr> output_shape(size_t idx) const final;
  void GatherBound(
          const Operation& self,
          const std::unordered_map<Tensor, TensorDom>& tensor_dom,
          std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
          const Stage& stage,
          const std::unordered_map<IterVar, Range>& realize_map,
          const Stmt& body) const final;
  virtual size_t num_schedulable_dims() const = 0;

  static constexpr const char* _type_key = "BaseComputeOp";
  TVM_DECLARE_BASE_NODE_INFO(BaseComputeOpNode, OperationNode);
};


/*!
 * \brief A Compute op that compute a tensor on certain domain.
 */
class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
 public:
tqchen committed
237
  /*! \brief the compute expression */
238
  Array<Expr> body;
tqchen committed
239 240
  /*! \brief constructor */
  ComputeOpNode() {}
241 242
  // override functions
  int num_outputs() const final;
tqchen committed
243
  Type output_dtype(size_t i) const final;
244 245 246 247 248 249 250 251 252 253
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  Stmt BuildProvide(
      const Stage& stage,
254
      const std::unordered_map<IterVar, Range>& dom_map,
255
      bool debug_keep_trivial_loop) const final;
256
  size_t num_schedulable_dims() const final;
tqchen committed
257

tqchen committed
258 259
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
260
    v->Visit("tag", &tag);
261
    v->Visit("attrs", &attrs);
262
    v->Visit("axis", &axis);
263
    v->Visit("reduce_axis", &reduce_axis);
tqchen committed
264 265
    v->Visit("body", &body);
  }
266
  static Operation make(std::string name,
267
                        std::string tag,
268
                        Map<std::string, NodeRef> attrs,
269
                        Array<IterVar> axis,
270
                        Array<Expr> body);
271 272

  static constexpr const char* _type_key = "ComputeOp";
273
  TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, BaseComputeOpNode);
tqchen committed
274 275
};

276
/*!
277 278
 * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
 */
279
class TensorComputeOpNode : public BaseComputeOpNode {
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
 public:
  /*! \brief number of axes that can be scheduled */
  int schedulable_ndim;
  /*! \brief TensorIntrin used to compute */
  TensorIntrin intrin;
  /*! \brief input tensors of intrin */
  Array<Tensor> inputs;
  /*! \brief region of input tensors */
  Array<Region> input_regions;
  /*! \brief constructor */
  TensorComputeOpNode() {}
  // override functions
  int num_outputs() const final;
  Type output_dtype(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  Stmt BuildProvide(
      const Stage& stage,
      const std::unordered_map<IterVar, Range>& dom_map,
      bool debug_keep_trivial_loop) const final;
306
  size_t num_schedulable_dims() const final;
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("axis", &axis);
    v->Visit("reduce_axis", &reduce_axis);
    v->Visit("schedulable_ndim", &schedulable_ndim);
    v->Visit("intrin", &intrin);
    v->Visit("inputs", &inputs);
    v->Visit("input_regions", &input_regions);
  }
  static Operation make(std::string name,
                        std::string tag,
                        Array<IterVar> axis,
                        Array<IterVar> reduce_axis,
                        int schedulable_ndim,
                        TensorIntrin intrin,
                        Array<Tensor> tensors,
                        Array<Region> regions);

  static constexpr const char* _type_key = "TensorComputeOp";
328
  TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);
329 330 331
};

/*!
332 333 334 335 336 337 338 339 340 341 342 343 344
 * \brief Symbolic scan.
 */
class ScanOpNode : public OperationNode {
 public:
  /*! \brief IterVar to scan over */
  IterVar scan_axis;
  /*! \brief the initialization tensors */
  Array<Tensor> init;
  /*! \brief the update function represented by tensor */
  Array<Tensor> update;
  /*! \brief The placeholder to refer as states in update. */
  Array<Tensor> state_placeholder;
  /*!
345 346 347 348 349
   * \brief the inputs to the scan, these are optionally provided
   *  But they can be helpful to provide hints to speedup get of scan body.
   */
  Array<Tensor> inputs;
  /*!
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
   * \brief Spatial axis to indicate spatial dimension of each output.
   *  They corresponds to flattened spatial axis of the outputs.
   *
   *  [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
   *  These are auxiliary data structure for storing result of bound inference.
   *  They do not corresponds to splittable iterations, thus the name comes
   *  with underscore.
   */
  Array<IterVar> spatial_axis_;
  /*! \brief constructor */
  ScanOpNode() {}
  // override behavior.
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
366 367 368 369 370 371 372 373 374 375 376 377 378
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
379
      const Stage& stage,
380 381 382 383
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
384
      const std::unordered_map<IterVar, Range>& dom_map,
385
      bool debug_keep_trivial_loop) const final;
386 387 388

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
389
    v->Visit("tag", &tag);
390
    v->Visit("attrs", &attrs);
391 392 393 394
    v->Visit("scan_axis", &scan_axis);
    v->Visit("init", &init);
    v->Visit("update", &update);
    v->Visit("state_placeholder", &state_placeholder);
395
    v->Visit("inputs", &inputs);
396 397 398
    v->Visit("spatial_axis_", &spatial_axis_);
  }
  static Operation make(std::string name,
399
                        std::string tag,
400
                        Map<std::string, NodeRef> attrs,
401 402 403
                        IterVar axis,
                        Array<Tensor> init,
                        Array<Tensor> update,
404 405
                        Array<Tensor> state_placeholder,
                        Array<Tensor> input);
406 407

  static constexpr const char* _type_key = "ScanOp";
408
  TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
409 410
};

411 412 413 414 415 416 417
/*!
 * \brief External computation that cannot be splitted.
 */
class ExternOpNode : public OperationNode {
 public:
  /*! \brief The input tensors */
  Array<Tensor> inputs;
418
  /*! \brief Symbolic placeholder representation of inputs */
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
  Array<Buffer> input_placeholders;
  /*! \brief Symbolic placeholder representation of outputs */
  Array<Buffer> output_placeholders;
  /*! \brief the statement that generates the computation. */
  Stmt body;

  /*! \brief constructor */
  ExternOpNode() {}
  // override functions
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
445
      const Stage& stage,
446 447 448 449
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
450
      const std::unordered_map<IterVar, Range>& dom_map,
451
      bool debug_keep_trivial_loop) const final;
452 453 454

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
455
    v->Visit("tag", &tag);
456
    v->Visit("attrs", &attrs);
457
    v->Visit("inputs", &inputs);
458 459
    v->Visit("input_placeholders", &input_placeholders);
    v->Visit("output_placeholders", &output_placeholders);
460 461
    v->Visit("body", &body);
  }
462
  EXPORT static Operation make(std::string name,
463 464 465 466 467 468
                               std::string tag,
                               Map<std::string, NodeRef> attrs,
                               Array<Tensor> inputs,
                               Array<Buffer> input_placeholders,
                               Array<Buffer> output_placeholders,
                               Stmt body);
469 470 471 472

  static constexpr const char* _type_key = "ExternOp";
  TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
};
tqchen committed
473

474 475 476 477 478 479 480 481 482
/*!
 * \brief A computation operator that generated by hybrid script.
 */
class HybridOpNode : public OperationNode {
 public:
  /*! \brief The input tensors */
  Array<Tensor> inputs;
  /*! \brief Symbolic placeholder representation of outputs */
  Array<Tensor> outputs;
483 484
  /*! \brief The axis of iterations */
  Array<IterVar> axis;
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
  /*! \brief the statement that generates the computation. This is
   * slightly different from the body in ExternOpNode. All the output
   * tensors keep its own name specified by users in the script.
   * However, when compilation, these tensors will be placed by those
   * actual output tensors. */
  Stmt body;

  /*! \brief constructor */
  HybridOpNode() {}
  // override functions
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  Type output_dtype(size_t i) const final;
  Array<Expr> output_shape(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(
      const Operation& self,
      const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(
      const Operation& self,
      const std::unordered_map<const Variable*, IntSet>& dom_map,
      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(
      const Operation& self,
      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
      std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(
      const Stage& stage,
      const std::unordered_map<IterVar, Range>& realize_map,
      const Stmt& body) const final;
  Stmt BuildProvide(
      const Stage& stage,
      const std::unordered_map<IterVar, Range>& dom_map,
      bool debug_keep_trivial_loop) const final;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
    v->Visit("inputs", &inputs);
    v->Visit("outputs", &outputs);
526
    v->Visit("axis", &axis);
527 528 529 530 531 532 533 534 535 536 537 538 539
    v->Visit("body", &body);
  }
  EXPORT static Operation make(std::string name,
                               std::string tag,
                               Map<std::string, NodeRef> attrs,
                               Array<Tensor> inputs,
                               Array<Tensor> outputs,
                               Stmt body);

  static constexpr const char* _type_key = "HybridOp";
  TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
};

tqchen committed
540 541 542
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;

543 544 545
/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;

tqchen committed
546
/*!
547 548 549 550 551
 * \brief create a place holder tensor.
 * \param shape The shape of the tensor.
 * \param dtype the data type of the tensor.
 * \param name The name of the Tensor.
 */
552 553 554
TVM_DLL Tensor placeholder(Array<Expr> shape,
                           Type dtype = Float(32),
                           std::string name = "placeholder");
555 556

/*!
tqchen committed
557 558 559 560 561
 * \brief Construct a new tensor by computing over shape,
 *  using the computation rule: result_tensor[axis] = fcompute(axis)
 * \param shape Shape of the tensor.
 * \param fcompute The compute function to create the tensor.
 * \param name The optional name of the tensor.
562
 * \param tag The optional tag of the tensor.
563
 * \param attrs Optional additional attributes of the compute.
tqchen committed
564
 */
565 566 567
TVM_DLL Tensor compute(Array<Expr> shape,
                       FCompute fcompute,
                       std::string name = "tensor",
568 569
                       std::string tag = "",
                       Map<std::string, NodeRef> attrs = {});
tqchen committed
570

571
/*!
572 573 574 575 576
 * \brief Construct a new tensor by computing over shape,
 *  using the computation rule: result_tensor[axis] = fcompute(axis)
 * \param shape Shape of the tensor.
 * \param fcompute The compute function to create the tensors.
 * \param name The optional name of the tensor.
577
 * \param tag The optional tag of the tensor.
578
 * \param attrs Optional additional attributes of the compute.
579
 */
580 581 582
TVM_DLL Array<Tensor> compute(Array<Expr> shape,
                              FBatchCompute fcompute,
                              std::string name = "tensor",
583 584
                              std::string tag = "",
                              Map<std::string, NodeRef> attrs = {});
585 586

/*!
587
 * \brief Construct new tensors by scan.
588 589 590 591
 *
 * \param init The intialize tensor of first K steps.
 * \param update The update tensor indicated the updated result after each timestamp.
 * \param state_placeholder The placeholder for the states.
592 593
 * \param inputs The inputs to the scan body, this is optional,
 *    but recommended to provide concrete information about scan body.
594
 * \param name The optional name of the tensor.
595
 * \param tag The optional tag of the tensor.
596
 * \param attrs Optional additional attributes of the compute.
597
 */
598 599 600 601 602
TVM_DLL Array<Tensor> scan(Array<Tensor> init,
                           Array<Tensor> update,
                           Array<Tensor> state_placeholder,
                           Array<Tensor> inputs = Array<Tensor>(),
                           std::string name = "scan",
603 604
                           std::string tag = "",
                           Map<std::string, NodeRef> attrs = {});
605

tqchen committed
606
// same as compute, specialized for different fcompute function
607
inline Tensor compute(Array<Expr> shape,
tqchen committed
608
                      std::function<Expr(Var)> f,
609
                      std::string name = "tensor",
610 611
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
612
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
613
  return compute(shape, fc, name, tag, attrs);
tqchen committed
614
}
615
inline Tensor compute(Array<Expr> shape,
tqchen committed
616
                      std::function<Expr(Var, Var)> f,
617
                      std::string name = "tensor",
618 619
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
620
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
621
  return compute(shape, fc, name, tag, attrs);
tqchen committed
622
}
623
inline Tensor compute(Array<Expr> shape,
tqchen committed
624
                      std::function<Expr(Var, Var, Var)> f,
625
                      std::string name = "tensor",
626 627
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
628
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
629
  return  compute(shape, fc, name, tag, attrs);
tqchen committed
630
}
631
inline Tensor compute(Array<Expr> shape,
tqchen committed
632
                      std::function<Expr(Var, Var, Var, Var)> f,
633
                      std::string name = "tensor",
634 635
                      std::string tag = "",
                      Map<std::string, NodeRef> attrs = {}) {
tqchen committed
636
  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
637
  return compute(shape, fc, name, tag, attrs);
tqchen committed
638 639
}

640 641 642 643
// inline function.
inline const OperationNode* Operation::operator->() const {
  return static_cast<const OperationNode*>(node_.get());
}
tqchen committed
644 645
}  // namespace tvm
#endif  // TVM_OPERATION_H_