Unverified Commit 3aa103e7 by Tianqi Chen Committed by GitHub

[IR] Initial stab at std::string->String upgrade (#5438)

parent 9bbf58ab
...@@ -40,7 +40,7 @@ class SourceName; ...@@ -40,7 +40,7 @@ class SourceName;
class SourceNameNode : public Object { class SourceNameNode : public Object {
public: public:
/*! \brief The source name. */ /*! \brief The source name. */
std::string name; String name;
// override attr visitor // override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
...@@ -64,7 +64,7 @@ class SourceName : public ObjectRef { ...@@ -64,7 +64,7 @@ class SourceName : public ObjectRef {
* \param name Name of the operator. * \param name Name of the operator.
* \return SourceName valid throughout program lifetime. * \return SourceName valid throughout program lifetime.
*/ */
TVM_DLL static SourceName Get(const std::string& name); TVM_DLL static SourceName Get(const String& name);
TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
}; };
......
...@@ -227,7 +227,7 @@ class TypeVarNode : public TypeNode { ...@@ -227,7 +227,7 @@ class TypeVarNode : public TypeNode {
* this only acts as a hint to the user, * this only acts as a hint to the user,
* and is not used for equality. * and is not used for equality.
*/ */
std::string name_hint; String name_hint;
/*! \brief The kind of type parameter */ /*! \brief The kind of type parameter */
TypeKind kind; TypeKind kind;
...@@ -263,7 +263,7 @@ class TypeVar : public Type { ...@@ -263,7 +263,7 @@ class TypeVar : public Type {
* \param name_hint The name of the type var. * \param name_hint The name of the type var.
* \param kind The kind of the type var. * \param kind The kind of the type var.
*/ */
TVM_DLL TypeVar(std::string name_hint, TypeKind kind); TVM_DLL TypeVar(String name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
}; };
......
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
# under the License. # under the License.
"""Tool to upgrade json from historical versions.""" """Tool to upgrade json from historical versions."""
import json import json
import tvm.ir
import tvm.runtime
def create_updater(node_map, from_ver, to_ver): def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data. """Create an updater to update json loaded data.
...@@ -41,8 +44,12 @@ def create_updater(node_map, from_ver, to_ver): ...@@ -41,8 +44,12 @@ def create_updater(node_map, from_ver, to_ver):
nodes = data["nodes"] nodes = data["nodes"]
for idx, item in enumerate(nodes): for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None) f = node_map.get(item["type_key"], None)
if f: if isinstance(f, list):
nodes[idx] = f(item, nodes) for fpass in f:
item = fpass(item, nodes)
elif f:
item = f(item, nodes)
nodes[idx] = item
data["attrs"]["tvm_version"] = to_ver data["attrs"]["tvm_version"] = to_ver
return data return data
return _updater return _updater
...@@ -84,12 +91,26 @@ def create_updater_06_to_07(): ...@@ -84,12 +91,26 @@ def create_updater_06_to_07():
del item["global_key"] del item["global_key"]
return item return item
def _update_from_std_str(key):
def _convert(item, nodes):
str_val = item["attrs"][key]
jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val)))
root_idx = jdata["root"]
val = jdata["nodes"][root_idx]
sidx = len(nodes)
nodes.append(val)
item["attrs"][key] = '%d' % sidx
return item
return _convert
node_map = { node_map = {
# Base IR # Base IR
"SourceName": _update_global_key, "SourceName": _update_global_key,
"EnvFunc": _update_global_key, "EnvFunc": _update_global_key,
"relay.Op": _update_global_key, "relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var, "relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.GlobalTypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"), "relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"), "relay.TupleType": _rename("TupleType"),
......
...@@ -25,10 +25,10 @@ ...@@ -25,10 +25,10 @@
namespace tvm { namespace tvm {
ObjectPtr<Object> GetSourceNameNode(const std::string& name) { ObjectPtr<Object> GetSourceNameNode(const String& name) {
// always return pointer as the reference can change as map re-allocate. // always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr // or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, ObjectPtr<SourceNameNode> > source_map; static std::unordered_map<String, ObjectPtr<SourceNameNode> > source_map;
auto sn = source_map.find(name); auto sn = source_map.find(name);
if (sn == source_map.end()) { if (sn == source_map.end()) {
...@@ -41,7 +41,11 @@ ObjectPtr<Object> GetSourceNameNode(const std::string& name) { ...@@ -41,7 +41,11 @@ ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
} }
} }
SourceName SourceName::Get(const std::string& name) { ObjectPtr<Object> GetSourceNameNodeByStr(const std::string& name) {
return GetSourceNameNode(name);
}
SourceName SourceName::Get(const String& name) {
return SourceName(GetSourceNameNode(name)); return SourceName(GetSourceNameNode(name));
} }
...@@ -55,10 +59,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -55,10 +59,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
TVM_REGISTER_NODE_TYPE(SourceNameNode) TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode) .set_creator(GetSourceNameNodeByStr)
.set_repr_bytes([](const Object* n) { .set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const SourceNameNode*>(n)->name; return static_cast<const SourceNameNode*>(n)->name;
}); });
Span SpanNode::make(SourceName source, int lineno, int col_offset) { Span SpanNode::make(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>(); auto n = make_object<SpanNode>();
......
...@@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
TypeVar::TypeVar(std::string name, TypeKind kind) { TypeVar::TypeVar(String name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>(); ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name); n->name_hint = std::move(name);
n->kind = std::move(kind); n->kind = std::move(kind);
...@@ -76,7 +76,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) { ...@@ -76,7 +76,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("ir.TypeVar") TVM_REGISTER_GLOBAL("ir.TypeVar")
.set_body_typed([](std::string name, int kind) { .set_body_typed([](String name, int kind) {
return TypeVar(name, static_cast<TypeKind>(kind)); return TypeVar(name, static_cast<TypeKind>(kind));
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment