api_lang.cc 11.4 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2016 by Contributors
 *  Implementation of API functions related to Higher DSL build.
 * \file api_lang.cc
 */
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
9
#include <tvm/operation.h>
10 11 12 13 14 15
#include <tvm/buffer.h>
#include <tvm/schedule.h>
#include <tvm/api_registry.h>

namespace tvm {

ziheng committed
16 17 18 19 20 21 22 23 24 25 26 27
TVM_REGISTER_API("_min_value")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    Type t = args[0].operator Type();
    *ret = t.min();
  });

TVM_REGISTER_API("_max_value")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    Type t = args[0].operator Type();
    *ret = t.max();
  });

28
TVM_REGISTER_API("_const")
29 30 31 32 33 34 35 36 37 38
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    if (args[0].type_code() == kInt) {
      *ret = make_const(args[1], args[0].operator int64_t());
    } else if (args[0].type_code() == kFloat) {
      *ret = make_const(args[1], args[0].operator double());
    } else {
      LOG(FATAL) << "only accept int or float";
    }
  });

39
TVM_REGISTER_API("_str")
40 41 42 43 44
.set_body([](TVMArgs args,  TVMRetValue* ret) {
  *ret = ir::StringImm::make(args[0]);
});


45
TVM_REGISTER_API("_Array")
46 47 48 49 50 51 52 53 54 55
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    std::vector<std::shared_ptr<Node> > data;
    for (int i = 0; i < args.size(); ++i) {
      data.push_back(args[i].node_sptr());
    }
    auto node = std::make_shared<ArrayNode>();
    node->data = std::move(data);
    *ret = node;
  });

56
TVM_REGISTER_API("_ArrayGetItem")
57 58 59 60 61 62 63
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    int64_t i = args[1];
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<ArrayNode>());
    auto* n = static_cast<const ArrayNode*>(sptr.get());
    CHECK_LT(static_cast<size_t>(i), n->data.size())
        << "out of bound of array";
64
    *ret = n->data[static_cast<size_t>(i)];
65 66
  });

67
TVM_REGISTER_API("_ArraySize")
68 69 70 71 72 73 74
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<ArrayNode>());
    *ret = static_cast<int64_t>(
        static_cast<const ArrayNode*>(sptr.get())->data.size());
  });

75
TVM_REGISTER_API("_Map")
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    CHECK_EQ(args.size() % 2, 0);
    MapNode::ContainerType data;
    for (int i = 0; i < args.num_args; i += 2) {
      CHECK(args[i].type_code() == kNodeHandle)
          << "need content of array to be NodeBase";
      CHECK(args[i + 1].type_code() == kNodeHandle)
          << "need content of array to be NodeBase";
      data.emplace(std::make_pair(args[i].node_sptr(),
                                  args[i + 1].node_sptr()));
    }
    auto node = std::make_shared<MapNode>();
    node->data = std::move(data);
    *ret = node;
  });

92
TVM_REGISTER_API("_MapSize")
93 94 95 96 97 98 99
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<MapNode>());
    auto* n = static_cast<const MapNode*>(sptr.get());
    *ret = static_cast<int64_t>(n->data.size());
  });

100
TVM_REGISTER_API("_MapGetItem")
101 102 103 104 105 106 107 108 109 110 111 112
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    CHECK(args[0].type_code() == kNodeHandle);
    CHECK(args[1].type_code() == kNodeHandle);
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<MapNode>());
    auto* n = static_cast<const MapNode*>(sptr.get());
    auto it = n->data.find(args[1].node_sptr());
    CHECK(it != n->data.end())
        << "cannot find the corresponding key in the Map";
    *ret = (*it).second;
  });

113
TVM_REGISTER_API("_MapCount")
114 115 116 117 118 119 120 121 122 123
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    CHECK(args[0].type_code() == kNodeHandle);
    CHECK(args[1].type_code() == kNodeHandle);
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<MapNode>());
    auto* n = static_cast<const MapNode*>(sptr.get());
    *ret = static_cast<int64_t>(
        n->data.count(args[1].node_sptr()));
  });

124
TVM_REGISTER_API("_MapItems")
125 126 127 128 129 130 131 132 133 134 135 136
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    auto& sptr = args[0].node_sptr();
    CHECK(sptr->is_type<MapNode>());
    auto* n = static_cast<const MapNode*>(sptr.get());
    auto rkvs = std::make_shared<ArrayNode>();
    for (const auto& kv : n->data) {
      rkvs->data.push_back(kv.first);
      rkvs->data.push_back(kv.second);
    }
    *ret = rkvs;
  });

137
TVM_REGISTER_API("Range")
138 139 140 141 142 143 144 145
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    if (args.size() == 1) {
      *ret = Range(0, args[0]);
    } else {
      *ret = Range(args[0], args[1]);
    }
  });

146
TVM_REGISTER_API("_Buffer")
147 148 149 150 151
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = BufferNode::make(args[0],
                            args[1],
                            args[2],
                            args[3],
152 153
                            args[4],
                            args[5],
154
                            args[6],
155 156
                            args[7],
                            args[8]);
157 158
  });

159
TVM_REGISTER_API("_Tensor")
160 161 162 163 164 165 166
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = TensorNode::make(args[0],
                            args[1],
                            args[2],
                            args[3]);
  });

167 168 169 170 171 172 173 174 175 176 177
TVM_REGISTER_API("_TensorIntrin")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = TensorIntrinNode::make(args[0],
                                  args[1],
                                  args[2],
                                  args[3],
                                  args[4],
                                  args[5],
                                  args[6]);
  });

178
TVM_REGISTER_API("_TensorEqual")
179 180 181 182
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Tensor() == args[1].operator Tensor();
  });

183
TVM_REGISTER_API("_TensorHash")
184 185 186 187 188
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = static_cast<int64_t>(
        std::hash<Tensor>()(args[0].operator Tensor()));
  });

189
TVM_REGISTER_API("_Placeholder")
190
.set_body([](TVMArgs args,  TVMRetValue* ret) {
191
    *ret = placeholder(args[0],
192 193 194 195
                       args[1],
                       args[2]);
  });

196
TVM_REGISTER_API("_ComputeOp")
197 198 199
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = ComputeOpNode::make(args[0],
                               args[1],
200 201
                               args[2],
                               args[3]);
202 203
  });

204
TVM_REGISTER_API("_ScanOp")
205 206 207 208 209
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = ScanOpNode::make(args[0],
                            args[1],
                            args[2],
                            args[3],
210
                            args[4],
211 212
                            args[5],
                            args[6]);
213 214
  });

215
TVM_REGISTER_API("_ExternOp")
216 217 218 219 220
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = ExternOpNode::make(args[0],
                              args[1],
                              args[2],
                              args[3],
221 222
                              args[4],
                              args[5]);
223 224
  });

225
TVM_REGISTER_API("_OpGetOutput")
226 227
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Operation().output(
228
        static_cast<size_t>(args[1].operator int64_t()));
229 230
  });

231
TVM_REGISTER_API("_OpNumOutputs")
232 233 234 235
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Operation()->num_outputs();
  });

236 237 238 239 240
TVM_REGISTER_API("_OpInputTensors")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
    *ret = args[0].operator Operation()->InputTensors();
  });

241
TVM_REGISTER_API("_IterVar")
242
.set_body([](TVMArgs args,  TVMRetValue* ret) {
243 244 245 246
    *ret = IterVarNode::make(
        args[0], args[1],
        static_cast<IterVarType>(args[2].operator int()),
        args[3]);
247 248
  });

249
TVM_REGISTER_API("_CreateSchedule")
250
.set_body([](TVMArgs args, TVMRetValue* ret) {
251
    *ret = create_schedule(args[0].operator Array<Operation>());
252 253
  });

254
TVM_REGISTER_API("_StageSetScope")
255 256 257 258 259
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .set_scope(args[1]);
  });

260
TVM_REGISTER_API("_StageBind")
261 262
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
263
        .bind(args[1], args[2]);
264 265
  });

266
TVM_REGISTER_API("_StageSplitByFactor")
267 268 269
.set_body([](TVMArgs args, TVMRetValue* ret) {
    IterVar outer, inner;
    args[0].operator Stage()
270
        .split(args[1], args[2], &outer, &inner);
271 272 273
    *ret = Array<IterVar>({outer, inner});
  });

274
TVM_REGISTER_API("_StageSplitByNParts")
275
.set_body([](TVMArgs args, TVMRetValue* ret) {
276
    IterVar outer, inner;
277
    args[0].operator Stage()
278 279
        .split_by_nparts(args[1], args[2], &outer, &inner);
    *ret = Array<IterVar>({outer, inner});
280 281
  });

282
TVM_REGISTER_API("_StageFuse")
283 284 285
.set_body([](TVMArgs args, TVMRetValue* ret) {
    IterVar fused;
    args[0].operator Stage()
Ziheng Jiang committed
286
        .fuse(args[1], args[2], &fused);
287 288 289
    *ret = fused;
  });

290
TVM_REGISTER_API("_StageComputeAt")
291 292 293 294 295
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .compute_at(args[1], args[2]);
  });

296
TVM_REGISTER_API("_StageComputeInline")
297 298 299 300 301
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .compute_inline();
  });

302
TVM_REGISTER_API("_StageComputeRoot")
303 304 305 306 307
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .compute_root();
  });

308
TVM_REGISTER_API("_StageReorder")
309 310 311 312 313
.set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .reorder(args[1]);
  });

314
TVM_REGISTER_API("_StageTile")
315 316 317
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    IterVar x_outer, y_outer, x_inner, y_inner;
    args[0].operator Stage()
318 319 320 321
        .tile(args[1], args[2],
              args[3], args[4],
              &x_outer, &y_outer,
              &x_inner, &y_inner);
322 323 324
    *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
  });

325
TVM_REGISTER_API("_StageEnvThreads")
326 327
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
328
        .env_threads(args[1]);
329 330
  });

331 332 333 334 335 336
TVM_REGISTER_API("_StageSetStorePredicate")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .set_store_predicate(args[1]);
  });

337
TVM_REGISTER_API("_StageUnroll")
338 339 340 341 342
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .unroll(args[1]);
  });

343
TVM_REGISTER_API("_StageVectorize")
344 345 346 347 348
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .vectorize(args[1]);
  });

349 350 351 352 353 354
TVM_REGISTER_API("_StageTensorize")
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .tensorize(args[1], args[2]);
  });

355
TVM_REGISTER_API("_StageParallel")
356 357 358 359 360
  .set_body([](TVMArgs args, TVMRetValue* ret) {
    args[0].operator Stage()
        .parallel(args[1]);
  });

361
TVM_REGISTER_API("_ScheduleNormalize")
362
.set_body([](TVMArgs args, TVMRetValue* ret) {
363
    *ret = args[0].operator Schedule()
364 365 366
        .normalize();
  });

367
TVM_REGISTER_API("_ScheduleCreateGroup")
368 369 370 371 372
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = args[0].operator Schedule()
        .create_group(args[1], args[2], args[3]);
  });

373
TVM_REGISTER_API("_ScheduleCacheRead")
374 375 376 377 378
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = args[0].operator Schedule()
        .cache_read(args[1], args[2], args[3]);
  });

379
TVM_REGISTER_API("_ScheduleCacheWrite")
380 381 382 383 384
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = args[0].operator Schedule()
        .cache_write(args[1], args[2]);
  });

385
TVM_REGISTER_API("_ScheduleRFactor")
386 387 388 389 390
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = args[0].operator Schedule()
        .rfactor(args[1], args[2]);
  });

ziheng committed
391 392 393 394 395 396 397
TVM_REGISTER_API("_CommReducerCombine")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    const ir::CommReducerNode* combiner =
      args[0].operator ir::CommReducer().as<ir::CommReducerNode>();
    *ret = (*combiner)(args[1], args[2]);
  });

398
}  // namespace tvm