Unverified Commit d2f9af78 by Zhi Committed by GitHub

[REFACTOR][IR] kExternalSymbol -> kGlobalSymbol (#5211)

* expose runtime::String to Python

* kExternalSymbol -> kGlobalSymbol
parent 03cbf78e
......@@ -512,12 +512,12 @@ class String : public ObjectRef {
#endif
}
TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
private:
/*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); }
TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
private:
/*!
* \brief Compare two char sequence
*
......
......@@ -109,4 +109,22 @@ def tuple_object(fields=None):
return _Tuple(*fields)
@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
Parameters
----------
string : Str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_String, string)
tvm._ffi._init_api("tvm.runtime.container")
......@@ -26,6 +26,7 @@
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
......@@ -622,10 +623,10 @@ class CompileEngineImpl : public CompileEngineNode {
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = IRModule({}, {});
}
auto symbol_name = src_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(symbol_name->value);
auto gv = GlobalVar(std::string(symbol_name));
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
......@@ -693,10 +694,10 @@ class CompileEngineImpl : public CompileEngineNode {
if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->func_name = std::string(name_node);
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
return value;
......
......@@ -27,6 +27,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/container.h>
#include <sstream>
#include <string>
#include <utility>
......@@ -69,10 +70,9 @@ class CSourceModuleCodegenBase {
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node =
func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value;
return ext_symbol;
return std::string(name_node);
}
};
......
......@@ -35,6 +35,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>
#include <unordered_map>
#include <unordered_set>
......@@ -239,8 +240,8 @@ class Partitioner : public ExprMutator {
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
tir::StringImmNode::make(name));
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
......
......@@ -76,7 +76,13 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT")
*rv = ADT(tag, fields);
});
TVM_REGISTER_GLOBAL("runtime.container._String")
.set_body_typed([](std::string str) {
return String(std::move(str));
});
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace runtime
......
......@@ -80,7 +80,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
func = func.with_attr("global_symbol",
runtime.container.String(ext_symbol))
return func
......
......@@ -23,6 +23,7 @@ import tvm
import tvm.relay.testing
from tvm import relay
from tvm import runtime
from tvm.runtime import container
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end
......@@ -305,10 +306,8 @@ def test_extern_ccompiler_default_ops():
func = relay.Function([x0, y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler",
tvm.tir.StringImm("ccompiler"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("global_symbol", container.String("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
......@@ -319,7 +318,7 @@ def test_extern_ccompiler_default_ops():
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive",
tvm.tir.IntImm("int32", 1))
tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod["main"] = main
......@@ -393,8 +392,7 @@ def test_extern_dnnl():
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("dnnl_0"))
func = func.with_attr("global_symbol", container.String("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func
......@@ -520,8 +518,8 @@ def test_function_lifting():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
func0 = func0.with_attr("global_symbol",
container.String("test_compiler_0"))
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0
......@@ -539,8 +537,8 @@ def test_function_lifting():
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_1"))
func1 = func1.with_attr("global_symbol",
container.String("test_compiler_1"))
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1
......@@ -613,8 +611,8 @@ def test_function_lifting_inline():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
func0 = func0.with_attr("global_symbol",
container.String("test_compiler_0"))
# main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
......@@ -649,8 +647,7 @@ def test_constant_propagation():
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
func = func.with_attr("global_symbol", container.String("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [y])
......@@ -751,8 +748,8 @@ def test_multiple_outputs():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_2"))
func0 = func0.with_attr("global_symbol",
container.String("test_target_2"))
gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0
......@@ -819,8 +816,8 @@ def test_mixed_single_multiple_outputs():
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_1"))
func1 = func1.with_attr("global_symbol",
container.String("test_target_1"))
gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1
......@@ -834,8 +831,8 @@ def test_mixed_single_multiple_outputs():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0"))
func0 = func0.with_attr("global_symbol",
container.String("test_target_0"))
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment