/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file relay/backend/compile_engine.cc * \brief Internal compialtion engine. */ #include <tvm/ir/type_functor.h> #include <tvm/te/schedule.h> #include <tvm/te/operation.h> #include <tvm/te/schedule_pass.h> #include <tvm/runtime/registry.h> #include <tvm/runtime/container.h> #include <tvm/relay/attrs/device_copy.h> #include <tvm/relay/analysis.h> #include <tvm/relay/expr.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/op.h> #include <tvm/relay/op_attr_types.h> #include <tvm/driver/driver_api.h> #include <topi/tags.h> #include <utility> #include <limits> #include <mutex> #include <functional> #include <vector> #include <unordered_map> #include "compile_engine.h" namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(LoweredOutputNode); TVM_REGISTER_NODE_TYPE(CachedFuncNode); TVM_REGISTER_NODE_TYPE(CCacheKeyNode); TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) { auto n = make_object<LoweredOutputNode>(); n->outputs = std::move(outputs); n->implementation = std::move(impl); data_ = std::move(n); } CCacheKey::CCacheKey(Function source_func, Target target) { auto n = make_object<CCacheKeyNode>(); n->source_func = std::move(source_func); n->target = std::move(target); data_ = std::move(n); } struct IsDynamicVisitor : public TypeVisitor { bool is_dyn{false}; void VisitType_(const TensorTypeNode* tt) { for (auto dim : tt->shape) { if (dim.as<Any>()) { is_dyn = true; break; } } } }; bool IsDynamic(const Type& ty) { IsDynamicVisitor v; v.VisitType(ty); return v.is_dyn; } // TODO(@jroesch): MOVE ME TVM_REGISTER_GLOBAL("relay.ir.IsDynamic") .set_body_typed(IsDynamic); Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) { // for now, we always use int32 shape when possible // even if the result of shape inference becomes int64. Array<IndexExpr> res; for (IndexExpr val : shape) { const int64_t* pval = tir::as_const_int(val); if (pval != nullptr) { CHECK_LE(pval[0], std::numeric_limits<int32_t>::max()); CHECK_GE(pval[0], std::numeric_limits<int32_t>::min()); res.push_back(IntImm(DataType::Int(32), *pval)); } else if (val->IsInstance<tir::AnyNode>()) { res.push_back(val.as<tir::AnyNode>()->ToVar()); } else { res.push_back(val); } } return res; } // The getter to get schedule from compile engine. // Get schedule from functor. class ScheduleGetter : public ExprFunctor<Array<te::Tensor>(const Expr&)> { public: explicit ScheduleGetter(Target target) : target_(target), device_copy_op_(Op::Get("device_copy")) {} CachedFunc Create(const Function& prim_func) { auto cache_node = make_object<CachedFuncNode>(); cache_node->target = target_; for (Var param : prim_func->params) { Array<tvm::te::Tensor> inputs; if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) { tvm::te::Tensor tensor = tvm::te::placeholder( GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } else { // flatten tuple of tensor type. const auto* tuple_type = param->type_as<TupleTypeNode>(); for (Type field : tuple_type->fields) { const auto* ttype = field.as<TensorTypeNode>(); // TODO(@icemelon): Allow recursive tuple CHECK(ttype != nullptr); tvm::te::Tensor tensor = tvm::te::placeholder( GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } } memo_[param] = inputs; } readable_name_stream_ << "fused"; cache_node->outputs = this->VisitExpr(prim_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } cache_node->func_name = candidate_name; CHECK(master_op_.defined()); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. // Hence schedule only non PlaceholderOp outputs. tvm::Array<te::Tensor> tensor_outs; for (const auto& tensor : cache_node->outputs) { if (!tensor->op.as<te::PlaceholderOpNode>()) { tensor_outs.push_back(tensor); } } te::Schedule schedule; // No need to register schedule for device copy op. if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) { CHECK(master_implementation_.defined()); schedule = master_implementation_.Schedule(master_attrs_, tensor_outs, target_); for (const auto& scalar : scalars_) { if (schedule->Contain(scalar)) { schedule[scalar].compute_inline(); } } } cache_node->schedule = std::move(schedule); return CachedFunc(cache_node); } Array<te::Tensor> VisitExpr(const Expr& expr) { auto it = memo_.find(expr); if (it != memo_.end()) { return it->second; } else { Array<te::Tensor> res = ExprFunctor::VisitExpr(expr); memo_[expr] = res; return res; } } Array<te::Tensor> VisitExpr_(const VarNode* op) final { LOG(FATAL) << "Free variable " << op->name_hint(); return {}; } Array<te::Tensor> VisitExpr_(const ConstantNode* op) final { using tir::make_const; CHECK(op->is_scalar()); void* data = op->data->data; DataType dtype = DataType(op->data->dtype); auto value = te::compute({}, [&](const Array<tvm::tir::Var>&) { if (dtype == DataType::Int(32)) { return make_const(dtype, static_cast<const int32_t*>(data)[0]); } else if (dtype == DataType::Int(64)) { return make_const(dtype, static_cast<const int64_t*>(data)[0]); } else if (dtype == DataType::Float(32)) { return make_const(dtype, static_cast<const float*>(data)[0]); } else if (dtype == DataType::Float(64)) { return make_const(dtype, static_cast<const double*>(data)[0]); } else if (dtype == DataType::Bool()) { return make_const(dtype, static_cast<const uint8_t*>(data)[0]); } else { LOG(FATAL) << "not handled"; return tvm::PrimExpr(); } }, "compile_engine_const", topi::kBroadcast); scalars_.push_back(value->op); return {value}; } Array<te::Tensor> VisitExpr_(const CallNode* call_node) final { static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); CHECK(flower_call) << "relay.backend.lower_call is not registered."; Array<te::Tensor> inputs; int count_tuple = 0; for (Expr arg : call_node->args) { if (arg->checked_type().as<TupleTypeNode>()) { ++count_tuple; } for (te::Tensor tensor : VisitExpr(arg)) { inputs.push_back(tensor); } } if (count_tuple) { CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } CHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops"; Op op = Downcast<Op>(call_node->op); Array<te::Tensor> outputs; OpImplementation impl; // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_); outputs = lowered_out->outputs; impl = lowered_out->implementation; } int op_pattern = fpattern[op]; if (op_pattern >= kCommReduce) { CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce) << "Two complicated op in a primitive function " << " master=" << master_op_ << " current=" << op; } if (op_pattern >= master_op_pattern_) { master_op_ = op; master_attrs_ = call_node->attrs; master_op_pattern_ = op_pattern; master_implementation_ = impl; } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>(); CHECK(tuple_type) << "Expect output to be a tuple type"; CHECK_EQ(tuple_type->fields.size(), outputs.size()); } // Set the name to `__copy`. It will be detected in graph runtime to perform // data copy across devices. if (op == device_copy_op_) { readable_name_stream_.str(std::string()); readable_name_stream_ << "__copy"; } else { readable_name_stream_ << '_' << op->name; } return outputs; } Array<te::Tensor> VisitExpr_(const FunctionNode* op) final { LOG(FATAL) << "Do not support sub function"; return Array<te::Tensor>(); } Array<te::Tensor> VisitExpr_(const LetNode* op) final { Array<te::Tensor> val = VisitExpr(op->value); CHECK(!memo_.count(op->var)); memo_[op->var] = val; return VisitExpr(op->body); } Array<te::Tensor> VisitExpr_(const TupleNode* op) final { Array<te::Tensor> fields; for (Expr field : op->fields) { CHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor"; Array<te::Tensor> res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); } return fields; } Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final { const auto* tuple_type = op->tuple->type_as<TupleTypeNode>(); Array<te::Tensor> tuple = VisitExpr(op->tuple); CHECK_EQ(tuple_type->fields.size(), tuple.size()); CHECK_GE(op->index, 0); CHECK_LT(static_cast<size_t>(op->index), tuple.size()); return {tuple[op->index]}; } private: tvm::Target target_; Op master_op_; Attrs master_attrs_; int master_op_pattern_{0}; OpImplementation master_implementation_; std::ostringstream readable_name_stream_; std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_; Array<te::Operation> scalars_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; }; // Creates shape function from functor. class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> { public: MakeShapeFunc() {} std::pair<te::Schedule, CachedFunc> Create(const Function& prim_func) { for (auto param : prim_func->params) { param_states_[param] = kNoNeed; Array<tvm::te::Tensor> data_inputs; Array<tvm::te::Tensor> shape_inputs; auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { // Add data placeholder Shape shape = GetShape(ttype->shape); tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); data_inputs.push_back(data_tensor); // Add shape placeholder int64_t ndim = shape.size(); Shape sshape; if (ndim > 0) { sshape.push_back(tvm::Integer(ndim)); } tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); shape_inputs.push_back(shape_tensor); }; if (const auto *ttype = param->checked_type().as<TensorTypeNode>()) { add_placeholder(ttype); } else { // flatten tuple of tensor type. const auto *tuple_type = param->type_as<TupleTypeNode>(); // TODO(@icemelon): Support recursive tuple CHECK(tuple_type); for (Type field : tuple_type->fields) { const auto *ttype = field.as<TensorTypeNode>(); CHECK(ttype); add_placeholder(ttype); } } param_data_[param] = data_inputs; param_shapes_[param] = shape_inputs; } readable_name_stream_ << "shape_func"; auto cache_node = make_object<CachedFuncNode>(); cache_node->outputs = VisitExpr(prim_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } cache_node->func_name = candidate_name; // set inputs for (auto param : prim_func->params) { int state = param_states_[param]; cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); if (state & kNeedInputData) { for (auto t : param_data_[param]) { cache_node->inputs.push_back(t); } } if (state & kNeedInputShape) { for (auto t : param_shapes_[param]) { cache_node->inputs.push_back(t); } } } CachedFunc cfunc(cache_node); // generate schedule for shape func Array<te::Operation> out_ops; for (auto t : cache_node->outputs) { out_ops.push_back(t->op); } auto schedule = te::create_schedule(out_ops); tvm::te::AutoInlineInjective(schedule); for (const auto& scalar : scalars_) { auto scalar_op = scalar->op; if (schedule->Contain(scalar_op)) { schedule[scalar_op].compute_inline(); } } return std::make_pair(schedule, cfunc); } Array<te::Tensor> VisitExpr(const Expr& expr) { auto it = memo_.find(expr); if (it != memo_.end()) { return it->second; } else { Array<te::Tensor> res = ExprFunctor::VisitExpr(expr); if (expr.as<VarNode>() == nullptr) { // Do not memoize vars because shape functions could use either the data // or the shape of a var each time. memo_[expr] = res; } return res; } } Array<te::Tensor> VisitExpr_(const VarNode* var_node) final { auto var = GetRef<Var>(var_node); auto it = param_states_.find(var); if (it == param_states_.end()) { LOG(FATAL) << "Free variable " << var->name_hint(); return {}; } else { CHECK(data_dependants_.size()); bool data_dependant = data_dependants_.back(); if (data_dependant) { param_states_[var] |= kNeedInputData; return param_data_[var]; } else { param_states_[var] |= kNeedInputShape; return param_shapes_[var]; } } } Array<te::Tensor> VisitExpr_(const ConstantNode* op) final { using tir::make_const; CHECK(data_dependants_.size()); CHECK(op->is_scalar()); bool data_dependant = data_dependants_.back(); if (data_dependant) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); auto value = tvm::te::compute({}, [&](const Array<tvm::tir::Var>&) { if (dtype == DataType::Int(32)) { return make_const(dtype, static_cast<const int32_t*>(data)[0]); } else if (dtype == DataType::Int(64)) { return make_const(dtype, static_cast<const int64_t*>(data)[0]); } else if (dtype == DataType::Float(32)) { return make_const(dtype, static_cast<const float*>(data)[0]); } else if (dtype == DataType::Float(64)) { return make_const(dtype, static_cast<const double*>(data)[0]); } else if (dtype == DataType::Bool()) { return make_const(dtype, static_cast<const uint8_t*>(data)[0]); } else { LOG(FATAL) << "not handled"; return tvm::PrimExpr(); } }, "data_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } else { auto value = tvm::te::compute({}, [&](const Array<tvm::tir::Var>&) { return tir::make_const(DataType::Int(64), 0); }, "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } } Array<te::Tensor> VisitExpr_(const CallNode* call_node) final { static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc"); static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>( "TShapeDataDependant"); CHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops"; Op op = Downcast<Op>(call_node->op); CHECK(data_dependants_.empty() || !data_dependants_.back()) << "Error in op fusion: output of the shape func is fed to a " << "data-dependant shape func"; CHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; CHECK_GT(tshape_data_dependant.count(op), 0) << "Internal error, cannot find TShapeDataDependant for " << op->name; data_dependants_.push_back(tshape_data_dependant[op]); // Visit all inputs Array<te::Tensor> inputs; int count_tuple = 0; for (Expr arg : call_node->args) { if (arg->checked_type().as<TupleTypeNode>()) { ++count_tuple; } for (te::Tensor tensor : VisitExpr(arg)) { inputs.push_back(tensor); } } if (count_tuple) { CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } // Get output ndims auto ret_type = call_node->checked_type(); Array<IndexExpr> out_ndims; if (const auto* ttype = ret_type.as<TensorTypeNode>()) { out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } else { auto rtype = ret_type.as<TupleTypeNode>(); // TODO(@icemelon): Allow recursive tuple CHECK(rtype); for (size_t i = 0; i < rtype->fields.size(); ++i) { auto ttype = rtype->fields[i].as<TensorTypeNode>(); CHECK(ttype); out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } } // Call shape function auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); data_dependants_.pop_back(); readable_name_stream_ << "_" << op->name; return outputs; } Array<te::Tensor> VisitExpr_(const FunctionNode* op) final { LOG(FATAL) << "Do not support sub function"; return Array<te::Tensor>(); } Array<te::Tensor> VisitExpr_(const LetNode* op) final { Array<te::Tensor> val = VisitExpr(op->value); CHECK(!memo_.count(op->var)); memo_[op->var] = val; return VisitExpr(op->body); } Array<te::Tensor> VisitExpr_(const TupleNode* op) final { Array<te::Tensor> fields; for (Expr field : op->fields) { CHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor"; Array<te::Tensor> res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); } return fields; } private: /*! \brief String stream for function name */ std::ostringstream readable_name_stream_; /*! \brief Map from parameter to its shape function usage state */ std::unordered_map<Expr, int, ObjectHash, ObjectEqual> param_states_; /*! \brief Map from parameter to list of data placeholder */ std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_; /*! \brief Map from parameter to list of shape placeholder */ std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_; /*! \brief Memoized visit result */ std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_; /*! \brief Stack of data dependencies for shape function */ std::vector<bool> data_dependants_; /*! \brief Scalars used in the shape function */ Array<te::Tensor> scalars_; }; class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { CCacheValue value = LowerInternal(key); if (value->packed_func != nullptr) return value->packed_func; // build the function. tvm::runtime::Module m; if (const auto* f = runtime::Registry::Get("relay.backend.build")) { m = (*f)(value->cached_func->funcs, key->target); } else { m = build(value->cached_func->funcs, key->target, Target(nullptr), BuildConfig::Current()); } value->packed_func = m.GetFunction(value->cached_func->func_name); return value->packed_func; } CachedFunc LowerShapeFunc(const CCacheKey& key) final { return LowerShapeFuncInternal(key)->cached_func; } Array<tvm::runtime::Module> LowerExternalFunctions() { std::unordered_map<std::string, IRModule> ext_mods; std::vector<CCacheKey> cached_ext_funcs; for (const auto& it : cache_) { auto src_func = it.first->source_func; CHECK(src_func.defined()); if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; if (ext_mods.find(code_gen->value) == ext_mods.end()) { ext_mods[code_gen->value] = IRModule({}, {}); } auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); auto gv = GlobalVar(std::string(symbol_name)); ext_mods[code_gen->value]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } } Array<tvm::runtime::Module> ret; for (const auto& it : ext_mods) { std::string ext_name = "relay.ext." + it.first; auto pf = tvm::runtime::Registry::Get(ext_name); CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; runtime::Module ext_mod = (*pf)(it.second); CHECK(ext_mod.defined()) << "No external runtime is generated."; ret.push_back(ext_mod); } // No need to cache external functions as we collected them all to create // external runtime modules. for (const auto& it : cached_ext_funcs) { cache_.erase(it); } return ret; } void Clear() final { cache_.clear(); } // List all items in the cache. Array<ObjectRef> ListItems() { std::lock_guard<std::mutex> lock(mutex_); Array<ObjectRef> items; for (auto& kv : cache_) { items.push_back(kv.first); items.push_back(kv.second); } return items; } /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. * \param target The target we want to create schedule for. * \return Pair of schedule and cache. * The funcs field in cache is not yet populated. */ CachedFunc CreateSchedule(const Function& source_func, const Target& target) { return ScheduleGetter(target).Create(source_func); } private: // implement lowered func CCacheValue LowerInternal(const CCacheKey& key) { std::lock_guard<std::mutex> lock(mutex_); CCacheValue value; auto it = cache_.find(key); if (it != cache_.end()) { it->second->use_count += 1; if (it->second->cached_func.defined()) return it->second; value = it->second; } else { value = CCacheValue(make_object<CCacheValueNode>()); value->use_count = 0; cache_[key] = value; } // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) { auto cache_node = make_object<CachedFuncNode>(); const auto name_node = key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); return value; } // Enforce use the target. With<Target> target_scope(key->target); CHECK(!value->cached_func.defined()); auto cfunc = CreateSchedule(key->source_func, key->target); auto cache_node = make_object<CachedFuncNode>( *(cfunc.operator->())); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; if (const CallNode* call_node = body.as<CallNode>()) { if (call_node->attrs.as<DeviceCopyAttrs>()) { value->cached_func = CachedFunc(cache_node); return value; } } cache_node->func_name = GetUniqueName(cache_node->func_name); // NOTE: array will copy on write. Array<te::Tensor> all_args = cache_node->inputs; for (te::Tensor arg : cache_node->outputs) { all_args.push_back(arg); } // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { cache_node->funcs = (*f)( cfunc->schedule, all_args, cache_node->func_name, key->source_func); } else { tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map<te::Tensor, tir::Buffer> binds; cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds, bcfg); } value->cached_func = CachedFunc(cache_node); return value; } // implement lowered shape func CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { std::lock_guard<std::mutex> lock(mutex_); CCacheValue value; auto it = shape_func_cache_.find(key); if (it != shape_func_cache_.end()) { it->second->use_count += 1; if (it->second->cached_func.defined()) return it->second; value = it->second; } else { value = CCacheValue(make_object<CCacheValueNode>()); value->use_count = 0; shape_func_cache_[key] = value; } // Enforce use the target. With<Target> target_scope(key->target); CHECK(!value->cached_func.defined()); auto spair = MakeShapeFunc().Create(key->source_func); auto cache_node = make_object<CachedFuncNode>( *(spair.second.operator->())); cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->target = key->target; Array<te::Tensor> all_args = cache_node->inputs; for (te::Tensor arg : cache_node->outputs) { all_args.push_back(arg); } tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map<te::Tensor, tir::Buffer> binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); value->cached_func = CachedFunc(cache_node); return value; } /*! * \brief Get unique name from name. * \param name The orginal name. * \return Updated name which is unique. */ std::string GetUniqueName(std::string name) { for (size_t i = 0; i < name.length(); ++i) { if (name[i] == '.') name[i] = '_'; } while (true) { auto it = name_map_.find(name); if (it == name_map_.end()) { name_map_[name] = 1; return name; } else { std::ostringstream os; os << name << "_" << it->second; ++(it->second); name = os.str(); } } return name; } /*! \brief compiler cache lock*/ std::mutex mutex_; /*! \brief internal name map to get an unique name */ std::unordered_map<std::string, int> name_map_; /*! \brief internal compiler cache */ std::unordered_map<CCacheKey, CCacheValue> cache_; /*! \brief internal compiler cache for shape funcs */ std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_; }; /*! \brief The global compile engine */ const CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. static CompileEngine* inst = new CompileEngine( make_object<CompileEngineImpl>()); return *inst; } TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") .set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) { return LoweredOutput(outputs, impl); }); TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") .set_body_typed([](Function source_func, Target target) { return CCacheKey(source_func, target); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") .set_body_typed([]() { return CompileEngine::Global(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") .set_body_typed([](CompileEngine self) { self->Clear(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") .set_body_typed( [](CompileEngine self, CCacheKey key) { return self->Lower(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") .set_body_typed( [](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") .set_body_typed( [](CompileEngine self, CCacheKey key) { return self->JIT(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") .set_body_typed( [](CompileEngine self){ return static_cast<CompileEngineImpl*>(self.operator->())->ListItems(); }); } // namespace relay } // namespace tvm