Unverified Commit 67b97e5a by Tianqi Chen Committed by GitHub

[TOOLS] JSON upgrader to upgrade serialized json. (#4730)

During Unified IR refactor we will change the structure of IRs.
This will cause certain historical modules stored via json no longer
able to be loaded by the current version.

This PR introduces a backward compatible layer to try its best effort
to upgrade json from previous version(this case 0.6) to the current version.
We mainly aim to support update of high-level ir(relay).
parent a5bb789a
...@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs ...@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
from numbers import Integral as _Integral from numbers import Integral as _Integral
from ._ffi.base import string_types from ._ffi.base import string_types, TVMError
from ._ffi.object import register_object, Object from ._ffi.object import register_object, Object
from ._ffi.object import convert_to_object as _convert_to_object from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.object_generic import _scalar_type_inference from ._ffi.object_generic import _scalar_type_inference
...@@ -35,6 +35,7 @@ from . import tensor as _tensor ...@@ -35,6 +35,7 @@ from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from . import container as _container from . import container as _container
from . import tag as _tag from . import tag as _tag
from . import json_compact
int8 = "int8" int8 = "int8"
int32 = "int32" int32 = "int32"
...@@ -154,7 +155,12 @@ def load_json(json_str): ...@@ -154,7 +155,12 @@ def load_json(json_str):
node : Object node : Object
The loaded tvm node. The loaded tvm node.
""" """
return _api_internal._load_json(json_str)
try:
return _api_internal._load_json(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return _api_internal._load_json(json_str)
def save_json(node): def save_json(node):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tool to upgrade json from historical versions."""
import json
def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data.
Parameters
----------
node_map : Map[str, Function]
Map from type_key to updating function
from_ver : str
Prefix of version that we can accept,
to_ver : str
The target version.
Returns
-------
fupdater : function
The updater function
"""
def _updater(data):
assert data["attrs"]["tvm_version"].startswith(from_ver)
nodes = data["nodes"]
for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None)
if f:
nodes[idx] = f(item, nodes)
data["attrs"]["tvm_version"] = to_ver
return data
return _updater
def create_updater_06_to_07():
"""Create an update to upgrade json from v0.6 to v0.7
Returns
-------
fupdater : function
The updater function
"""
def _ftype_var(item, nodes):
vindex = int(item["attrs"]["var"])
item["attrs"]["name_hint"] = nodes[vindex]["attrs"]["name"]
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
return item
node_map = {
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
}
return create_updater(node_map, "0.6", "0.7")
def upgrade_json(json_str):
"""Update json from a historical version.
Parameters
----------
json_str : str
A historical json file.
Returns
-------
updated_json : str
The updated version.
"""
data = json.loads(json_str)
from_version = data["attrs"]["tvm_version"]
if from_version.startswith("0.6"):
data = create_updater_06_to_07()(data)
else:
raise ValueError("Cannot update from version %s" % from_version)
return json.dumps(data, indent=2)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
import json
def test_type_var():
# type var in 0.6
nodes = [
{"type_key": ""},
{"type_key": "relay.TypeVar",
"attrs": {"kind": "0", "span": "0", "var": "2"}},
{"type_key": "Variable",
"attrs": {"dtype": "int32", "name": "in0"}},
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.load_json(json.dumps(data))
assert isinstance(tvar, relay.TypeVar)
assert tvar.name_hint == "in0"
nodes[1]["type_key"] = "relay.GlobalTypeVar"
tvar = tvm.load_json(json.dumps(data))
assert isinstance(tvar, relay.GlobalTypeVar)
assert tvar.name_hint == "in0"
if __name__ == "__main__":
test_type_var()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment