tensor.cc 2.43 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2016 by Contributors
 * \file tensor.cc
 */
#include <tvm/tensor.h>
6
#include <tvm/operation.h>
7
#include <tvm/tensor_intrin.h>
8 9 10 11 12
#include <ir/IR.h>
#include <memory>

namespace tvm {

13 14 15 16 17
Expr Tensor::operator()(Array<Var> indices) const {
  Array<Expr> arr(indices.begin(), indices.end());
  return operator()(arr);
}

18
Expr Tensor::operator()(Array<Expr> indices) const {
19
  using HalideIR::Internal::Call;
20 21 22
  CHECK_EQ(ndim(), indices.size())
      << "Tensor dimension mismatch in read"
      << "ndim = " << ndim() << ", indices.size=" << indices.size();
tqchen committed
23
  auto n = Call::make(
24 25
      (*this)->dtype, (*this)->op->name, indices, Call::Halide,
      (*this)->op, (*this)->value_index);
tqchen committed
26
  return n;
27 28
}

tqchen committed
29 30
Tensor TensorNode::make(Array<Expr> shape,
                        Type dtype,
tqchen committed
31 32
                        Operation op,
                        int value_index) {
tqchen committed
33
  auto n = std::make_shared<TensorNode>();
34
  n->shape = std::move(shape);
tqchen committed
35
  n->dtype = dtype;
tqchen committed
36 37
  n->op = op;
  n->value_index = value_index;
tqchen committed
38 39 40
  return Tensor(n);
}

tqchen committed
41 42 43
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
    p->stream << "Tensor(shape=" << t->shape
44
              << ", op.name=" << t->op->name << ')';
tqchen committed
45 46
  });

47 48
TVM_REGISTER_NODE_TYPE(TensorNode);

49 50 51 52 53 54 55 56 57
Tensor Operation::output(size_t i) const {
  auto node = std::make_shared<TensorNode>();
  node->op = *this;
  node->value_index = i;
  node->dtype = (*this)->output_dtype(i);
  node->shape = (*this)->output_shape(i);
  return Tensor(node);
}

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
TensorIntrin TensorIntrinNode::make(std::string name,
                                    Operation op,
                                    Array<Tensor> inputs,
                                    Array<Buffer> buffers,
                                    Stmt body,
                                    Stmt reduce_init,
                                    Stmt reduce_update) {
  auto n = std::make_shared<TensorIntrinNode>();
  n->name = std::move(name);
  n->op = std::move(op);
  n->inputs = std::move(inputs);
  n->buffers = std::move(buffers);
  n->body = std::move(body);
  n->reduce_init = std::move(reduce_init);
  n->reduce_update = std::move(reduce_update);
  return TensorIntrin(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const TensorIntrinNode *n, IRPrinter *p) {
    p->stream << "TensorIntrin(name=" << n->name << ", " << n << ")";
  });

TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
82
}  // namespace tvm