Unverified Commit 502cf264 by Zhi Committed by GitHub

[Refactor] move vm.py under runtime and adt to runtime.container.py (#4855)

parent 4fce5137
......@@ -17,7 +17,7 @@
"""Container data structures used in TVM DSL."""
import tvm._ffi
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import Object
from tvm.runtime.container import getitem_helper
from tvm.runtime import _ffi_node_api
from . import _api_internal
......@@ -104,56 +104,3 @@ class LoweredFunc(Object):
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
@property
def tag(self):
return _GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
def __len__(self):
return _GetADTSize(self)
def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
tvm._ffi._init_api("tvm.container")
......@@ -37,7 +37,6 @@ from . import debug
from . import param_dict
from . import feature
from .backend import vm
from .backend import profiler_vm
# Root operators
from .op import Op
......
......@@ -20,7 +20,7 @@ from __future__ import absolute_import
import numpy as np
from tvm import container
from tvm.runtime import container
from . import _backend
from .. import _make, analysis, transform
from .. import module
......
......@@ -32,15 +32,15 @@ OUTPUT_VAR_NAME = '_py_out'
# import numpy
# import tvm
# from tvm import relay
# from tvm import import container as _container
# from tvm import nd
# from tvm.runtime import import container as _container
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0),
ast.ImportFrom('tvm', [alias('container', '_container')],
ast.ImportFrom('tvm.runtime', [alias('container', '_container')],
0),
ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None),
......
......@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Runtime container structures."""
import tvm._ffi
from tvm.runtime import Object, ObjectTypes
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
......@@ -54,3 +57,56 @@ def getitem_helper(obj, elem_getter, length, idx):
if idx < 0:
idx += length
return elem_getter(obj, idx)
@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
@property
def tag(self):
return _GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
def __len__(self):
return _GetADTSize(self)
def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
tvm._ffi._init_api("tvm.runtime.container")
......@@ -20,18 +20,19 @@ The Relay Virtual Machine profiler.
Provides extra APIs for profiling vm execution.
"""
from . import vm, _vm
from tvm.runtime import _ffi_api
from . import vm
def enabled():
"""Whether vm profiler is enabled."""
return hasattr(_vm, "_VirtualMachineDebug")
return hasattr(_ffi_api, "_VirtualMachineDebug")
class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
def __init__(self, mod):
super(VirtualMachineProfiler, self).__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self.mod = _ffi_api._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"]
......
......@@ -32,14 +32,14 @@ namespace runtime {
using namespace vm;
TVM_REGISTER_GLOBAL("container._GetADTTag")
TVM_REGISTER_GLOBAL("runtime.container._GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});
TVM_REGISTER_GLOBAL("container._GetADTSize")
TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
......@@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("container._GetADTSize")
});
TVM_REGISTER_GLOBAL("container._GetADTFields")
TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
......@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("container._GetADTFields")
*rv = adt[idx];
});
TVM_REGISTER_GLOBAL("container._Tuple")
TVM_REGISTER_GLOBAL("runtime.container._Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
......@@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("container._Tuple")
*rv = ADT::Tuple(fields);
});
TVM_REGISTER_GLOBAL("container._ADT")
TVM_REGISTER_GLOBAL("runtime.container._ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
......
......@@ -738,7 +738,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
}
}
TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......@@ -746,7 +746,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
*rv = static_cast<int>(exec->global_map.size());
});
TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields")
TVM_REGISTER_GLOBAL("runtime.GetGlobalFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......@@ -763,7 +763,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields")
*rv = globals[idx].first;
});
TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives")
TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......@@ -772,7 +772,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives")
});
TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields")
TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......@@ -789,7 +789,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields")
}
});
TVM_REGISTER_GLOBAL("relay._vm.Load_Executable")
TVM_REGISTER_GLOBAL("runtime.Load_Executable")
.set_body_typed([](
std::string code,
runtime::Module lib) {
......
......@@ -133,7 +133,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug")
TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......
......@@ -1057,7 +1057,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) {
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine")
TVM_REGISTER_GLOBAL("runtime._VirtualMachine")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......
......@@ -62,7 +62,7 @@ tf_dtypes = {
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT):
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
......
......@@ -19,7 +19,9 @@ import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm import relay, container
from tvm import relay
from tvm.runtime import container
from tvm.runtime import vm as vm_rt
from tvm.relay import testing
from tvm.relay import vm
......@@ -58,7 +60,7 @@ def benchmark_execution(mod,
number=2, repeat=20):
with relay.build_config(opt_level=3):
exe = vm.compile(mod, target, params=params)
rly_vm = vm.VirtualMachine(exe)
rly_vm = vm_rt.VirtualMachine(exe)
rly_vm.init(ctx)
result = rly_vm.run(data)
......
......@@ -117,7 +117,7 @@ def tree_to_dict(t):
def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT):
elif isinstance(o, tvm.runtime.container.ADT):
if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype)
if tensor_nil.tag == o.tag:
......
......@@ -18,7 +18,8 @@ import numpy as np
import tvm
import tvm.testing
from tvm import nd
from tvm import relay, container
from tvm import relay
from tvm.runtime import container
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
......
......@@ -18,12 +18,12 @@
import os
import sys
import numpy as np
import pytest
import tvm
import tvm.relay.testing
import tvm.relay.transform
from tvm import relay
from tvm import runtime
from tvm.contrib import util
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
......@@ -52,8 +52,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
exe = relay.vm.compile(mod, target=target)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
......
......@@ -24,6 +24,7 @@ import tvm
import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator
......@@ -182,7 +183,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.runtime.load_module(lib_path)
lib = runtime.load_module(lib_path)
return lib
......@@ -191,8 +192,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
......
......@@ -19,7 +19,7 @@ import tvm
from tvm import relay
from tvm.relay.testing import to_python, run_as_python
from tvm.relay.prelude import Prelude
from tvm.container import ADT
from tvm.runtime.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
# helper: uses a dummy let binding to sequence a list
......
......@@ -14,16 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import numpy as np
import pytest
import tvm
import numpy as np
from tvm import runtime
from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude
from tvm.relay import testing
import pytest
def check_result(args, expected_result, mod=None):
"""
......@@ -52,14 +52,14 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe)
vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx)
return vm.invoke("main", *args)
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT):
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
......@@ -573,7 +573,7 @@ def test_add_op_broadcast():
def test_vm_optimize():
mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18)
comp = relay.backend.vm.VMCompiler()
comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, "llvm", params)
if __name__ == "__main__":
......
......@@ -19,9 +19,10 @@
import numpy as np
import tvm
from tvm.runtime import vm as _vm
from tvm.relay import vm as rly_vm
from tvm import relay
from tvm.relay.module import Module as rly_module
from tvm.relay import vm as _vm
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
from tvm.contrib import util
......@@ -31,11 +32,11 @@ def create_exec(f, target="llvm", params=None):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
executable = _vm.compile(mod, target=target, params=params)
executable = rly_vm.compile(mod, target=target, params=params)
return executable
else:
assert isinstance(f, relay.Module), "expected mod as relay.Module"
executable = _vm.compile(f, target=target, params=params)
executable = rly_vm.compile(f, target=target, params=params)
return executable
......
......@@ -18,7 +18,7 @@
import numpy as np
import tvm
from tvm import nd, relay
from tvm import container as _container
from tvm.runtime import container as _container
def test_adt_constructor():
......
......@@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tvm
import numpy as np
import pytest
import tvm
from tvm.runtime import profiler_vm
from tvm import relay
from tvm.relay.testing import resnet
......@@ -26,10 +25,10 @@ def test_basic():
mod, params = resnet.get_workload()
target = 'llvm'
ctx = tvm.cpu()
if not relay.profiler_vm.enabled():
if not profiler_vm.enabled():
return
exe = relay.vm.compile(mod, target, params=params)
vm = relay.profiler_vm.VirtualMachineProfiler(exe)
vm = profiler_vm.VirtualMachineProfiler(exe)
vm.init(ctx)
data = np.random.rand(1, 3, 224, 224).astype('float32')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment