/*! * Copyright (c) 2016 by Contributors * Implementation of API functions related to Higher DSL build. * \file api_lang.cc */ #include <tvm/expr.h> #include <tvm/ir.h> #include <tvm/tensor.h> #include <tvm/operation.h> #include <tvm/buffer.h> #include <tvm/schedule.h> #include <tvm/api_registry.h> #include <tvm/build_module.h> namespace tvm { TVM_REGISTER_API("_min_value") .set_body([](TVMArgs args, TVMRetValue* ret) { Type t = args[0].operator Type(); *ret = t.min(); }); TVM_REGISTER_API("_max_value") .set_body([](TVMArgs args, TVMRetValue* ret) { Type t = args[0].operator Type(); *ret = t.max(); }); TVM_REGISTER_API("_const") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { *ret = make_const(args[1], args[0].operator int64_t()); } else if (args[0].type_code() == kDLFloat) { *ret = make_const(args[1], args[0].operator double()); } else { LOG(FATAL) << "only accept int or float"; } }); TVM_REGISTER_API("_str") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ir::StringImm::make(args[0]); }); TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { std::vector<NodePtr<Node> > data; for (int i = 0; i < args.size(); ++i) { if (args[i].type_code() != kNull) { data.push_back(args[i].node_sptr()); } else { data.push_back(NodePtr<Node>(nullptr)); } } auto node = make_node<ArrayNode>(); node->data = std::move(data); *ret = node; }); TVM_REGISTER_API("_ArrayGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { int64_t i = args[1]; auto& sptr = args[0].node_sptr(); CHECK(sptr->is_type<ArrayNode>()); auto* n = static_cast<const ArrayNode*>(sptr.get()); CHECK_LT(static_cast<size_t>(i), n->data.size()) << "out of bound of array"; *ret = n->data[static_cast<size_t>(i)]; }); TVM_REGISTER_API("_ArraySize") .set_body([](TVMArgs args, TVMRetValue* ret) { auto& sptr = args[0].node_sptr(); CHECK(sptr->is_type<ArrayNode>()); *ret = static_cast<int64_t>( static_cast<const ArrayNode*>(sptr.get())->data.size()); }); TVM_REGISTER_API("_Map") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size() % 2, 0); if (args.size() != 0 && args[0].type_code() == kStr) { // StrMap StrMapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; CHECK(args[i + 1].type_code() == kNodeHandle) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), args[i + 1].node_sptr())); } auto node = make_node<StrMapNode>(); node->data = std::move(data); *ret = node; } else { // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kNodeHandle) << "key of str map need to be str"; CHECK(args[i + 1].type_code() == kNodeHandle) << "value of map to be NodeRef"; data.emplace(std::make_pair(args[i].node_sptr(), args[i + 1].node_sptr())); } auto node = make_node<MapNode>(); node->data = std::move(data); *ret = node; } }); TVM_REGISTER_API("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { auto& sptr = args[0].node_sptr(); if (sptr->is_type<MapNode>()) { auto* n = static_cast<const MapNode*>(sptr.get()); *ret = static_cast<int64_t>(n->data.size()); } else { CHECK(sptr->is_type<StrMapNode>()); auto* n = static_cast<const StrMapNode*>(sptr.get()); *ret = static_cast<int64_t>(n->data.size()); } }); TVM_REGISTER_API("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK(args[0].type_code() == kNodeHandle); auto& sptr = args[0].node_sptr(); if (sptr->is_type<MapNode>()) { CHECK(args[1].type_code() == kNodeHandle); auto* n = static_cast<const MapNode*>(sptr.get()); auto it = n->data.find(args[1].node_sptr()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } else { CHECK(sptr->is_type<StrMapNode>()); auto* n = static_cast<const StrMapNode*>(sptr.get()); auto it = n->data.find(args[1].operator std::string()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } }); TVM_REGISTER_API("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK(args[0].type_code() == kNodeHandle); auto& sptr = args[0].node_sptr(); if (sptr->is_type<MapNode>()) { auto* n = static_cast<const MapNode*>(sptr.get()); CHECK(args[1].type_code() == kNodeHandle); *ret = static_cast<int64_t>( n->data.count(args[1].node_sptr())); } else { CHECK(sptr->is_type<StrMapNode>()); auto* n = static_cast<const StrMapNode*>(sptr.get()); *ret = static_cast<int64_t>( n->data.count(args[1].operator std::string())); } }); TVM_REGISTER_API("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { auto& sptr = args[0].node_sptr(); if (sptr->is_type<MapNode>()) { auto* n = static_cast<const MapNode*>(sptr.get()); auto rkvs = make_node<ArrayNode>(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.second); } *ret = rkvs; } else { auto* n = static_cast<const StrMapNode*>(sptr.get()); auto rkvs = make_node<ArrayNode>(); for (const auto& kv : n->data) { rkvs->data.push_back(ir::StringImm::make(kv.first).node_); rkvs->data.push_back(kv.second); } *ret = rkvs; } }); TVM_REGISTER_API("Range") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = Range(0, args[0]); } else { *ret = Range(args[0], args[1]); } }); TVM_REGISTER_API("_Buffer") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8]); }); TVM_REGISTER_API("_BufferAccessPtr") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Buffer() .access_ptr(args[1], args[2], args[3], args[4]); }); TVM_REGISTER_API("_BufferVLoad") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Buffer() .vload(args[1], args[2]); }); TVM_REGISTER_API("_BufferVStore") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Buffer() .vstore(args[1], args[2]); }); TVM_REGISTER_API("_Tensor") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = TensorNode::make(args[0], args[1], args[2], args[3]); }); TVM_REGISTER_API("_TensorIntrin") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = TensorIntrinNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); TVM_REGISTER_API("_TensorIntrinCall") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = TensorIntrinCallNode::make(args[0], args[1], args[2], args[3]); }); TVM_REGISTER_API("_TensorEqual") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Tensor() == args[1].operator Tensor(); }); TVM_REGISTER_API("_TensorHash") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = static_cast<int64_t>( std::hash<Tensor>()(args[0].operator Tensor())); }); TVM_REGISTER_API("_Placeholder") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = placeholder(args[0], args[1], args[2]); }); TVM_REGISTER_API("_ComputeOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ComputeOpNode::make(args[0], args[1], args[2], args[3], args[4]); }); TVM_REGISTER_API("_ScanOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ScanOpNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); }); TVM_REGISTER_API("_TensorComputeOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = TensorComputeOpNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); }); TVM_REGISTER_API("_ExternOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ExternOpNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); TVM_REGISTER_API("_HybridOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = HybridOpNode::make(args[0], args[1], args[2], args[3], args[4], args[5]); }); TVM_REGISTER_API("_OpGetOutput") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Operation().output( static_cast<size_t>(args[1].operator int64_t())); }); TVM_REGISTER_API("_OpNumOutputs") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Operation()->num_outputs(); }); TVM_REGISTER_API("_OpInputTensors") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Operation()->InputTensors(); }); TVM_REGISTER_API("_IterVar") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = IterVarNode::make( args[0], args[1], static_cast<IterVarType>(args[2].operator int()), args[3]); }); TVM_REGISTER_API("_CreateSchedule") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = create_schedule(args[0].operator Array<Operation>()); }); TVM_REGISTER_API("_StageSetScope") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .set_scope(args[1]); }); TVM_REGISTER_API("_StageBind") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .bind(args[1], args[2]); }); TVM_REGISTER_API("_StageSplitByFactor") .set_body([](TVMArgs args, TVMRetValue* ret) { IterVar outer, inner; args[0].operator Stage() .split(args[1], args[2], &outer, &inner); *ret = Array<IterVar>({outer, inner}); }); TVM_REGISTER_API("_StageSplitByNParts") .set_body([](TVMArgs args, TVMRetValue* ret) { IterVar outer, inner; args[0].operator Stage() .split_by_nparts(args[1], args[2], &outer, &inner); *ret = Array<IterVar>({outer, inner}); }); TVM_REGISTER_API("_StageFuse") .set_body([](TVMArgs args, TVMRetValue* ret) { IterVar fused; args[0].operator Stage() .fuse(args[1], &fused); *ret = fused; }); TVM_REGISTER_API("_StageComputeAt") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .compute_at(args[1], args[2]); }); TVM_REGISTER_API("_StageComputeInline") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .compute_inline(); }); TVM_REGISTER_API("_StageComputeRoot") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .compute_root(); }); TVM_REGISTER_API("_StageReorder") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .reorder(args[1]); }); TVM_REGISTER_API("_StageTile") .set_body([](TVMArgs args, TVMRetValue* ret) { IterVar x_outer, y_outer, x_inner, y_inner; args[0].operator Stage() .tile(args[1], args[2], args[3], args[4], &x_outer, &y_outer, &x_inner, &y_inner); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); }); TVM_REGISTER_API("_StageEnvThreads") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .env_threads(args[1]); }); TVM_REGISTER_API("_StageSetStorePredicate") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .set_store_predicate(args[1]); }); TVM_REGISTER_API("_StageUnroll") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .unroll(args[1]); }); TVM_REGISTER_API("_StageVectorize") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .vectorize(args[1]); }); TVM_REGISTER_API("_StageTensorize") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .tensorize(args[1], args[2]); }); TVM_REGISTER_API("_StageParallel") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .parallel(args[1]); }); TVM_REGISTER_API("_StagePragma") .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Stage() .pragma(args[1], args[2], args[3]); }); TVM_REGISTER_API("_StagePrefetch") .set_body([](TVMArgs args, TVMRetValue *ret) { args[0].operator Stage() .prefetch(args[1], args[2], args[3]); }); TVM_REGISTER_API("_StageStorageAlign") .set_body([](TVMArgs args, TVMRetValue *ret) { args[0].operator Stage() .storage_align(args[1], args[2], args[3]); }); TVM_REGISTER_API("_StageDoubleBuffer") .set_body([](TVMArgs args, TVMRetValue *ret) { args[0].operator Stage().double_buffer(); }); TVM_REGISTER_API("_StageOpenGL") .set_body([](TVMArgs args, TVMRetValue *ret) { args[0].operator Stage().opengl(); }); TVM_REGISTER_API("_ScheduleNormalize") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Schedule() .normalize(); }); TVM_REGISTER_API("_ScheduleCreateGroup") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Schedule() .create_group(args[1], args[2], args[3]); }); TVM_REGISTER_API("_ScheduleCacheRead") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Schedule() .cache_read(args[1], args[2], args[3]); }); TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args[1].IsNodeType<Tensor>()) { *ret = args[0].operator Schedule() .cache_write(args[1].operator Tensor(), args[2]); } else { *ret = args[0].operator Schedule() .cache_write(args[1].operator Array<Tensor>(), args[2]); } }); TVM_REGISTER_API("_ScheduleRFactor") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Schedule() .rfactor(args[1], args[2], args[3]); }); TVM_REGISTER_API("_CommReducerCombine") .set_body([](TVMArgs args, TVMRetValue* ret) { const ir::CommReducerNode* combiner = args[0].operator ir::CommReducer().as<ir::CommReducerNode>(); *ret = (*combiner)(args[1], args[2]); }); } // namespace tvm