/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/target/generic_func.cc */ #include <dmlc/thread_local.h> #include <tvm/runtime/registry.h> #include <tvm/node/node.h> #include <tvm/node/repr_printer.h> #include <tvm/target/target.h> #include <tvm/target/generic_func.h> #include <tvm/runtime/registry.h> #include <tvm/tir/expr.h> #include <algorithm> #include <mutex> #include <stack> namespace tvm { TVM_REGISTER_NODE_TYPE(GenericFuncNode); struct GenericFunc::Manager { std::unordered_map<std::string, GenericFunc> fmap; // mutex std::mutex mutex; Manager() { } static Manager* Global() { static Manager inst; return &inst; } }; GenericFunc GenericFunc::Get(const std::string& name) { Manager* m = Manager::Global(); std::lock_guard<std::mutex>(m->mutex); auto it = m->fmap.find(name); if (it == m->fmap.end()) { auto f = make_object<GenericFuncNode>(); f->name_ = name; auto gf = GenericFunc(f); m->fmap[name] = gf; return gf; } else { return it->second; } } void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) { Manager* m = Manager::Global(); std::lock_guard<std::mutex>(m->mutex); auto it = m->fmap.find(name); CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name; func->name_ = name; m->fmap[name] = func; } GenericFunc& GenericFunc::set_default(const PackedFunc value, bool allow_override) { auto node = static_cast<GenericFuncNode*>(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) << "Generic function already registered for " << node->name_; } node->generic_func_ = value; return *this; } GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags, const PackedFunc value, bool allow_override) { for (auto &t : tags) { if (!allow_override) { auto iter = (*this)->dispatch_dict_.find(t); CHECK(iter == (*this)->dispatch_dict_.end()) << "Tag " << t << " already registered for schedule factory " << (*this)->name_; } (*this)->dispatch_dict_[t] = value; } return *this; } void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { auto node = static_cast<const GenericFuncNode*>(get()); auto target = Target::Current(true); PackedFunc func; if (target.defined()) { for (auto &k : target->keys()) { auto iter = node->dispatch_dict_.find(k); if (iter != node->dispatch_dict_.end()) { func = iter->second; break; } } } if (func == nullptr) { CHECK(node->generic_func_ != nullptr) << "No generic function registered for " << node->name_; func = node->generic_func_; } func.CallPacked(args, ret); } TVM_REGISTER_GLOBAL("target.GenericFuncCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = GenericFunc(make_object<GenericFuncNode>()); }); TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string func_name = args[0]; *ret = GenericFunc::Get(func_name); }); TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); bool allow_override = args[2]; generic_func .set_default(*func, allow_override); }); TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); Array<PrimExpr> tags = args[2]; bool allow_override = args[3]; std::vector<std::string> tags_vector; for (auto& tag : tags) { tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value); } generic_func .register_func(tags_vector, *func, allow_override); }); TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); generic_func .CallPacked(func_args, ret); }); } // namespace tvm