tensor_compute_op.cc 9.77 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \brief Tensor Compute Op.
 * \file tensor_compute_op.cc
 */
24
#include <tvm/runtime/registry.h>
25
#include <tvm/te/operation.h>
26
#include <tvm/arith/analyzer.h>
27 28
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
29 30 31
#include <unordered_set>
#include "./op_util.h"
#include "./compute_op.h"
32
#include "../../arith/compute_expr.h"
33 34

namespace tvm {
35
namespace te {
36
using namespace tir;
37
// TensorComputeOpNode
38 39
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
40
    auto* op = static_cast<const TensorComputeOpNode*>(node.get());
41 42 43 44 45 46 47 48 49
    p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
  });

TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);

int TensorComputeOpNode::num_outputs() const {
  return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
}

50
DataType TensorComputeOpNode::output_dtype(size_t i) const {
51 52 53 54 55 56 57 58 59 60
  return this->intrin->buffers[this->inputs.size() + i]->dtype;
}

Operation TensorComputeOpNode::make(std::string name,
                                    std::string tag,
                                    Array<IterVar> axis,
                                    Array<IterVar> reduce_axis,
                                    int schedulable_ndim,
                                    TensorIntrin intrin,
                                    Array<Tensor> tensors,
61
                                    Array<Region> regions,
62
                                    Array<PrimExpr> scalar_inputs) {
63
  auto n = make_object<TensorComputeOpNode>();
64 65 66 67 68 69 70 71
  n->name = std::move(name);
  n->tag = std::move(tag);
  n->axis = std::move(axis);
  n->reduce_axis = std::move(reduce_axis);
  n->schedulable_ndim = std::move(schedulable_ndim);
  n->intrin = std::move(intrin);
  n->inputs = std::move(tensors);
  n->input_regions = std::move(regions);
72
  n->scalar_inputs = std::move(scalar_inputs);
73 74 75
  return Operation(n);
}

76 77 78 79
TVM_REGISTER_GLOBAL("te.TensorComputeOp")
.set_body_typed(TensorComputeOpNode::make);


80 81 82 83 84 85 86 87
Array<Tensor> TensorComputeOpNode::InputTensors() const {
  return inputs;
}

Operation TensorComputeOpNode::ReplaceInputs(
    const Operation& self,
    const std::unordered_map<Tensor, Tensor>& rmap) const {
  CHECK_EQ(self.operator->(), this);
88 89
  auto n = make_object<TensorComputeOpNode>(*this);
  auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
90
  intrin->body = ReplaceTensor(this->intrin->body, rmap);
91
  if (intrin->reduce_init.defined()) {
92
    intrin->reduce_init = ReplaceTensor(this->intrin->reduce_init, rmap);
93 94
  }
  if (intrin->reduce_update.defined()) {
95
    intrin->reduce_update = ReplaceTensor(this->intrin->reduce_update, rmap);
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  }
  for (size_t i = 0; i < n->inputs.size(); ++i) {
    Tensor t = n->inputs[i];
    if (rmap.count(t)) {
      n->inputs.Set(i, rmap.at(t));
    }
  }

  if (intrin->body.same_as(n->intrin->body) &&
      intrin->reduce_init.same_as(n->intrin->reduce_init) &&
      intrin->reduce_update.same_as(n->intrin->reduce_update) &&
      inputs.same_as(n->inputs)) {
    return self;
  } else {
    n->intrin = TensorIntrin(intrin);
    return Operation(n);
  }
}

void TensorComputeOpNode::PropBoundToInputs(
    const Operation& self,
117
    arith::Analyzer* analyzer,
118
    const std::unordered_map<const VarNode*, IntSet>& dom_map,
119 120 121 122 123 124 125 126 127 128 129 130 131 132
    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
  for (size_t i = 0; i < this->inputs.size(); ++i) {
    Tensor t = this->inputs[i];
    Region region = input_regions[i];

    auto it = out_dom_map->find(t);
    if (it == out_dom_map->end()) continue;
    TensorDom& dom = it->second;
    for (size_t j = 0; j < t.ndim(); ++j) {
      dom.data[j].emplace_back(EvalSet(region[j], dom_map));
    }
  }
}

133 134
size_t TensorComputeOpNode::num_schedulable_dims() const {
  return schedulable_ndim;
135 136 137 138 139 140 141 142 143
}

Stmt TensorComputeOpNode::BuildProvide(
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& dom_map,
    bool debug_keep_trivial_loop) const {
  CHECK_EQ(stage->op.operator->(), this);

  // Start bind data.
144
  Stmt nop = EvaluateNode::make(0);
145 146 147 148 149 150 151 152 153
  std::vector<Stmt> input_bind_nest, output_bind_nest;
  Array<Tensor> inputs = this->InputTensors();

  // input binding
  size_t num_inputs = inputs.size();
  for (size_t i = 0; i < num_inputs; ++i) {
    Tensor tensor = inputs[i];
    Region region = this->input_regions[i];
    Buffer buffer = this->intrin->buffers[i];
154
    Array<ObjectRef> bind_spec{buffer, tensor};
155

156
    Array<PrimExpr> tuple;
157 158 159 160
    for (size_t i = 0; i < region.size(); ++i) {
      tuple.push_back(region[i]->min);
      tuple.push_back(region[i]->extent);
    }
161
    input_bind_nest.emplace_back(AttrStmtNode::make(
162
        bind_spec, tir::attr::buffer_bind_scope,
163
        CallNode::make(DataType::Handle(),
164
                       tir::intrinsic::tvm_tuple,
165
                       tuple, CallNode::Intrinsic), nop));
166 167 168 169 170 171
  }

  // output binding
  for (int i = 0; i < this->num_outputs(); ++i) {
    Tensor tensor = stage->op.output(i);
    Buffer buffer = this->intrin->buffers[num_inputs + i];
172
    Array<ObjectRef> bind_spec{buffer, tensor};
173

174
    Array<PrimExpr> tuple;
175 176 177 178 179 180 181 182 183 184 185 186
    for (size_t i = 0; i < this->axis.size(); ++i) {
      auto ivar = this->axis[i];
      if (i < static_cast<size_t>(this->schedulable_ndim)) {
        tuple.push_back(ivar->var);
        tuple.push_back(1);
      } else {
        Range dom = ivar->dom;
        tuple.push_back(dom->min);
        tuple.push_back(dom->extent);
      }
    }

187
    output_bind_nest.emplace_back(AttrStmtNode::make(
188
        bind_spec, tir::attr::buffer_bind_scope,
189
        CallNode::make(DataType::Handle(),
190
                       tir::intrinsic::tvm_tuple,
191
                       tuple, CallNode::Intrinsic), nop));
192 193 194
  }

  // Check variable remap
195
  std::unordered_map<const VarNode*, PrimExpr> vmap;
196
  tir::ArgBinder binder(&vmap);
197

198 199
  // Map the expressions passed in the call to the TensorIntrin, to the placeholder
  // variables
200
  Array<PrimExpr> user_expr = this->scalar_inputs;
201
  Array<Var> scalar_params = this->intrin->scalar_params;
202
  Array<PrimExpr> sp_expr;
203
  for (auto sp : scalar_params) {
204
    PrimExpr esp = sp;
205 206 207 208 209 210
    sp_expr.push_back(esp);
  }
  CHECK_EQ(sp_expr.size(), user_expr.size());
  // TODO(jdavies-huawei): what name should be used here?
  binder.BindArray(sp_expr, user_expr, this->name);

211
  size_t tloc = stage->leaf_iter_vars.size();
212
  ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
213 214 215 216

  if (this->reduce_axis.size() == 0) {
    std::vector<std::vector<Stmt> > nest(
        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
217
    nest.emplace_back(MakeIfNest(n.main_predicates));
218 219 220 221 222
    CHECK_EQ(n.init_predicates.size(), 0U);
    CHECK(this->intrin->body.defined())
        << "Normal store op for intrin " << this << " is not defined";
    Stmt body = MergeNest(output_bind_nest, this->intrin->body);
    body = MergeNest(input_bind_nest, body);
223
    body = tir::Substitute(body, vmap);
224
    body = MergeNest(binder.asserts(), body);
225
    body = te::Substitute(body, n.main_vmap);
226 227 228 229 230 231 232 233 234 235 236 237
    Stmt ret =  MergeNest(nest, body);
    return ret;
  } else {
    // Need to split reduction
    CHECK(this->intrin->reduce_update.defined())
        << "Reduction update op is not defined";
    // Need init and update steps
    CHECK_NE(this->reduce_axis.size(), 0U);
    std::vector<std::vector<Stmt> > common(
        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
    std::vector<std::vector<Stmt> > update_nest(
        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
238
    update_nest.emplace_back(MakeIfNest(n.main_predicates));
239 240 241 242 243

    if (this->intrin->reduce_init.defined()) {
      // init nest
      std::vector<std::vector<Stmt> > init_nest(
          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
244
      init_nest.emplace_back(MakeIfNest(n.init_predicates));
245
      Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
246
      init = te::Substitute(init, n.init_vmap);
247 248 249 250
      init = MergeNest(init_nest, init);
      // The update
      Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
      update = MergeNest(input_bind_nest, update);
251
      update = tir::Substitute(update, vmap);
252
      update = MergeNest(binder.asserts(), update);
253
      update = te::Substitute(update, n.main_vmap);
254
      update = MergeNest(update_nest, update);
255
      return MergeNest(common, SeqStmt::Flatten(init, update));
256 257 258 259 260 261 262 263 264
    } else {
      // When init op is not available, use body op for reset in the first iter.
      CHECK(this->intrin->body.defined())
          << "Normal body op is not defined";
      Stmt update = TransformUpdate(stage, dom_map, n,
                                    this->intrin->body,
                                    this->intrin->reduce_update);
      update = MergeNest(output_bind_nest, update);
      update = MergeNest(input_bind_nest, update);
265
      update = tir::Substitute(update, vmap);
266
      update = MergeNest(binder.asserts(), update);
267
      update = te::Substitute(update, n.main_vmap);
268 269 270 271 272
      update = MergeNest(update_nest, update);
      return MergeNest(common, update);
    }
  }
}
273
}  // namespace te
274
}  // namespace tvm