Unverified Commit 06bbc7c9 by Zhi Committed by GitHub

Replace UseDefaultCompiler with GetAttr (#5088)

parent 4ae46748
......@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
*/
TVM_DLL FuncType func_type_annotation() const;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
......
......@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) {
if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
......@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (!key->source_func->UseDefaultCompiler()) {
if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
......
......@@ -424,7 +424,7 @@ class GraphRuntimeCodegen
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
if (!func->UseDefaultCompiler()) {
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
......@@ -490,7 +490,8 @@ class GraphRuntimeCodegen
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen";
CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
<< "Only functions supported by custom codegen";
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
......
......@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Target target;
if (!func->UseDefaultCompiler()) {
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
......@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine_->Lower(key);
auto op_index = -1;
if (!func->UseDefaultCompiler()) {
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
......
......@@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
......
......@@ -187,7 +187,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
func = Function(func->params,
VisitExpr(func->body),
......
......@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncType(param_types, ret_type, this->type_params, {});
}
bool FunctionNode::UseDefaultCompiler() const {
tir::StringImm val = this->GetAttr<tir::StringImm>(attr::kCompiler);
return !val.defined() || val->value == "default";
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay.ir.Function")
......
......@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
!(func->UseDefaultCompiler());
(func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
}
Pass CreateFunctionPass(
......
......@@ -131,7 +131,7 @@ class Inliner : ExprMutator {
fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
if (func->UseDefaultCompiler()) {
if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
......
......@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment