/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Implementation of API functions related to Higher DSL build. * \file api_lang.cc */ #include <tvm/expr.h> #include <tvm/ir.h> #include <tvm/tensor.h> #include <tvm/operation.h> #include <tvm/buffer.h> #include <tvm/schedule.h> #include <tvm/api_registry.h> #include <tvm/build_module.h> #include <tvm/data_layout.h> namespace tvm { TVM_REGISTER_API("_min_value") .set_body_method(&DataType::min); TVM_REGISTER_API("_max_value") .set_body_method(&DataType::max); TVM_REGISTER_API("_const") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { *ret = make_const(args[1], args[0].operator int64_t()); } else if (args[0].type_code() == kDLFloat) { *ret = make_const(args[1], args[0].operator double()); } else { LOG(FATAL) << "only accept int or float"; } }); TVM_REGISTER_API("_str") .set_body_typed(ir::StringImm::make); TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { std::vector<ObjectRef> data; for (int i = 0; i < args.size(); ++i) { if (args[i].type_code() != kNull) { data.push_back(args[i].operator ObjectRef()); } else { data.push_back(ObjectRef(nullptr)); } } auto node = make_node<ArrayNode>(); node->data = std::move(data); *ret = runtime::ObjectRef(node); }); TVM_REGISTER_API("_ArrayGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { int64_t i = args[1]; CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast<Object*>(args[0].value().v_handle); CHECK(ptr->IsInstance<ArrayNode>()); auto* n = static_cast<const ArrayNode*>(ptr); CHECK_LT(static_cast<size_t>(i), n->data.size()) << "out of bound of array"; *ret = n->data[static_cast<size_t>(i)]; }); TVM_REGISTER_API("_ArraySize") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast<Object*>(args[0].value().v_handle); CHECK(ptr->IsInstance<ArrayNode>()); *ret = static_cast<int64_t>( static_cast<const ArrayNode*>(ptr)->data.size()); }); TVM_REGISTER_API("_Map") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size() % 2, 0); if (args.size() != 0 && args[0].type_code() == kStr) { // StrMap StrMapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; CHECK(args[i + 1].type_code() == kObjectHandle) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef())); } auto node = make_node<StrMapNode>(); node->data = std::move(data); *ret = node; } else { // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kObjectHandle) << "key of str map need to be str"; CHECK(args[i + 1].type_code() == kObjectHandle) << "value of map to be NodeRef"; data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); } auto node = make_node<MapNode>(); node->data = std::move(data); *ret = node; } }); TVM_REGISTER_API("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast<Object*>(args[0].value().v_handle); if (ptr->IsInstance<MapNode>()) { auto* n = static_cast<const MapNode*>(ptr); *ret = static_cast<int64_t>(n->data.size()); } else { CHECK(ptr->IsInstance<StrMapNode>()); auto* n = static_cast<const StrMapNode*>(ptr); *ret = static_cast<int64_t>(n->data.size()); } }); TVM_REGISTER_API("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast<Object*>(args[0].value().v_handle); if (ptr->IsInstance<MapNode>()) { CHECK(args[1].type_code() == kObjectHandle); auto* n = static_cast<const MapNode*>(ptr); auto it = n->data.find(args[1].operator ObjectRef()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } else { CHECK(ptr->IsInstance<StrMapNode>()); auto* n = static_cast<const StrMapNode*>(ptr); auto it = n->data.find(args[1].operator std::string()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } }); TVM_REGISTER_API("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast<Object*>(args[0].value().v_handle); if (ptr->IsInstance<MapNode>()) { auto* n = static_cast<const MapNode*>(ptr); CHECK_EQ(args[0].type_code(), kObjectHandle); *ret = static_cast<int64_t>( n->data.count(args[1].operator ObjectRef())); } else { CHECK(ptr->IsInstance<StrMapNode>()); auto* n = static_cast<const StrMapNode*>(ptr); *ret = static_cast<int64_t>( n->data.count(args[1].operator std::string())); } }); TVM_REGISTER_API("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast<Object*>(args[0].value().v_handle); if (ptr->IsInstance<MapNode>()) { auto* n = static_cast<const MapNode*>(ptr); auto rkvs = make_node<ArrayNode>(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.second); } *ret = rkvs; } else { auto* n = static_cast<const StrMapNode*>(ptr); auto rkvs = make_node<ArrayNode>(); for (const auto& kv : n->data) { rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); } *ret = rkvs; } }); TVM_REGISTER_API("Range") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = Range(0, args[0]); } else { *ret = Range(args[0], args[1]); } }); TVM_REGISTER_API("_Buffer") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size(), 10); auto buffer_type = args[9].operator std::string(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], type); }); TVM_REGISTER_API("_BufferAccessPtr") .set_body_method(&Buffer::access_ptr); TVM_REGISTER_API("_BufferVLoad") .set_body_method(&Buffer::vload); TVM_REGISTER_API("_BufferVStore") .set_body_method(&Buffer::vstore); TVM_REGISTER_API("_Layout") .set_body_typed(LayoutNode::make); TVM_REGISTER_API("_LayoutIndexOf") .set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) { return layout.IndexOf(LayoutAxis::make(axis)); }); TVM_REGISTER_API("_LayoutFactorOf") .set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) { return layout.FactorOf(LayoutAxis::make(axis)); }); TVM_REGISTER_API("_LayoutNdim") .set_body_typed<int(Layout)>([](Layout layout) { return layout.ndim(); }); TVM_REGISTER_API("_LayoutGetItem") .set_body_typed<std::string(Layout, int)>([](Layout layout, int idx) { const LayoutAxis& axis = layout[idx]; return axis.name(); }); TVM_REGISTER_API("_BijectiveLayout") .set_body_typed(BijectiveLayoutNode::make); TVM_REGISTER_API("_BijectiveLayoutForwardIndex") .set_body_method(&BijectiveLayout::ForwardIndex); TVM_REGISTER_API("_BijectiveLayoutBackwardIndex") .set_body_method(&BijectiveLayout::BackwardIndex); TVM_REGISTER_API("_BijectiveLayoutForwardShape") .set_body_method(&BijectiveLayout::ForwardShape); TVM_REGISTER_API("_BijectiveLayoutBackwardShape") .set_body_method(&BijectiveLayout::BackwardShape); TVM_REGISTER_API("_Tensor") .set_body_typed(TensorNode::make); TVM_REGISTER_API("_TensorIntrin") .set_body_typed(TensorIntrinNode::make); TVM_REGISTER_API("_TensorIntrinCall") .set_body_typed(TensorIntrinCallNode::make); TVM_REGISTER_API("_TensorEqual") .set_body_method(&Tensor::operator==); TVM_REGISTER_API("_TensorHash") .set_body_typed<int64_t(Tensor)>([](Tensor tensor) { return static_cast<int64_t>(std::hash<Tensor>()(tensor)); }); TVM_REGISTER_API("_Placeholder") .set_body_typed<Tensor(Array<Expr>, Type, std::string)>([]( Array<Expr> shape, Type dtype, std::string name ) { return placeholder(shape, dtype, name); }); TVM_REGISTER_API("_ComputeOp") .set_body_typed(ComputeOpNode::make); TVM_REGISTER_API("_ScanOp") .set_body_typed(ScanOpNode::make); TVM_REGISTER_API("_TensorComputeOp") .set_body_typed(TensorComputeOpNode::make); TVM_REGISTER_API("_ExternOp") .set_body_typed(ExternOpNode::make); TVM_REGISTER_API("_HybridOp") .set_body_typed(HybridOpNode::make); TVM_REGISTER_API("_OpGetOutput") .set_body_typed<Tensor(Operation, int64_t)>([](Operation op, int64_t output) { return op.output(static_cast<size_t>(output)); }); TVM_REGISTER_API("_OpNumOutputs") .set_body_method<Operation>(&OperationNode::num_outputs); TVM_REGISTER_API("_OpInputTensors") .set_body_method<Operation>(&OperationNode::InputTensors); TVM_REGISTER_API("_IterVar") .set_body_typed<IterVar(Range, Var, int, std::string)>([]( Range dom, Var var, int iter_type, std::string thread_tag ) { return IterVarNode::make( dom, var, static_cast<IterVarType>(iter_type), thread_tag); }); TVM_REGISTER_API("_CreateSchedule") .set_body_typed(create_schedule); TVM_REGISTER_API("_StageSetScope") .set_body_method(&Stage::set_scope); TVM_REGISTER_API("_StageBind") .set_body_method(&Stage::bind); TVM_REGISTER_API("_StageSplitByFactor") .set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([]( Stage stage, IterVar parent, Expr factor ) { IterVar outer, inner; stage.split(parent, factor, &outer, &inner); return Array<IterVar>({outer, inner}); }); TVM_REGISTER_API("_StageSplitByNParts") .set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([]( Stage stage, IterVar parent, Expr nparts ) { IterVar outer, inner; stage.split_by_nparts(parent, nparts, &outer, &inner); return Array<IterVar>({outer, inner}); }); TVM_REGISTER_API("_StageFuse") .set_body_typed<IterVar(Stage, Array<IterVar>)>([](Stage stage, Array<IterVar> axes) { IterVar fused; stage.fuse(axes, &fused); return fused; }); TVM_REGISTER_API("_StageComputeAt") .set_body_method(&Stage::compute_at); TVM_REGISTER_API("_StageComputeInline") .set_body_method(&Stage::compute_inline); TVM_REGISTER_API("_StageComputeRoot") .set_body_method(&Stage::compute_root); TVM_REGISTER_API("_StageReorder") .set_body_method(&Stage::reorder); TVM_REGISTER_API("_StageTile") .set_body_typed<Array<IterVar>(Stage, IterVar, IterVar, Expr, Expr)>([]( Stage stage, IterVar x_parent, IterVar y_parent, Expr x_factor, Expr y_factor ) { IterVar x_outer, y_outer, x_inner, y_inner; stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner); return Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); }); TVM_REGISTER_API("_StageEnvThreads") .set_body_method(&Stage::env_threads); TVM_REGISTER_API("_StageSetStorePredicate") .set_body_method(&Stage::set_store_predicate); TVM_REGISTER_API("_StageUnroll") .set_body_method(&Stage::unroll); TVM_REGISTER_API("_StageVectorize") .set_body_method(&Stage::vectorize); TVM_REGISTER_API("_StageTensorize") .set_body_method(&Stage::tensorize); TVM_REGISTER_API("_StageParallel") .set_body_method(&Stage::parallel); TVM_REGISTER_API("_StagePragma") .set_body_method(&Stage::pragma); TVM_REGISTER_API("_StagePrefetch") .set_body_method(&Stage::prefetch); TVM_REGISTER_API("_StageStorageAlign") .set_body_method(&Stage::storage_align); TVM_REGISTER_API("_StageDoubleBuffer") .set_body_method(&Stage::double_buffer); TVM_REGISTER_API("_StageOpenGL") .set_body_method(&Stage::opengl); TVM_REGISTER_API("_ScheduleNormalize") .set_body_method(&Schedule::normalize); TVM_REGISTER_API("_ScheduleCreateGroup") .set_body_method(&Schedule::create_group); TVM_REGISTER_API("_ScheduleCacheRead") .set_body_method(&Schedule::cache_read); TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args[1].IsObjectRef<Tensor>()) { *ret = args[0].operator Schedule() .cache_write(args[1].operator Tensor(), args[2]); } else { *ret = args[0].operator Schedule() .cache_write(args[1].operator Array<Tensor>(), args[2]); } }); TVM_REGISTER_API("_ScheduleRFactor") .set_body_method(&Schedule::rfactor); TVM_REGISTER_API("_CommReducerCombine") .set_body_method<ir::CommReducer>(&ir::CommReducerNode::operator()); } // namespace tvm