api_lang.cc 15.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2016 by Contributors
 *  Implementation of API functions related to Higher DSL build.
 * \file api_lang.cc
 */
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
9
#include <tvm/operation.h>
10 11 12
#include <tvm/buffer.h>
#include <tvm/schedule.h>
#include <tvm/api_registry.h>
13
#include <tvm/build_module.h>
14 15 16

namespace tvm {

ziheng committed
17 18 19 20 21 22 23 24 25 26 27 28
TVM_REGISTER_API("_min_value")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    Type t = args[0].operator Type();
    *ret = t.min();
  });

TVM_REGISTER_API("_max_value")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    Type t = args[0].operator Type();
    *ret = t.max();
  });

29
TVM_REGISTER_API("_const")
30
.set_body([](TVMArgs args,  TVMRetValue* ret) {
31
    if (args[0].type_code() == kDLInt) {
32
      *ret = make_const(args[1], args[0].operator int64_t());
33
    } else if (args[0].type_code() == kDLFloat) {
34 35 36 37 38 39
      *ret = make_const(args[1], args[0].operator double());
    } else {
      LOG(FATAL) << "only accept int or float";
    }
  });

40
TVM_REGISTER_API("_str")
41 42 43 44 45
.set_body([](TVMArgs args,  TVMRetValue* ret) {
  *ret = ir::StringImm::make(args[0]);
});


46
TVM_REGISTER_API("_Array")
47
.set_body([](TVMArgs args,  TVMRetValue* ret) {
48
    std::vector<NodePtr<Node> > data;
49
    for (int i = 0; i < args.size(); ++i) {
50 51 52 53 54
      if (args[i].type_code() != kNull) {
        data.push_back(args[i].node_sptr());
      } else {
        data.push_back(NodePtr<Node>(nullptr));
      }
55
    }
56
    auto node = make_node<ArrayNode>();
57 58 59 60
    node->data = std::move(data);
    *ret = node;
  });

61
TVM_REGISTER_API("_ArrayGetItem")
62 63 64 65 66 67 68
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    int64_t i = args[1];
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<ArrayNode>());
    auto* n = static_cast<const ArrayNode*>(sptr.get());
    CHECK_LT(static_cast<size_t>(i), n->data.size())
        << "out of bound of array";
69
    *ret = n->data[static_cast<size_t>(i)];
70 71
  });

72
TVM_REGISTER_API("_ArraySize")
73 74 75 76 77 78 79
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<ArrayNode>());
    *ret = static_cast<int64_t>(
        static_cast<const ArrayNode*>(sptr.get())->data.size());
  });

80
TVM_REGISTER_API("_Map")
81 82
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    CHECK_EQ(args.size() % 2, 0);
83 84 85 86 87 88 89 90 91 92 93
    if (args.size() != 0 && args[0].type_code() == kStr) {
      // StrMap
      StrMapNode::ContainerType data;
      for (int i = 0; i < args.num_args; i += 2) {
        CHECK(args[i].type_code() == kStr)
            << "key of str map need to be str";
        CHECK(args[i + 1].type_code() == kNodeHandle)
            << "value of the map to be NodeRef";
        data.emplace(std::make_pair(args[i].operator std::string(),
                                    args[i + 1].node_sptr()));
      }
94
      auto node = make_node<StrMapNode>();
95 96 97 98 99 100 101 102 103 104 105 106 107
      node->data = std::move(data);
      *ret = node;
    } else {
      // Container node.
      MapNode::ContainerType data;
      for (int i = 0; i < args.num_args; i += 2) {
        CHECK(args[i].type_code() == kNodeHandle)
            << "key of str map need to be str";
        CHECK(args[i + 1].type_code() == kNodeHandle)
            << "value of map to be NodeRef";
        data.emplace(std::make_pair(args[i].node_sptr(),
                                    args[i + 1].node_sptr()));
      }
108
      auto node = make_node<MapNode>();
109 110
      node->data = std::move(data);
      *ret = node;
111 112 113
    }
  });

114
TVM_REGISTER_API("_MapSize")
115 116
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    auto& sptr = args[0].node_sptr();
117 118 119 120 121 122 123 124
    if (sptr->is_type<MapNode>()) {
      auto* n = static_cast<const MapNode*>(sptr.get());
      *ret = static_cast<int64_t>(n->data.size());
    } else {
      CHECK(sptr->is_type<StrMapNode>());
      auto* n = static_cast<const StrMapNode*>(sptr.get());
      *ret = static_cast<int64_t>(n->data.size());
    }
125 126
  });

127
TVM_REGISTER_API("_MapGetItem")
128 129 130
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    CHECK(args[0].type_code() == kNodeHandle);
    auto& sptr = args[0].node_sptr();
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    if (sptr->is_type<MapNode>()) {
      CHECK(args[1].type_code() == kNodeHandle);
      auto* n = static_cast<const MapNode*>(sptr.get());
      auto it = n->data.find(args[1].node_sptr());
      CHECK(it != n->data.end())
          << "cannot find the corresponding key in the Map";
      *ret = (*it).second;
    } else {
      CHECK(sptr->is_type<StrMapNode>());
      auto* n = static_cast<const StrMapNode*>(sptr.get());
      auto it = n->data.find(args[1].operator std::string());
      CHECK(it != n->data.end())
          << "cannot find the corresponding key in the Map";
      *ret = (*it).second;
    }
146 147
  });

148
TVM_REGISTER_API("_MapCount")
149 150 151
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    CHECK(args[0].type_code() == kNodeHandle);
    auto& sptr = args[0].node_sptr();
152 153 154 155 156 157 158 159 160 161 162
    if (sptr->is_type<MapNode>()) {
      auto* n = static_cast<const MapNode*>(sptr.get());
      CHECK(args[1].type_code() == kNodeHandle);
      *ret = static_cast<int64_t>(
          n->data.count(args[1].node_sptr()));
    } else {
      CHECK(sptr->is_type<StrMapNode>());
      auto* n = static_cast<const StrMapNode*>(sptr.get());
      *ret = static_cast<int64_t>(
          n->data.count(args[1].operator std::string()));
    }
163 164
  });

165
TVM_REGISTER_API("_MapItems")
166 167
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    auto& sptr = args[0].node_sptr();
168 169
    if (sptr->is_type<MapNode>()) {
      auto* n = static_cast<const MapNode*>(sptr.get());
170
      auto rkvs = make_node<ArrayNode>();
171 172 173 174 175 176 177
      for (const auto& kv : n->data) {
        rkvs->data.push_back(kv.first);
        rkvs->data.push_back(kv.second);
      }
      *ret = rkvs;
    } else {
      auto* n = static_cast<const StrMapNode*>(sptr.get());
178
      auto rkvs = make_node<ArrayNode>();
179 180 181 182 183
      for (const auto& kv : n->data) {
        rkvs->data.push_back(ir::StringImm::make(kv.first).node_);
        rkvs->data.push_back(kv.second);
      }
      *ret = rkvs;
184 185 186
    }
  });

187
TVM_REGISTER_API("Range")
188 189 190 191 192 193 194 195
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    if (args.size() == 1) {
      *ret = Range(0, args[0]);
    } else {
      *ret = Range(args[0], args[1]);
    }
  });

196
TVM_REGISTER_API("_Buffer")
197 198 199 200 201
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = BufferNode::make(args[0],
                            args[1],
                            args[2],
                            args[3],
202 203
                            args[4],
                            args[5],
204
                            args[6],
205 206
                            args[7],
                            args[8]);
207 208
  });

209 210 211
TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Buffer()
212
        .access_ptr(args[1], args[2], args[3], args[4]);
213 214
  });

215 216 217 218 219 220 221 222 223 224 225 226
TVM_REGISTER_API("_BufferVLoad")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Buffer()
        .vload(args[1], args[2]);
  });

TVM_REGISTER_API("_BufferVStore")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Buffer()
        .vstore(args[1], args[2]);
  });

227
TVM_REGISTER_API("_Tensor")
228 229 230 231 232 233 234
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = TensorNode::make(args[0],
                            args[1],
                            args[2],
                            args[3]);
  });

235 236 237 238 239 240 241 242 243 244 245
TVM_REGISTER_API("_TensorIntrin")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = TensorIntrinNode::make(args[0],
                                  args[1],
                                  args[2],
                                  args[3],
                                  args[4],
                                  args[5],
                                  args[6]);
  });

246 247 248 249 250 251 252 253
TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = TensorIntrinCallNode::make(args[0],
                                      args[1],
                                      args[2],
                                      args[3]);
  });

254
TVM_REGISTER_API("_TensorEqual")
255 256 257 258
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Tensor() == args[1].operator Tensor();
  });

259
TVM_REGISTER_API("_TensorHash")
260 261 262 263 264
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = static_cast<int64_t>(
        std::hash<Tensor>()(args[0].operator Tensor()));
  });

265
TVM_REGISTER_API("_Placeholder")
266
.set_body([](TVMArgs args,  TVMRetValue* ret) {
267
    *ret = placeholder(args[0],
268 269 270 271
                       args[1],
                       args[2]);
  });

272
TVM_REGISTER_API("_ComputeOp")
273 274 275
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = ComputeOpNode::make(args[0],
                               args[1],
276
                               args[2],
277 278
                               args[3],
                               args[4]);
279 280
  });

281
TVM_REGISTER_API("_ScanOp")
282 283 284 285 286
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = ScanOpNode::make(args[0],
                            args[1],
                            args[2],
                            args[3],
287
                            args[4],
288
                            args[5],
289 290
                            args[6],
                            args[7]);
291 292
  });

293 294 295 296 297 298 299 300 301 302 303 304
TVM_REGISTER_API("_TensorComputeOp")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = TensorComputeOpNode::make(args[0],
                                     args[1],
                                     args[2],
                                     args[3],
                                     args[4],
                                     args[5],
                                     args[6],
                                     args[7]);
  });

305
TVM_REGISTER_API("_ExternOp")
306 307 308 309 310
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = ExternOpNode::make(args[0],
                              args[1],
                              args[2],
                              args[3],
311
                              args[4],
312 313
                              args[5],
                              args[6]);
314 315
  });

316 317 318 319 320 321 322 323 324 325
TVM_REGISTER_API("_HybridOp")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = HybridOpNode::make(args[0],
                              args[1],
                              args[2],
                              args[3],
                              args[4],
                              args[5]);
  });

326
TVM_REGISTER_API("_OpGetOutput")
327 328
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Operation().output(
329
        static_cast<size_t>(args[1].operator int64_t()));
330 331
  });

332
TVM_REGISTER_API("_OpNumOutputs")
333 334 335 336
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Operation()->num_outputs();
  });

337 338 339 340 341
TVM_REGISTER_API("_OpInputTensors")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Operation()->InputTensors();
  });

342
TVM_REGISTER_API("_IterVar")
343
.set_body([](TVMArgs args,  TVMRetValue* ret) {
344 345 346 347
    *ret = IterVarNode::make(
        args[0], args[1],
        static_cast<IterVarType>(args[2].operator int()),
        args[3]);
348 349
  });

350
TVM_REGISTER_API("_CreateSchedule")
351
.set_body([](TVMArgs args, TVMRetValue* ret) {
352
    *ret = create_schedule(args[0].operator Array<Operation>());
353 354
  });

355
TVM_REGISTER_API("_StageSetScope")
356 357 358 359 360
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .set_scope(args[1]);
  });

361
TVM_REGISTER_API("_StageBind")
362 363
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
364
        .bind(args[1], args[2]);
365 366
  });

367
TVM_REGISTER_API("_StageSplitByFactor")
368 369 370
.set_body([](TVMArgs args, TVMRetValue* ret) {
    IterVar outer, inner;
    args[0].operator Stage()
371
        .split(args[1], args[2], &outer, &inner);
372 373 374
    *ret = Array<IterVar>({outer, inner});
  });

375
TVM_REGISTER_API("_StageSplitByNParts")
376
.set_body([](TVMArgs args, TVMRetValue* ret) {
377
    IterVar outer, inner;
378
    args[0].operator Stage()
379 380
        .split_by_nparts(args[1], args[2], &outer, &inner);
    *ret = Array<IterVar>({outer, inner});
381 382
  });

383
TVM_REGISTER_API("_StageFuse")
384 385 386
.set_body([](TVMArgs args, TVMRetValue* ret) {
    IterVar fused;
    args[0].operator Stage()
387
        .fuse(args[1], &fused);
388 389 390
    *ret = fused;
  });

391
TVM_REGISTER_API("_StageComputeAt")
392 393 394 395 396
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .compute_at(args[1], args[2]);
  });

397
TVM_REGISTER_API("_StageComputeInline")
398 399 400 401 402
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .compute_inline();
  });

403
TVM_REGISTER_API("_StageComputeRoot")
404 405 406 407 408
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .compute_root();
  });

409
TVM_REGISTER_API("_StageReorder")
410 411 412 413 414
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .reorder(args[1]);
  });

415
TVM_REGISTER_API("_StageTile")
416 417 418
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    IterVar x_outer, y_outer, x_inner, y_inner;
    args[0].operator Stage()
419 420 421 422
        .tile(args[1], args[2],
              args[3], args[4],
              &x_outer, &y_outer,
              &x_inner, &y_inner);
423 424 425
    *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
  });

426
TVM_REGISTER_API("_StageEnvThreads")
427 428
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
429
        .env_threads(args[1]);
430 431
  });

432 433 434 435 436 437
TVM_REGISTER_API("_StageSetStorePredicate")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .set_store_predicate(args[1]);
  });

438
TVM_REGISTER_API("_StageUnroll")
439 440 441 442 443
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .unroll(args[1]);
  });

444
TVM_REGISTER_API("_StageVectorize")
445 446 447 448 449
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .vectorize(args[1]);
  });

450 451 452 453 454 455
TVM_REGISTER_API("_StageTensorize")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .tensorize(args[1], args[2]);
  });

456
TVM_REGISTER_API("_StageParallel")
457 458 459 460 461
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .parallel(args[1]);
  });

462 463 464
TVM_REGISTER_API("_StagePragma")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
465
        .pragma(args[1], args[2], args[3]);
466 467
  });

468 469 470
TVM_REGISTER_API("_StagePrefetch")
  .set_body([](TVMArgs args, TVMRetValue *ret) {
    args[0].operator Stage()
471
        .prefetch(args[1], args[2], args[3]);
472 473
  });

474 475 476
TVM_REGISTER_API("_StageStorageAlign")
  .set_body([](TVMArgs args, TVMRetValue *ret) {
    args[0].operator Stage()
477 478 479 480 481 482
        .storage_align(args[1], args[2], args[3]);
  });

TVM_REGISTER_API("_StageDoubleBuffer")
  .set_body([](TVMArgs args, TVMRetValue *ret) {
    args[0].operator Stage().double_buffer();
483 484
  });

485 486 487 488 489
TVM_REGISTER_API("_StageOpenGL")
  .set_body([](TVMArgs args, TVMRetValue *ret) {
    args[0].operator Stage().opengl();
  });

490
TVM_REGISTER_API("_ScheduleNormalize")
491
.set_body([](TVMArgs args, TVMRetValue* ret) {
492
    *ret = args[0].operator Schedule()
493 494 495
        .normalize();
  });

496
TVM_REGISTER_API("_ScheduleCreateGroup")
497 498 499 500 501
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = args[0].operator Schedule()
        .create_group(args[1], args[2], args[3]);
  });

502
TVM_REGISTER_API("_ScheduleCacheRead")
503 504 505 506 507
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = args[0].operator Schedule()
        .cache_read(args[1], args[2], args[3]);
  });

508
TVM_REGISTER_API("_ScheduleCacheWrite")
509
.set_body([](TVMArgs args, TVMRetValue* ret) {
510 511 512 513 514 515 516
    if (args[1].IsNodeType<Tensor>()) {
      *ret = args[0].operator Schedule()
          .cache_write(args[1].operator Tensor(), args[2]);
    } else {
      *ret = args[0].operator Schedule()
          .cache_write(args[1].operator Array<Tensor>(), args[2]);
    }
517 518
  });

519
TVM_REGISTER_API("_ScheduleRFactor")
520 521
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = args[0].operator Schedule()
522
        .rfactor(args[1], args[2], args[3]);
523 524
  });

ziheng committed
525 526 527 528 529 530 531
TVM_REGISTER_API("_CommReducerCombine")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    const ir::CommReducerNode* combiner =
      args[0].operator ir::CommReducer().as<ir::CommReducerNode>();
    *ret = (*combiner)(args[1], args[2]);
  });

532
}  // namespace tvm