tensorize.cc 20.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \brief Logics related to tensorize, used by ComputeOpNode.
 * \file tensorize.cc
 */
24 25 26
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
27 28
#include <tvm/runtime/registry.h>

29 30
#include "op_util.h"
#include "compute_op.h"
31 32 33
#include "../schedule/message_passing.h"

namespace tvm {
34
namespace te {
35

36
using namespace tir;
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

// Detect the region of input and output to be tensrized.
// out_dom: the domain of root iter vars in output op
// in_region: region of each input tensor.
// return The location of the tensorized scope start.
size_t InferTensorizeRegion(
    const ComputeOpNode* self,
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& dom_map,
    std::unordered_map<IterVar, Range>* out_dom,
    std::unordered_map<Tensor, Array<Range> >* in_region) {
  // Get the bound of the tensorized scope.
  bool found_point = false;
  size_t loc_scope = 0;
  std::unordered_map<IterVar, IntSet> up_state;
  // Loop over the leafs
  for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
    IterVar iv = stage->leaf_iter_vars[i - 1];
    CHECK(iv->iter_type == kDataPar ||
          iv->iter_type == kCommReduce);
    auto vit = dom_map.find(iv);
    CHECK(vit != dom_map.end());
    const Range& vrange = vit->second;
    if (is_one(vrange->extent)) {
      up_state[iv] = IntSet::single_point(vrange->min);
    } else if (found_point) {
      CHECK(is_zero(vrange->min));
      up_state[iv] = IntSet::single_point(iv->var);
    } else {
      up_state[iv] = IntSet::range(vrange);
    }
    auto iit = stage->iter_var_attrs.find(iv);
    if (iit != stage->iter_var_attrs.end()) {
      const IterVarAttr& attr = (*iit).second;
      if (!found_point) {
        CHECK(!attr->bind_thread.defined())
Siju committed
73
            << "Do not allow thread in tensorize scope";
74 75
      }
      if (attr->iter_type == kTensorized) {
Siju committed
76
        CHECK(!found_point) << "Do not allow two tensorized point";
77 78 79 80 81 82 83
        found_point = true;
        loc_scope = i - 1;
      }
    }
  }
  CHECK(found_point);
  // Get domain of the tensorized scope.
84
  te::PassUpDomain(stage, dom_map, &up_state);
85 86
  // Get domains if inputs
  std::unordered_map<Tensor, TensorDom> in_dom;
87
  std::unordered_map<const VarNode*, IntSet> temp_dmap;
88
  arith::Analyzer analyzer;
89 90 91 92 93 94
  Array<Tensor> inputs = self->InputTensors();
  for (Tensor t : inputs) {
    in_dom.emplace(t, TensorDom(t.ndim()));
  }
  for (IterVar iv : self->root_iter_vars()) {
    IntSet iset = up_state.at(iv);
95 96 97
    Range iv_range = iset.cover_range(dom_map.at(iv));
    (*out_dom)[iv] = iv_range;
    analyzer.Bind(iv->var, iv_range);
98 99 100
    temp_dmap[iv->var.get()] = iset;
  }
  // Input domains
101
  self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
102 103 104 105
  Range none;
  for (const auto& kv : in_dom) {
    Array<Range> vec;
    const Tensor& t = kv.first;
106
    for (size_t i = 0; i < t.ndim(); ++i) {
107 108 109 110 111 112 113 114 115 116 117 118 119 120
      Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
      CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
      vec.push_back(std::move(r));
    }
    (*in_region)[t] = std::move(vec);
  }
  return loc_scope;
}

void VerifyTensorizeLoopNest(const ComputeOpNode* self,
                             const Stage& stage,
                             const ComputeLoopNest& n,
                             size_t tloc) {
  // Veirfication step.
121
  std::unordered_set<const VarNode*> banned;
122 123 124 125
  CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
  CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
        n.init_nest.size() == 0);
  auto f_push_banned = [&banned](const Stmt& s) {
126
    if (const ForNode* op = s.as<ForNode>()) {
127
        banned.insert(op->loop_var.get());
128
    } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
129 130 131
      if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
        banned.insert(iv->var.get());
      }
132
    } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
133 134 135 136 137 138 139 140 141 142 143 144 145
      banned.insert(op->var.get());
    }
  };
  for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
    for (const Stmt& s : n.main_nest[i + 1]) {
      f_push_banned(s);
    }
    if (n.init_nest.size() != 0) {
      for (const Stmt& s : n.init_nest[i + 1]) {
        f_push_banned(s);
      }
    }
  }
146
  for (const PrimExpr& pred : n.main_predicates) {
147
    if (tir::ExprUseVar(pred, banned)) {
148 149 150 151
      LOG(FATAL) << "Tensorize failed, split condition "
                 << pred << " relies on var defined inside tensorize scope";
    }
  }
152
  for (const PrimExpr& pred : n.init_predicates) {
153
    if (tir::ExprUseVar(pred, banned)) {
154 155 156 157 158 159 160
      LOG(FATAL) << "Tensorize failed, split condition "
                 << pred << " relies on var defined inside tensorize scope";
    }
  }
}

// Remap the tensor placeholder, index and inline things.
161
class TensorIntrinMatcher final : public StmtExprMutator {
162
 public:
163 164
  PrimExpr VisitExpr_(const CallNode* op) final {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
165 166
    op = expr.as<CallNode>();
    if (op->call_type == CallNode::Halide) {
167
      Tensor t = Downcast<Operation>(op->func).output(op->value_index);
168 169 170 171
      auto it = in_remap_.find(t);
      if (it != in_remap_.end()) {
        const InputEntry& e = it->second;
        CHECK_EQ(op->args.size(), e.region.size());
172
        Array<PrimExpr> args;
173 174 175
        for (size_t i = e.start; i < e.region.size(); ++i) {
          args.push_back(op->args[i] - e.region[i]->min);
        }
176
        return CallNode::make(
177
            op->dtype, e.tensor->op->name, args,
178 179 180 181 182 183
            op->call_type, e.tensor->op, e.tensor->value_index);
      }
    }
    return expr;
  }

184
  PrimExpr VisitExpr_(const VarNode* op) final {
185 186 187 188
    auto it = var_remap_.find(op);
    if (it != var_remap_.end()) {
      return it->second;
    } else {
189
      return GetRef<PrimExpr>(op);
190 191 192
    }
  }

193 194
  PrimExpr VisitExpr_(const ReduceNode* op) final {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
195
    op = expr.as<ReduceNode>();
196
    Array<IterVar> axis;
197 198 199
    for (size_t i = 0; i < op->axis.size(); ++i) {
      auto it = axis_remap_.find(op->axis[i]);
      if (it != axis_remap_.end()) {
200
        axis.push_back(it->second);
201 202
      }
    }
203
    return ReduceNode::make(
204
        op->combiner, op->source, axis, op->condition, op->value_index);
205 206 207 208
  }

  void Init(const ComputeOpNode* self,
            const Stage& stage,
209
            const std::unordered_map<IterVar, Range>& dom_map,
210 211
            const std::unordered_map<IterVar, Range>& out_dom,
            const std::unordered_map<Tensor, Array<Range> >& in_region,
212 213
            const TensorIntrin& intrin,
            Map<Var, Range>* compute_intrin_iter_space) {
214
    CHECK(self == stage->op.get());
215 216 217 218 219 220 221 222 223 224

    for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
      IterVar iv = stage->leaf_iter_vars[i];
      auto vit = dom_map.find(iv);
      if (vit != dom_map.end()) {
        const Range vrange = vit->second;
        compute_intrin_iter_space->Set(iv->var, vrange);
      }
    }

225 226 227 228 229 230 231 232 233 234
    // input remap.
    Array<Tensor> inputs = self->InputTensors();
    CHECK_EQ(inputs.size(), intrin->inputs.size());
    for (size_t i = 0; i < inputs.size(); ++i) {
      InputEntry e;
      e.tensor = intrin->inputs[i];
      e.region = Array<Range>(in_region.at(inputs[i]));
      CHECK_GE(e.region.size(), e.tensor.ndim());
      // Enable fuzzy matching, to match [1, n, m] to [n, m]
      e.start = e.region.size() - e.tensor.ndim();
235 236 237
      for (size_t j = 0; j < e.start; ++j) {
        auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
        CHECK(is_one(canonical_extent))
238 239
            << "Tensorize " << intrin->name << ":"
            << " Input dimension mismatch with tensor intrin "
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
            << " expected shape=" << e.tensor->shape
            << ", given region=" << e.region;
      }
      in_remap_[inputs[i]] = e;
    }
    // output remap
    const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
    CHECK(intrin_compute) << "Only support compute intrinsic for now";
    CHECK_GE(self->axis.size(), intrin_compute->axis.size())
        << "Tensorize: Output mismatch with tensor intrin ";
    // Enable fuzzy matching, to match [1, n, m] to [n, m]
    size_t axis_start = self->axis.size() - intrin_compute->axis.size();
    for (size_t i = 0; i < axis_start; ++i) {
      Range r = out_dom.at(self->axis[i]);
      CHECK(is_one(r->extent))
          << "Tensorize: Output mismatch with tensor intrin "
          << " intrin-dim=" << intrin_compute->axis.size()
          << ", tensorize-dim=" << self->axis.size();
258
      var_remap_[self->axis[i]->var.get()] = r->min;
259 260 261 262 263 264 265 266 267 268
    }
    // Assume we tensorize at regin axis i [min, min + extent)
    // The corresponding intrinsic axis is j [0, extent)
    // Remap index i to j + min
    for (size_t i = axis_start; i < self->axis.size(); ++i) {
      IterVar iv = self->axis[i];
      IterVar target_iv = intrin_compute->axis[i - axis_start];
      Range r = out_dom.at(iv);
      var_remap_[iv->var.get()] = target_iv->var + r->min;
      axis_remap_[iv] = target_iv;
269
      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
270 271 272 273 274 275 276 277 278 279 280
    }
    // Remap reduction axis
    CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
        << "Tensorize: Reduction dimension mismatch with tensor intrin";
    axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
    for (size_t i = 0; i < axis_start; ++i) {
      Range r = out_dom.at(self->reduce_axis[i]);
      CHECK(is_one(r->extent))
          << "Tensorize: Reduction mismatch with tensor intrin "
          << " intrin-dim=" << intrin_compute->reduce_axis.size()
          << ", tensorize-dim=" << self->reduce_axis.size();
281
      var_remap_[self->reduce_axis[i]->var.get()] = r->min;
282 283 284 285 286 287 288
    }
    for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
      IterVar iv = self->reduce_axis[i];
      IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
      Range r = out_dom.at(iv);
      var_remap_[iv->var.get()] = target_iv->var + r->min;
      axis_remap_[iv] = target_iv;
289
      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
290 291 292 293 294 295 296 297 298 299 300 301 302
    }
  }

 private:
  // Input entry
  struct InputEntry {
    Tensor tensor;
    size_t start;
    Array<Range> region;
  };
  // input data remap
  std::unordered_map<Tensor, InputEntry> in_remap_;
  // variable remap.
303
  std::unordered_map<const VarNode*, PrimExpr> var_remap_;
304 305 306 307 308
  // IterVar remap.
  std::unordered_map<IterVar, IterVar> axis_remap_;
};

// Try to match tensor dataflow of the stage with the intrinsic
309
Array<PrimExpr> MatchTensorizeBody(
310 311
    const ComputeOpNode* self,
    const Stage& stage,
312
    const std::unordered_map<IterVar, Range>& dom_map,
313 314
    const std::unordered_map<IterVar, Range>& out_dom,
    const std::unordered_map<Tensor, Array<Range> >& in_region,
315 316
    const TensorIntrin& intrin,
    Map<Var, Range>* compute_intrin_iter_space) {
317
  TensorIntrinMatcher matcher;
318
  matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
319 320
  Array<PrimExpr> ret;
  for (PrimExpr expr : self->body) {
321
    ret.push_back(matcher(expr));
322 323 324 325 326 327 328
  }
  return ret;
}

void VerifyTensorizeBody(
    const ComputeOpNode* self,
    const Stage& stage,
329
    const std::unordered_map<IterVar, Range>& dom_map,
330 331 332
    const std::unordered_map<IterVar, Range>& out_dom,
    const std::unordered_map<Tensor, Array<Range> >& in_region,
    const TensorIntrin& intrin) {
333
  Map<Var, Range> compute_intrin_iter_space;
334
  Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
335
                                        &compute_intrin_iter_space);
336 337 338 339 340
  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
  CHECK(intrin_compute) << "Only support compute intrinsic for now";
  CHECK_EQ(body.size(), intrin_compute->body.size())
      << "Tensorize failed: body size mismatch";
  for (size_t i = 0; i < body.size(); ++i) {
341
    PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
342
    lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
343
    PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
344
    rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
345
    if (lhs.dtype() != rhs.dtype()) {
346 347 348
      LOG(FATAL)
          << "Failed to match the data type with TensorIntrin "
          << intrin->name << "'s declaration "
349 350
          << " provided=" << lhs.dtype()
          << ", intrin=" << rhs.dtype();
351
    }
352
    CHECK(Equal(lhs, rhs))
353 354 355 356
        << "Failed to match the compute with TensorIntrin "
        << intrin->name << "'s declaration "
        << " provided= " << lhs
        << ", intrin=  " << rhs;
357 358 359 360 361
  }
}

Stmt MakeTensorize(const ComputeOpNode* self,
                   const Stage& stage,
362
                   const std::unordered_map<IterVar, Range>& dom_map,
363
                   bool debug_keep_trivial_loop) {
364 365 366 367 368 369
  std::unordered_map<IterVar, Range> out_dom;
  std::unordered_map<Tensor, Array<Range> > in_region;
  size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
  TensorIntrin intrin = stage->iter_var_attrs.at(
      stage->leaf_iter_vars[tloc])->tensor_intrin;
  CHECK(intrin.defined());
370
  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
371
  VerifyTensorizeLoopNest(self, stage, n, tloc);
372
  VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
373
  // Start bind data.
374
  Stmt nop = EvaluateNode::make(0);
375
  std::vector<Stmt> input_bind_nest, output_bind_nest;
376 377 378 379
  Array<Tensor> inputs = self->InputTensors();
  CHECK_EQ(inputs.size(), intrin->inputs.size())
      << "Tensorize failed: input size mismatch ";
  // input binding
380
  for (size_t i = 0; i < intrin->inputs.size(); ++i) {
381 382
    Tensor tensor = inputs[i];
    Buffer buffer = intrin->buffers[i];
383
    Array<ObjectRef> bind_spec{buffer, tensor};
384 385 386
    auto it = in_region.find(tensor);
    CHECK(it != in_region.end());
    const Array<Range>& region = it->second;
387
    Array<PrimExpr> tuple;
388 389 390 391
    for (const Range r : region) {
      tuple.push_back(r->min);
      tuple.push_back(r->extent);
    }
392
    input_bind_nest.emplace_back(AttrStmtNode::make(
393
        bind_spec, tir::attr::buffer_bind_scope,
394
        CallNode::make(DataType::Handle(),
395
                       tir::intrinsic::tvm_tuple,
396
                       tuple, CallNode::Intrinsic), nop));
397 398 399 400 401 402
  }
  // output binding
  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
  CHECK(intrin_compute) << "Only support compute intrinsic for now";
  CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
  CHECK_EQ(intrin_compute->body.size(), self->body.size());
403
  Array<PrimExpr> tuple;
404 405 406 407 408 409 410 411 412
  for (IterVar iv : self->axis) {
    auto it = out_dom.find(iv);
    CHECK(it != out_dom.end());
    tuple.push_back(it->second->min);
    tuple.push_back(it->second->extent);
  }
  for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
    Tensor tensor = stage->op.output(i - intrin->inputs.size());
    Buffer buffer = intrin->buffers[i];
413
    Array<ObjectRef> bind_spec{buffer, tensor};
414
    output_bind_nest.emplace_back(AttrStmtNode::make(
415
        bind_spec, tir::attr::buffer_bind_scope,
416
        CallNode::make(DataType::Handle(),
417
                       tir::intrinsic::tvm_tuple,
418
                       tuple, CallNode::Intrinsic), nop));
419 420
  }
  // Check variable remap
421
  std::unordered_map<const VarNode*, PrimExpr> vmap;
422
  tir::ArgBinder binder(&vmap);
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
  CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
      << "Tensorization fail: reduction axis size do not match";
  size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
  for (size_t i = 0; i < start; ++i) {
    IterVar iv = self->reduce_axis[i];
    auto it = out_dom.find(iv);
    CHECK(it != out_dom.end());
    CHECK(is_one(it->second->extent))
        << "Tensorization fail: reduction axis size do not match";
  }
  for (size_t i = start; i < self->reduce_axis.size(); ++i) {
    IterVar iv = self->reduce_axis[i];
    IterVar target = intrin_compute->reduce_axis[i - start];
    auto it = out_dom.find(iv);
    CHECK(it != out_dom.end());
438
    binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
439 440 441 442 443 444 445 446
                "tensir_intrin.reduction.min");
    binder.Bind(target->dom->extent, it->second->extent,
                "tensir_intrin.reduction.extent");
  }
  if (tloc <= n.num_common_loop) {
    // Do no need to split reduction
    std::vector<std::vector<Stmt> > nest(
        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
447
    nest.emplace_back(MakeIfNest(n.main_predicates));
448 449 450
    CHECK_EQ(n.init_predicates.size(), 0U);
    CHECK(intrin->body.defined())
        << "Normal store op for intrin " << intrin << " is not defined";
451 452
    Stmt body = MergeNest(output_bind_nest, intrin->body);
    body = MergeNest(input_bind_nest, body);
453
    body = tir::Substitute(body, vmap);
454
    body = MergeNest(binder.asserts(), body);
455
    body = te::Substitute(body, n.main_vmap);
456
    return MergeNest(nest, body);
457 458 459 460 461 462 463 464 465 466
  } else {
    // Need to split reduction
    CHECK(intrin->reduce_update.defined())
        << "Reduction update op for intrin " << intrin << " is not defined";
    // Need init and update steps
    CHECK_NE(self->reduce_axis.size(), 0U);
    std::vector<std::vector<Stmt> > common(
        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
    std::vector<std::vector<Stmt> > update_nest(
        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
467
    update_nest.emplace_back(MakeIfNest(n.main_predicates));
468 469 470 471 472

    if (intrin->reduce_init.defined()) {
      // init nest
      std::vector<std::vector<Stmt> > init_nest(
          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
473
      init_nest.emplace_back(MakeIfNest(n.init_predicates));
474
      Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
475
      init = te::Substitute(init, n.init_vmap);
476 477 478 479
      init = MergeNest(init_nest, init);
      // The update
      Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
      update = MergeNest(input_bind_nest, update);
480
      update = tir::Substitute(update, vmap);
481
      update = MergeNest(binder.asserts(), update);
482
      update = te::Substitute(update, n.main_vmap);
483
      update = MergeNest(update_nest, update);
484
      return MergeNest(common, SeqStmt::Flatten(init, update));
485 486 487 488 489 490 491 492 493
    } else {
      // When init op is not available, use body op for reset in the first iter.
      CHECK(intrin->body.defined())
          << "Normal body op for intrin " << intrin << " is not defined";
      Stmt update = TransformUpdate(stage, dom_map, n,
                                    intrin->body,
                                    intrin->reduce_update);
      update = MergeNest(output_bind_nest, update);
      update = MergeNest(input_bind_nest, update);
494
      update = tir::Substitute(update, vmap);
495
      update = MergeNest(binder.asserts(), update);
496
      update = te::Substitute(update, n.main_vmap);
497 498 499
      update = MergeNest(update_nest, update);
      return MergeNest(common, update);
    }
500 501 502 503
  }
}

// Register functions for unittests
504
TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion")
505 506 507 508 509 510 511 512 513 514
.set_body([](TVMArgs args, TVMRetValue* ret) {
    Stage stage = args[0];
    Map<IterVar, Range> dmap = args[1];
    std::unordered_map<IterVar, Range> out_dom;
    std::unordered_map<Tensor, Array<Range> > in_region;
    CHECK(stage->op.as<ComputeOpNode>());
    InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
                         stage,
                         as_unordered_map(dmap),
                         &out_dom, &in_region);
515
    *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom),
516 517 518
                          Map<Tensor, Array<Range> >(in_region)};
  });

519
TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody")
520 521 522 523 524
.set_body([](TVMArgs args, TVMRetValue* ret) {
    Stage stage = args[0];
    Map<IterVar, Range> out_dom = args[1];
    Map<Tensor, Array<Range> > in_region = args[2];
    TensorIntrin intrin = args[3];
525
    Map<Var, Range> vrange;
526 527 528
    CHECK(stage->op.as<ComputeOpNode>());
    *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
                              stage,
529
                              {{}},
530 531
                              as_unordered_map(out_dom),
                              as_unordered_map(in_region),
532 533
                              intrin,
                              &vrange);
534
  });
535
}  // namespace te
536
}  // namespace tvm