Unverified Commit 502cf264 by Zhi Committed by GitHub

[Refactor] move vm.py under runtime and adt to runtime.container.py (#4855)

parent 4fce5137
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Container data structures used in TVM DSL.""" """Container data structures used in TVM DSL."""
import tvm._ffi import tvm._ffi
from tvm.runtime import Object, ObjectTypes from tvm.runtime import Object
from tvm.runtime.container import getitem_helper from tvm.runtime.container import getitem_helper
from tvm.runtime import _ffi_node_api from tvm.runtime import _ffi_node_api
from . import _api_internal from . import _api_internal
...@@ -104,56 +104,3 @@ class LoweredFunc(Object): ...@@ -104,56 +104,3 @@ class LoweredFunc(Object):
MixedFunc = 0 MixedFunc = 0
HostFunc = 1 HostFunc = 1
DeviceFunc = 2 DeviceFunc = 2
@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
@property
def tag(self):
return _GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
def __len__(self):
return _GetADTSize(self)
def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
tvm._ffi._init_api("tvm.container")
...@@ -37,7 +37,6 @@ from . import debug ...@@ -37,7 +37,6 @@ from . import debug
from . import param_dict from . import param_dict
from . import feature from . import feature
from .backend import vm from .backend import vm
from .backend import profiler_vm
# Root operators # Root operators
from .op import Op from .op import Op
......
...@@ -20,7 +20,7 @@ from __future__ import absolute_import ...@@ -20,7 +20,7 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from tvm import container from tvm.runtime import container
from . import _backend from . import _backend
from .. import _make, analysis, transform from .. import _make, analysis, transform
from .. import module from .. import module
......
...@@ -32,15 +32,15 @@ OUTPUT_VAR_NAME = '_py_out' ...@@ -32,15 +32,15 @@ OUTPUT_VAR_NAME = '_py_out'
# import numpy # import numpy
# import tvm # import tvm
# from tvm import relay # from tvm import relay
# from tvm import import container as _container
# from tvm import nd # from tvm import nd
# from tvm.runtime import import container as _container
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue # from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [ PROLOGUE = [
ast.Import([alias('numpy', None)]), ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]), ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0), ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0), ast.ImportFrom('tvm', [alias('nd', None)], 0),
ast.ImportFrom('tvm', [alias('container', '_container')], ast.ImportFrom('tvm.runtime', [alias('container', '_container')],
0), 0),
ast.ImportFrom('tvm.relay.backend.interpreter', ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None), [alias('RefValue', None),
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Runtime container structures.""" """Runtime container structures."""
import tvm._ffi
from tvm.runtime import Object, ObjectTypes
def getitem_helper(obj, elem_getter, length, idx): def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function. """Helper function to implement a pythonic getitem function.
...@@ -54,3 +57,56 @@ def getitem_helper(obj, elem_getter, length, idx): ...@@ -54,3 +57,56 @@ def getitem_helper(obj, elem_getter, length, idx):
if idx < 0: if idx < 0:
idx += length idx += length
return elem_getter(obj, idx) return elem_getter(obj, idx)
@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
@property
def tag(self):
return _GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
def __len__(self):
return _GetADTSize(self)
def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
tvm._ffi._init_api("tvm.runtime.container")
...@@ -20,18 +20,19 @@ The Relay Virtual Machine profiler. ...@@ -20,18 +20,19 @@ The Relay Virtual Machine profiler.
Provides extra APIs for profiling vm execution. Provides extra APIs for profiling vm execution.
""" """
from . import vm, _vm from tvm.runtime import _ffi_api
from . import vm
def enabled(): def enabled():
"""Whether vm profiler is enabled.""" """Whether vm profiler is enabled."""
return hasattr(_vm, "_VirtualMachineDebug") return hasattr(_ffi_api, "_VirtualMachineDebug")
class VirtualMachineProfiler(vm.VirtualMachine): class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime.""" """Relay profile VM runtime."""
def __init__(self, mod): def __init__(self, mod):
super(VirtualMachineProfiler, self).__init__(mod) super(VirtualMachineProfiler, self).__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m) self.mod = _ffi_api._VirtualMachineDebug(m)
self._init = self.mod["init"] self._init = self.mod["init"]
self._invoke = self.mod["invoke"] self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"] self._get_stat = self.mod["get_stat"]
......
...@@ -32,14 +32,14 @@ namespace runtime { ...@@ -32,14 +32,14 @@ namespace runtime {
using namespace vm; using namespace vm;
TVM_REGISTER_GLOBAL("container._GetADTTag") TVM_REGISTER_GLOBAL("runtime.container._GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj); const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag()); *rv = static_cast<int64_t>(adt.tag());
}); });
TVM_REGISTER_GLOBAL("container._GetADTSize") TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj); const auto& adt = Downcast<ADT>(obj);
...@@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("container._GetADTSize") ...@@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("container._GetADTSize")
}); });
TVM_REGISTER_GLOBAL("container._GetADTFields") TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
int idx = args[1]; int idx = args[1];
...@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("container._GetADTFields") ...@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("container._GetADTFields")
*rv = adt[idx]; *rv = adt[idx];
}); });
TVM_REGISTER_GLOBAL("container._Tuple") TVM_REGISTER_GLOBAL("runtime.container._Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields; std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) { for (auto i = 0; i < args.size(); ++i) {
...@@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("container._Tuple") ...@@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("container._Tuple")
*rv = ADT::Tuple(fields); *rv = ADT::Tuple(fields);
}); });
TVM_REGISTER_GLOBAL("container._ADT") TVM_REGISTER_GLOBAL("runtime.container._ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0]; int itag = args[0];
size_t tag = static_cast<size_t>(itag); size_t tag = static_cast<size_t>(itag);
......
...@@ -738,7 +738,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { ...@@ -738,7 +738,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
} }
} }
TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
...@@ -746,7 +746,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") ...@@ -746,7 +746,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
*rv = static_cast<int>(exec->global_map.size()); *rv = static_cast<int>(exec->global_map.size());
}); });
TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields") TVM_REGISTER_GLOBAL("runtime.GetGlobalFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
...@@ -763,7 +763,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields") ...@@ -763,7 +763,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields")
*rv = globals[idx].first; *rv = globals[idx].first;
}); });
TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives") TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
...@@ -772,7 +772,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives") ...@@ -772,7 +772,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives")
}); });
TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
...@@ -789,7 +789,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") ...@@ -789,7 +789,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields")
} }
}); });
TVM_REGISTER_GLOBAL("relay._vm.Load_Executable") TVM_REGISTER_GLOBAL("runtime.Load_Executable")
.set_body_typed([]( .set_body_typed([](
std::string code, std::string code,
runtime::Module lib) { runtime::Module lib) {
......
...@@ -133,7 +133,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) { ...@@ -133,7 +133,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
return runtime::Module(vm); return runtime::Module(vm);
} }
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug") TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......
...@@ -1057,7 +1057,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) { ...@@ -1057,7 +1057,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) {
return runtime::Module(vm); return runtime::Module(vm);
} }
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine") TVM_REGISTER_GLOBAL("runtime._VirtualMachine")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
......
...@@ -62,7 +62,7 @@ tf_dtypes = { ...@@ -62,7 +62,7 @@ tf_dtypes = {
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT): elif isinstance(o, tvm.runtime.container.ADT):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
......
...@@ -19,7 +19,9 @@ import numpy as np ...@@ -19,7 +19,9 @@ import numpy as np
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm import relay, container from tvm import relay
from tvm.runtime import container
from tvm.runtime import vm as vm_rt
from tvm.relay import testing from tvm.relay import testing
from tvm.relay import vm from tvm.relay import vm
...@@ -58,7 +60,7 @@ def benchmark_execution(mod, ...@@ -58,7 +60,7 @@ def benchmark_execution(mod,
number=2, repeat=20): number=2, repeat=20):
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
exe = vm.compile(mod, target, params=params) exe = vm.compile(mod, target, params=params)
rly_vm = vm.VirtualMachine(exe) rly_vm = vm_rt.VirtualMachine(exe)
rly_vm.init(ctx) rly_vm.init(ctx)
result = rly_vm.run(data) result = rly_vm.run(data)
......
...@@ -117,7 +117,7 @@ def tree_to_dict(t): ...@@ -117,7 +117,7 @@ def tree_to_dict(t):
def vmobj_to_list(o, dtype="float32"): def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.nd.NDArray): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT): elif isinstance(o, tvm.runtime.container.ADT):
if len(o) == 0: if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype) tensor_nil = p.get_var("tensor_nil", dtype=dtype)
if tensor_nil.tag == o.tag: if tensor_nil.tag == o.tag:
......
...@@ -18,7 +18,8 @@ import numpy as np ...@@ -18,7 +18,8 @@ import numpy as np
import tvm import tvm
import tvm.testing import tvm.testing
from tvm import nd from tvm import nd
from tvm import relay, container from tvm import relay
from tvm.runtime import container
from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor from tvm.relay import testing, create_executor
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
import os import os
import sys import sys
import numpy as np import numpy as np
import pytest
import tvm import tvm
import tvm.relay.testing import tvm.relay.testing
import tvm.relay.transform import tvm.relay.transform
from tvm import relay from tvm import relay
from tvm import runtime
from tvm.contrib import util from tvm.contrib import util
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
...@@ -52,8 +52,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -52,8 +52,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
exe = relay.vm.compile(mod, target=target) exe = relay.vm.compile(mod, target=target)
code, lib = exe.save() code, lib = exe.save()
lib = update_lib(lib) lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib) exe = runtime.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe) vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx) vm.init(ctx)
out = vm.run(**map_inputs) out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
......
...@@ -24,6 +24,7 @@ import tvm ...@@ -24,6 +24,7 @@ import tvm
import tvm.relay.testing import tvm.relay.testing
import tvm.relay.transform as transform import tvm.relay.transform as transform
from tvm import relay from tvm import relay
from tvm import runtime
from tvm.contrib import util from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
...@@ -182,7 +183,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -182,7 +183,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
lib_name = 'lib.so' lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name) lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs) lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.runtime.load_module(lib_path) lib = runtime.load_module(lib_path)
return lib return lib
...@@ -191,8 +192,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -191,8 +192,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
exe = relay.vm.compile(mod, target=target, params=params) exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save() code, lib = exe.save()
lib = update_lib(lib) lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib) exe = runtime.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe) vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx) vm.init(ctx)
out = vm.run(**map_inputs) out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
......
...@@ -19,7 +19,7 @@ import tvm ...@@ -19,7 +19,7 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import to_python, run_as_python from tvm.relay.testing import to_python, run_as_python
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.container import ADT from tvm.runtime.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.backend.interpreter import RefValue, ConstructorValue
# helper: uses a dummy let binding to sequence a list # helper: uses a dummy let binding to sequence a list
......
...@@ -14,16 +14,16 @@ ...@@ -14,16 +14,16 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os import numpy as np
import pytest
import tvm import tvm
import numpy as np from tvm import runtime
from tvm import relay from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay import testing from tvm.relay import testing
import pytest
def check_result(args, expected_result, mod=None): def check_result(args, expected_result, mod=None):
""" """
...@@ -52,14 +52,14 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -52,14 +52,14 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
exe = relay.vm.compile(mod, target) exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe) vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx) vm.init(ctx)
return vm.invoke("main", *args) return vm.invoke("main", *args)
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT): elif isinstance(o, tvm.runtime.container.ADT):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
...@@ -573,7 +573,7 @@ def test_add_op_broadcast(): ...@@ -573,7 +573,7 @@ def test_add_op_broadcast():
def test_vm_optimize(): def test_vm_optimize():
mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18) mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18)
comp = relay.backend.vm.VMCompiler() comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, "llvm", params) opt_mod, _ = comp.optimize(mod, "llvm", params)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm.runtime import vm as _vm
from tvm.relay import vm as rly_vm
from tvm import relay from tvm import relay
from tvm.relay.module import Module as rly_module from tvm.relay.module import Module as rly_module
from tvm.relay import vm as _vm
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.contrib import util from tvm.contrib import util
...@@ -31,11 +32,11 @@ def create_exec(f, target="llvm", params=None): ...@@ -31,11 +32,11 @@ def create_exec(f, target="llvm", params=None):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
executable = _vm.compile(mod, target=target, params=params) executable = rly_vm.compile(mod, target=target, params=params)
return executable return executable
else: else:
assert isinstance(f, relay.Module), "expected mod as relay.Module" assert isinstance(f, relay.Module), "expected mod as relay.Module"
executable = _vm.compile(f, target=target, params=params) executable = rly_vm.compile(f, target=target, params=params)
return executable return executable
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import nd, relay from tvm import nd, relay
from tvm import container as _container from tvm.runtime import container as _container
def test_adt_constructor(): def test_adt_constructor():
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os
import tvm
import numpy as np import numpy as np
import pytest import tvm
from tvm.runtime import profiler_vm
from tvm import relay from tvm import relay
from tvm.relay.testing import resnet from tvm.relay.testing import resnet
...@@ -26,10 +25,10 @@ def test_basic(): ...@@ -26,10 +25,10 @@ def test_basic():
mod, params = resnet.get_workload() mod, params = resnet.get_workload()
target = 'llvm' target = 'llvm'
ctx = tvm.cpu() ctx = tvm.cpu()
if not relay.profiler_vm.enabled(): if not profiler_vm.enabled():
return return
exe = relay.vm.compile(mod, target, params=params) exe = relay.vm.compile(mod, target, params=params)
vm = relay.profiler_vm.VirtualMachineProfiler(exe) vm = profiler_vm.VirtualMachineProfiler(exe)
vm.init(ctx) vm.init(ctx)
data = np.random.rand(1, 3, 224, 224).astype('float32') data = np.random.rand(1, 3, 224, 224).astype('float32')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment