Unverified Commit 3aa103e7 by Tianqi Chen Committed by GitHub

[IR] Initial stab at std::string->String upgrade (#5438)

parent 9bbf58ab
......@@ -40,7 +40,7 @@ class SourceName;
class SourceNameNode : public Object {
public:
/*! \brief The source name. */
std::string name;
String name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
......@@ -64,7 +64,7 @@ class SourceName : public ObjectRef {
* \param name Name of the operator.
* \return SourceName valid throughout program lifetime.
*/
TVM_DLL static SourceName Get(const std::string& name);
TVM_DLL static SourceName Get(const String& name);
TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
};
......
......@@ -227,7 +227,7 @@ class TypeVarNode : public TypeNode {
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;
......@@ -263,7 +263,7 @@ class TypeVar : public Type {
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL TypeVar(std::string name_hint, TypeKind kind);
TVM_DLL TypeVar(String name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
......
......@@ -16,6 +16,9 @@
# under the License.
"""Tool to upgrade json from historical versions."""
import json
import tvm.ir
import tvm.runtime
def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data.
......@@ -41,8 +44,12 @@ def create_updater(node_map, from_ver, to_ver):
nodes = data["nodes"]
for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None)
if f:
nodes[idx] = f(item, nodes)
if isinstance(f, list):
for fpass in f:
item = fpass(item, nodes)
elif f:
item = f(item, nodes)
nodes[idx] = item
data["attrs"]["tvm_version"] = to_ver
return data
return _updater
......@@ -84,12 +91,26 @@ def create_updater_06_to_07():
del item["global_key"]
return item
def _update_from_std_str(key):
def _convert(item, nodes):
str_val = item["attrs"][key]
jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val)))
root_idx = jdata["root"]
val = jdata["nodes"][root_idx]
sidx = len(nodes)
nodes.append(val)
item["attrs"][key] = '%d' % sidx
return item
return _convert
node_map = {
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var,
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
......
......@@ -25,10 +25,10 @@
namespace tvm {
ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
ObjectPtr<Object> GetSourceNameNode(const String& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, ObjectPtr<SourceNameNode> > source_map;
static std::unordered_map<String, ObjectPtr<SourceNameNode> > source_map;
auto sn = source_map.find(name);
if (sn == source_map.end()) {
......@@ -41,7 +41,11 @@ ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
}
}
SourceName SourceName::Get(const std::string& name) {
ObjectPtr<Object> GetSourceNameNodeByStr(const std::string& name) {
return GetSourceNameNode(name);
}
SourceName SourceName::Get(const String& name) {
return SourceName(GetSourceNameNode(name));
}
......@@ -55,10 +59,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_repr_bytes([](const Object* n) {
.set_creator(GetSourceNameNodeByStr)
.set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const SourceNameNode*>(n)->name;
});
});
Span SpanNode::make(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>();
......
......@@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TypeVar::TypeVar(std::string name, TypeKind kind) {
TypeVar::TypeVar(String name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
......@@ -76,7 +76,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("ir.TypeVar")
.set_body_typed([](std::string name, int kind) {
.set_body_typed([](String name, int kind) {
return TypeVar(name, static_cast<TypeKind>(kind));
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment