Unverified Commit d2f9af78 by Zhi Committed by GitHub

[REFACTOR][IR] kExternalSymbol -> kGlobalSymbol (#5211)

* expose runtime::String to Python

* kExternalSymbol -> kGlobalSymbol
parent 03cbf78e
...@@ -512,12 +512,12 @@ class String : public ObjectRef { ...@@ -512,12 +512,12 @@ class String : public ObjectRef {
#endif #endif
} }
TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
private:
/*! \return the internal StringObj pointer */ /*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); } const StringObj* get() const { return operator->(); }
TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
private:
/*! /*!
* \brief Compare two char sequence * \brief Compare two char sequence
* *
......
...@@ -109,4 +109,22 @@ def tuple_object(fields=None): ...@@ -109,4 +109,22 @@ def tuple_object(fields=None):
return _Tuple(*fields) return _Tuple(*fields)
@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
Parameters
----------
string : Str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_String, string)
tvm._ffi._init_api("tvm.runtime.container") tvm._ffi._init_api("tvm.runtime.container")
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h> #include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h> #include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -622,10 +623,10 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -622,10 +623,10 @@ class CompileEngineImpl : public CompileEngineNode {
if (ext_mods.find(code_gen->value) == ext_mods.end()) { if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = IRModule({}, {}); ext_mods[code_gen->value] = IRModule({}, {});
} }
auto symbol_name = src_func->GetAttr<tir::StringImm>(attr::kExternalSymbol); auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n" CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false); << AsText(src_func, false);
auto gv = GlobalVar(symbol_name->value); auto gv = GlobalVar(std::string(symbol_name));
ext_mods[code_gen->value]->Add(gv, src_func); ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first); cached_ext_funcs.push_back(it.first);
} }
...@@ -693,10 +694,10 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -693,10 +694,10 @@ class CompileEngineImpl : public CompileEngineNode {
if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) { if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>(); auto cache_node = make_object<CachedFuncNode>();
const auto name_node = const auto name_node =
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol); key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) CHECK(name_node.defined())
<< "External function has not been attached a name yet."; << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value; cache_node->func_name = std::string(name_node);
cache_node->target = tvm::target::ext_dev(); cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node); value->cached_func = CachedFunc(cache_node);
return value; return value;
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/function.h> #include <tvm/relay/function.h>
#include <tvm/runtime/container.h>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -69,10 +70,9 @@ class CSourceModuleCodegenBase { ...@@ -69,10 +70,9 @@ class CSourceModuleCodegenBase {
*/ */
std::string GetExtSymbol(const Function& func) const { std::string GetExtSymbol(const Function& func) const {
const auto name_node = const auto name_node =
func->GetAttr<tir::StringImm>(attr::kExternalSymbol); func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol."; CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value; return std::string(name_node);
return ext_symbol;
} }
}; };
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -239,8 +240,8 @@ class Partitioner : public ExprMutator { ...@@ -239,8 +240,8 @@ class Partitioner : public ExprMutator {
std::string target = call->attrs.as<CompilerAttrs>()->compiler; std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID()); std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol, global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
tir::StringImmNode::make(name)); runtime::String(name));
global_region_func = global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
......
...@@ -76,7 +76,13 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT") ...@@ -76,7 +76,13 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT")
*rv = ADT(tag, fields); *rv = ADT(tag, fields);
}); });
TVM_REGISTER_GLOBAL("runtime.container._String")
.set_body_typed([](std::string str) {
return String(std::move(str));
});
TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace runtime } // namespace runtime
......
...@@ -80,7 +80,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -80,7 +80,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol): def set_external_func_attr(func, compiler, ext_symbol):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm(compiler)) func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol)) func = func.with_attr("global_symbol",
runtime.container.String(ext_symbol))
return func return func
......
...@@ -23,6 +23,7 @@ import tvm ...@@ -23,6 +23,7 @@ import tvm
import tvm.relay.testing import tvm.relay.testing
from tvm import relay from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.runtime import container
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
...@@ -305,10 +306,8 @@ def test_extern_ccompiler_default_ops(): ...@@ -305,10 +306,8 @@ def test_extern_ccompiler_default_ops():
func = relay.Function([x0, y0], add) func = relay.Function([x0, y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
tvm.tir.StringImm("ccompiler")) func = func.with_attr("global_symbol", container.String("ccompiler_0"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y]) add_call = relay.Call(glb_0, [x, y])
...@@ -319,7 +318,7 @@ def test_extern_ccompiler_default_ops(): ...@@ -319,7 +318,7 @@ def test_extern_ccompiler_default_ops():
concat = relay.concatenate([log, exp], axis=0) concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat) fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive", fused_func = fused_func.with_attr("Primitive",
tvm.tir.IntImm("int32", 1)) tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call]) fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call) main = relay.Function([x, y], fused_call)
mod["main"] = main mod["main"] = main
...@@ -393,8 +392,7 @@ def test_extern_dnnl(): ...@@ -393,8 +392,7 @@ def test_extern_dnnl():
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl")) func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
func = func.with_attr("ExternalSymbol", func = func.with_attr("global_symbol", container.String("dnnl_0"))
tvm.tir.StringImm("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0") glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule() mod = tvm.IRModule()
mod[glb_var] = func mod[glb_var] = func
...@@ -520,8 +518,8 @@ def test_function_lifting(): ...@@ -520,8 +518,8 @@ def test_function_lifting():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler")) tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("ExternalSymbol", func0 = func0.with_attr("global_symbol",
tvm.tir.StringImm("test_compiler_0")) container.String("test_compiler_0"))
gv0 = relay.GlobalVar("test_compiler_0") gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0 mod[gv0] = func0
...@@ -539,8 +537,8 @@ def test_function_lifting(): ...@@ -539,8 +537,8 @@ def test_function_lifting():
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler", func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_compiler")) tvm.tir.StringImm("test_compiler"))
func1 = func1.with_attr("ExternalSymbol", func1 = func1.with_attr("global_symbol",
tvm.tir.StringImm("test_compiler_1")) container.String("test_compiler_1"))
gv1 = relay.GlobalVar("test_compiler_1") gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1 mod[gv1] = func1
...@@ -613,8 +611,8 @@ def test_function_lifting_inline(): ...@@ -613,8 +611,8 @@ def test_function_lifting_inline():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler")) tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("ExternalSymbol", func0 = func0.with_attr("global_symbol",
tvm.tir.StringImm("test_compiler_0")) container.String("test_compiler_0"))
# main function # main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
...@@ -649,8 +647,7 @@ def test_constant_propagation(): ...@@ -649,8 +647,7 @@ def test_constant_propagation():
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("ExternalSymbol", func = func.with_attr("global_symbol", container.String("ccompiler_0"))
tvm.tir.StringImm("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func mod[glb_0] = func
add_call = relay.Call(glb_0, [y]) add_call = relay.Call(glb_0, [y])
...@@ -751,8 +748,8 @@ def test_multiple_outputs(): ...@@ -751,8 +748,8 @@ def test_multiple_outputs():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target")) tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol", func0 = func0.with_attr("global_symbol",
tvm.tir.StringImm("test_target_2")) container.String("test_target_2"))
gv0 = relay.GlobalVar("test_target_2") gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0 mod[gv0] = func0
...@@ -819,8 +816,8 @@ def test_mixed_single_multiple_outputs(): ...@@ -819,8 +816,8 @@ def test_mixed_single_multiple_outputs():
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler", func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_target")) tvm.tir.StringImm("test_target"))
func1 = func1.with_attr("ExternalSymbol", func1 = func1.with_attr("global_symbol",
tvm.tir.StringImm("test_target_1")) container.String("test_target_1"))
gv1 = relay.GlobalVar("test_target_1") gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1 mod[gv1] = func1
...@@ -834,8 +831,8 @@ def test_mixed_single_multiple_outputs(): ...@@ -834,8 +831,8 @@ def test_mixed_single_multiple_outputs():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target")) tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol", func0 = func0.with_attr("global_symbol",
tvm.tir.StringImm("test_target_0")) container.String("test_target_0"))
gv0 = relay.GlobalVar("test_target_0") gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0 mod[gv0] = func0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment