Unverified Commit 06bbc7c9 by Zhi Committed by GitHub

Replace UseDefaultCompiler with GetAttr (#5088)

parent 4ae46748
...@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode { ...@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
*/ */
TVM_DLL FuncType func_type_annotation() const; TVM_DLL FuncType func_type_annotation() const;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
}; };
......
...@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) { for (const auto& it : cache_) {
auto src_func = it.first->source_func; auto src_func = it.first->source_func;
CHECK(src_func.defined()); CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) { if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler); auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set"; CHECK(code_gen.defined()) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) { if (ext_mods.find(code_gen->value) == ext_mods.end()) {
...@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
} }
// No need to lower external functions for now. We will invoke the external // No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together. // codegen tool once and lower all functions together.
if (!key->source_func->UseDefaultCompiler()) { if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>(); auto cache_node = make_object<CachedFuncNode>();
const auto name_node = const auto name_node =
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol); key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
......
...@@ -424,7 +424,7 @@ class GraphRuntimeCodegen ...@@ -424,7 +424,7 @@ class GraphRuntimeCodegen
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target; Target target;
// Handle external function // Handle external function
if (!func->UseDefaultCompiler()) { if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev(); target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target); CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key); CachedFunc ext_func = (*pf1)(compile_engine_, key);
...@@ -490,7 +490,8 @@ class GraphRuntimeCodegen ...@@ -490,7 +490,8 @@ class GraphRuntimeCodegen
return {}; return {};
} }
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override { std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen"; CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
<< "Only functions supported by custom codegen";
return {}; return {};
} }
std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override { std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
......
...@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Target target; Target target;
if (!func->UseDefaultCompiler()) { if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev(); target = tvm::target::ext_dev();
} else { } else {
// Next generate the invoke instruction. // Next generate the invoke instruction.
...@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine_->Lower(key); auto cfunc = engine_->Lower(key);
auto op_index = -1; auto op_index = -1;
if (!func->UseDefaultCompiler()) { if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size(); op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc); context_->cached_funcs.push_back(cfunc);
} else { } else {
......
...@@ -122,17 +122,17 @@ struct PrimitiveInliner : ExprMutator { ...@@ -122,17 +122,17 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first; auto global = pair.first;
auto base_func = pair.second; auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) { if (auto* n = base_func.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue; if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n); auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false); << std::endl << AsText(func, false);
func = Function(func->params, func = Function(func->params,
VisitExpr(func->body), VisitExpr(func->body),
func->ret_type, func->ret_type,
func->type_params, func->type_params,
func->attrs); func->attrs);
module_->Add(global, func, true); module_->Add(global, func, true);
DLOG(INFO) << "After inlining primitives: " << global DLOG(INFO) << "After inlining primitives: " << global
......
...@@ -187,13 +187,13 @@ class LambdaLifter : public ExprMutator { ...@@ -187,13 +187,13 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions; auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) { for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) { if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue; if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n); auto func = GetRef<Function>(n);
func = Function(func->params, func = Function(func->params,
VisitExpr(func->body), VisitExpr(func->body),
func->ret_type, func->ret_type,
func->type_params, func->type_params,
func->attrs); func->attrs);
module_->Add(pair.first, func, true); module_->Add(pair.first, func, true);
} }
} }
......
...@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const { ...@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncType(param_types, ret_type, this->type_params, {}); return FuncType(param_types, ret_type, this->type_params, {});
} }
bool FunctionNode::UseDefaultCompiler() const {
tir::StringImm val = this->GetAttr<tir::StringImm>(attr::kCompiler);
return !val.defined() || val->value == "default";
}
TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay.ir.Function") TVM_REGISTER_GLOBAL("relay.ir.Function")
......
...@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, ...@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
bool FunctionPassNode::SkipFunction(const Function& func) const { bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 || return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
!(func->UseDefaultCompiler()); (func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
} }
Pass CreateFunctionPass( Pass CreateFunctionPass(
......
...@@ -125,13 +125,13 @@ class Inliner : ExprMutator { ...@@ -125,13 +125,13 @@ class Inliner : ExprMutator {
CHECK(fn) << "Expected to work on a Relay function."; CHECK(fn) << "Expected to work on a Relay function.";
auto func = Function(fn->params, auto func = Function(fn->params,
fn->body, fn->body,
fn->ret_type, fn->ret_type,
fn->type_params, fn->type_params,
fn->attrs); fn->attrs);
// Inline the function body to the caller if this function uses default // Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed. // compiler, i.e. no external codegen is needed.
if (func->UseDefaultCompiler()) { if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size()) CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args"; << "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args. // Bind the parameters with call args.
......
...@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) { ...@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for (const auto& it : funcs) { for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0); CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) { if (const auto* n = it.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue; if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
} }
Expr ret = Expr ret =
TransformF([&](const Expr& e) { TransformF([&](const Expr& e) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment