Unverified Commit e0122c0e by Tianqi Chen Committed by GitHub

[REFACTOR][PY][API-Change] Polish tvm.runtime, tvm.runtime.module API update (#4837)

* [REFACTOR][PY-API] Polish tvm.runtime, tvm.runtime.module API update

This PR updates the tvm.runtime to use the new FFI style.

- Remove top-level tvm.module to avoid confusion between runtime.Module and IRModule
- API changes wrt to runtime.Module
  - tvm.module.load -> tvm.runtime.load_module
  - tvm.module.enabled -> tvm.runtime.enabled
  - tvm.module.system_lib -> tvm.runtime.system_lib
- Remove dep on api_internal from runtime.

* Update module.load in the latest API
parent 30b7d836
......@@ -34,7 +34,7 @@ TVM_BUNDLE_FUNCTION void *tvm_runtime_create() {
const std::string json_data(&build_graph_json[0],
&build_graph_json[0] + build_graph_json_len);
tvm::runtime::Module mod_syslib =
(*tvm::runtime::Registry::Get("module._GetSystemLib"))();
(*tvm::runtime::Registry::Get("runtime.SystemLib"))();
int device_type = kDLCPU;
int device_id = 0;
tvm::runtime::Module mod =
......
......@@ -19,7 +19,7 @@
Example Plugin Module
=====================
This folder contains an example that implements a C++ module
that can be directly loaded as TVM's DSOModule (via tvm.module.load)
that can be directly loaded as TVM's DSOModule (via tvm.runtime.load_module)
## Guideline
......
......@@ -19,7 +19,7 @@ import os
def test_plugin_module():
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
mod = tvm.module.load(os.path.join(curr_path, "lib", "plugin_module.so"))
mod = tvm.runtime.load_module(os.path.join(curr_path, "lib", "plugin_module.so"))
# NOTE: we need to make sure all managed resources returned
# from mod get destructed before mod get unloaded.
#
......
......@@ -30,7 +30,7 @@ def test_ext_dev():
B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_llvm():
if not tvm.module.enabled("llvm"):
if not tvm.runtime.enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
ctx = tvm.ext_dev(0)
......@@ -74,7 +74,7 @@ def test_extern_call():
s = tvm.create_schedule(B.op)
def check_llvm():
if not tvm.module.enabled("llvm"):
if not tvm.runtime.enabled("llvm"):
return
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
......
......@@ -79,7 +79,7 @@ int main(void) {
// For libraries that are directly packed as system lib and linked together with the app
// We can directly use GetSystemLib to get the system wide library.
LOG(INFO) << "Verify load function from system lib";
tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))();
tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("runtime.SystemLib"))();
Verify(mod_syslib, "addonesys");
return 0;
}
......@@ -40,7 +40,7 @@ def verify(mod, fname):
if __name__ == "__main__":
# The normal dynamic loading method for deployment
mod_dylib = tvm.module.load("lib/test_addone_dll.so")
mod_dylib = tvm.runtime.load_module("lib/test_addone_dll.so")
print("Verify dynamic loading from test_addone_dll.so")
verify(mod_dylib, "addone")
# There might be methods to use the system lib way in
......
......@@ -23,7 +23,7 @@ CWD = osp.abspath(osp.dirname(__file__))
def main():
ctx = tvm.context('cpu', 0)
model = tvm.module.load(osp.join(CWD, 'build', 'enclave.signed.so'))
model = tvm.runtime.load_module(osp.join(CWD, 'build', 'enclave.signed.so'))
inp = tvm.nd.array(np.ones((1, 3, 224, 224), dtype='float32'), ctx)
out = tvm.nd.array(np.empty((1, 1000), dtype='float32'), ctx)
model(inp, out)
......
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Framework Bridge APIs
---------------------
tvm.contrib.mxnet
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.mxnet
:members:
......@@ -15,8 +15,8 @@
specific language governing permissions and limitations
under the License.
Additional Contrib APIs
-----------------------
tvm.contrib
-----------
.. automodule:: tvm.contrib
tvm.contrib.cblas
......@@ -43,6 +43,11 @@ tvm.contrib.cublas
:members:
tvm.contrib.dlpack
~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.dlpack
:members:
tvm.contrib.emscripten
~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.emscripten
......@@ -53,6 +58,11 @@ tvm.contrib.miopen
.. automodule:: tvm.contrib.miopen
:members:
tvm.contrib.mxnet
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.mxnet
:members:
tvm.contrib.ndk
~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.ndk
......@@ -118,7 +128,6 @@ tvm.contrib.util
:members:
tvm.contrib.xcode
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.xcode
......
......@@ -20,14 +20,7 @@ Developer API
This page contains modules that are used by developers of TVM.
Many of these APIs are PackedFunc registered in C++ backend.
tvm.object
~~~~~~~~~~
.. automodule:: tvm.object
.. autoclass:: tvm.object.Object
:members:
.. autofunction:: tvm.register_object
tvm.expr
~~~~~~~~
......
......@@ -22,6 +22,8 @@ Python API
:maxdepth: 2
tvm
runtime
ndarray
intrin
tensor
schedule
......@@ -29,7 +31,6 @@ Python API
build
module
error
ndarray
container
function
autotvm
......@@ -37,6 +38,7 @@ Python API
rpc
bridge
contrib
ffi
dev
topi
vta/index
......
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.module
----------
.. automodule:: tvm.module
:members:
......@@ -15,22 +15,22 @@
specific language governing permissions and limitations
under the License.
tvm.ndarray
-----------
.. automodule:: tvm.ndarray
tvm.runtime.ndarray
-------------------
.. automodule:: tvm.runtime.ndarray
.. autoclass:: tvm.ndarray.TVMContext
.. autoclass:: tvm.nd.NDArray
:members:
:inherited-members:
.. autoclass:: tvm.ndarray.NDArray
.. autoclass:: tvm.runtime.TVMContext
:members:
:inherited-members:
.. autofunction:: tvm.context
.. autofunction:: tvm.cpu
.. autofunction:: tvm.gpu
.. autofunction:: tvm.opencl
.. autofunction:: tvm.metal
.. autofunction:: tvm.ndarray.array
.. autofunction:: tvm.ndarray.empty
.. autofunction:: tvm.register_extension
.. autofunction:: tvm.nd.array
.. autofunction:: tvm.nd.empty
......@@ -15,10 +15,31 @@
specific language governing permissions and limitations
under the License.
tvm.Function
------------
.. autoclass:: tvm.Function
tvm.runtime
-----------
.. automodule:: tvm.runtime
.. autoclass:: tvm.runtime.PackedFunc
:members:
.. autofunction:: tvm.register_func
.. autofunction:: tvm.get_global_func
.. autoclass:: tvm.runtime.Module
:members:
.. autofunction:: tvm.runtime.load_module
.. autofunction:: tvm.runtime.system_lib
.. autofunction:: tvm.runtime.enabled
.. autoclass:: tvm.runtime.Object
:members:
.. autofunction:: tvm.register_object
......@@ -57,8 +57,8 @@ import os
tgt="aocl_sw_emu"
fadd = tvm.module.load("myadd.so")
fadd_dev = tvm.module.load("myadd.aocx")
fadd = tvm.runtime.load("myadd.so")
fadd_dev = tvm.runtime.load("myadd.aocx")
fadd.import_module(fadd_dev)
ctx = tvm.context(tgt, 0)
......
......@@ -57,11 +57,11 @@ import os
tgt="sdaccel"
fadd = tvm.module.load("myadd.so")
fadd = tvm.runtime.load("myadd.so")
if os.environ.get("XCL_EMULATION_MODE"):
fadd_dev = tvm.module.load("myadd.xclbin")
fadd_dev = tvm.runtime.load("myadd.xclbin")
else:
fadd_dev = tvm.module.load("myadd.awsxclbin")
fadd_dev = tvm.runtime.load("myadd.awsxclbin")
fadd.import_module(fadd_dev)
ctx = tvm.context(tgt, 0)
......
......@@ -53,7 +53,7 @@ Let us build one ResNet-18 workload for GPU as an example first.
resnet18_lib.export_library(path_lib)
# load it back
loaded_lib = tvm.module.load(path_lib)
loaded_lib = tvm.runtime.load(path_lib)
assert loaded_lib.type_key == "library"
assert loaded_lib.imported_modules[0].type_key == "cuda"
......@@ -177,7 +177,7 @@ support arbitrary modules to import ideally.
Deserialization
****************
The entrance API is ``tvm.module.load``. This function
The entrance API is ``tvm.runtime.load``. This function
is to call ``_LoadFromFile`` in fact. If we dig it a little deeper, this is
``Module::LoadFromFile``. In our example, the file is ``deploy.so``,
according to the function logic, we will call ``module.loadfile_so`` in
......
......@@ -903,7 +903,7 @@ We also need to register this function to enable the corresponding Python API:
TVM_REGISTER_GLOBAL("module.loadbinary_examplejson")
.set_body_typed(ExampleJsonModule::LoadFromBinary);
The above registration means when users call ``tvm.module.load(lib_path)`` API and the exported library has an ExampleJSON stream, our ``LoadFromBinary`` will be invoked to create the same customized runtime module.
The above registration means when users call ``tvm.runtime.load(lib_path)`` API and the exported library has an ExampleJSON stream, our ``LoadFromBinary`` will be invoked to create the same customized runtime module.
In addition, if you want to support module creation directly from an ExampleJSON file, you can also implement a simple function and register a Python API as follows:
......@@ -928,7 +928,7 @@ In addition, if you want to support module creation directly from an ExampleJSON
*rv = ExampleJsonModule::Create(args[0]);
});
It means users can manually write/modify an ExampleJSON file, and use Python API ``tvm.module.load("mysubgraph.examplejson", "examplejson")`` to construct a customized module.
It means users can manually write/modify an ExampleJSON file, and use Python API ``tvm.runtime.load("mysubgraph.examplejson", "examplejson")`` to construct a customized module.
*******
Summary
......@@ -952,7 +952,7 @@ In summary, here is a checklist for you to refer:
* ``Run`` to execute a subgraph.
* Register a runtime creation API.
* ``SaveToBinary`` and ``LoadFromBinary`` to serialize/deserialize customized runtime module.
* Register ``LoadFromBinary`` API to support ``tvm.module.load(your_module_lib_path)``.
* Register ``LoadFromBinary`` API to support ``tvm.runtime.load(your_module_lib_path)``.
* (optional) ``Create`` to support customized runtime module construction from subgraph file in your representation.
* An annotator to annotate a user Relay program to make use of your compiler and runtime (TBA).
......@@ -211,6 +211,13 @@ class TVM_DLL ModuleNode : public Object {
std::shared_ptr<PackedFunc> > import_cache_;
};
/*!
* \brief Check if runtime module is enabled for target.
* \param target The target module name.
* \return Whether runtime is enabled.
*/
TVM_DLL bool RuntimeEnabled(const std::string& target);
/*! \brief namespace for constant symbols */
namespace symbol {
/*! \brief Global variable to store module context. */
......
......@@ -20,7 +20,7 @@ import tvm
from tvm.contrib import cc, util
def test_add(target_dir):
if not tvm.module.enabled("cuda"):
if not tvm.runtime.enabled("cuda"):
print("skip %s because cuda is not enabled..." % __file__)
return
n = tvm.var("n")
......
......@@ -29,12 +29,8 @@ from ._ffi.registry import register_object, register_func, register_extension
# top-level alias
# tvm.runtime
from .runtime.object import Object
from .runtime.packed_func import PackedFunc as Function
from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import module
from .runtime import ndarray
# pylint: disable=reimported
from .runtime import ndarray as nd
# others
......
......@@ -92,3 +92,22 @@ class ObjectBase(object):
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
def same_as(self, other):
"""Check object identity.
Parameters
----------
other : object
The other object to compare against.
Returns
-------
result : bool
The comparison result.
"""
if not isinstance(other, ObjectBase):
return False
if self.handle is None:
return other.handle is None
return self.handle.value == other.handle.value
......@@ -99,3 +99,20 @@ cdef class ObjectBase:
(<PackedFuncBase>fconstructor).chandle,
kTVMObjectHandle, args, &chandle)
self.chandle = chandle
def same_as(self, other):
"""Check object identity.
Parameters
----------
other : object
The other object to compare against.
Returns
-------
result : bool
The comparison result.
"""
if not isinstance(other, ObjectBase):
return False
return self.chandle == (<ObjectBase>other).chandle
......@@ -19,6 +19,7 @@
from numbers import Integral as _Integral
import tvm._ffi
import tvm.runtime._ffi_node_api
from tvm.runtime import convert, const, DataType
from ._ffi.base import string_types, TVMError
......@@ -108,10 +109,10 @@ def load_json(json_str):
"""
try:
return _api_internal._load_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return _api_internal._load_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
def save_json(node):
......@@ -127,7 +128,7 @@ def save_json(node):
json_str : str
Saved json string.
"""
return _api_internal._save_json(node)
return tvm.runtime._ffi_node_api.SaveJSON(node)
def var(name="tindex", dtype=int32):
......
......@@ -21,6 +21,7 @@ LoweredFunc and compiled Module.
"""
import warnings
import tvm._ffi
import tvm.runtime
from tvm.runtime import Object, ndarray
from . import api
......@@ -31,7 +32,6 @@ from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import container
from . import module
from . import codegen
from . import target as _target
from . import make
......@@ -628,7 +628,7 @@ def build(inputs,
target_host = tar
break
if not target_host:
target_host = "llvm" if module.enabled("llvm") else "stackvm"
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
fhost_all = []
device_modules = []
......
......@@ -19,6 +19,7 @@ import tvm._ffi
from tvm.runtime import Object, ObjectTypes
from tvm.runtime.container import getitem_helper
from tvm.runtime import _ffi_node_api
from . import _api_internal
......@@ -33,10 +34,10 @@ class Array(Object):
"""
def __getitem__(self, idx):
return getitem_helper(
self, _api_internal._ArrayGetItem, len(self), idx)
self, _ffi_node_api.ArrayGetItem, len(self), idx)
def __len__(self):
return _api_internal._ArraySize(self)
return _ffi_node_api.ArraySize(self)
@tvm._ffi.register_object
......@@ -62,18 +63,18 @@ class Map(Object):
You can use convert to create a dict[Object-> Object] into a Map
"""
def __getitem__(self, k):
return _api_internal._MapGetItem(self, k)
return _ffi_node_api.MapGetItem(self, k)
def __contains__(self, k):
return _api_internal._MapCount(self, k) != 0
return _ffi_node_api.MapCount(self, k) != 0
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
akvs = _ffi_node_api.MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
def __len__(self):
return _api_internal._MapSize(self)
return _ffi_node_api.MapSize(self)
@tvm._ffi.register_object
......@@ -84,7 +85,7 @@ class StrMap(Map):
"""
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
akvs = _ffi_node_api.MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
......
......@@ -269,7 +269,7 @@ def save_tensors(params):
param_bytes: bytearray
Serialized parameters.
"""
_save_tensors = tvm.get_global_func("_save_param_dict")
_save_tensors = tvm.get_global_func("tvm.relay._save_param_dict")
args = []
for k, v in params.items():
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from .. import ndarray
from tvm.runtime import ndarray
def convert_func(tvm_func, tensor_type, to_dlpack_func):
"""Convert a tvm function into one that accepts a tensor from another
......
......@@ -17,8 +17,9 @@
"""MXNet bridge wrap Function MXNet's async function."""
from __future__ import absolute_import as _abs
from .. import api, _api_internal, ndarray
from ..module import Module
import tvm._ffi.registry
import tvm.runtime._ffi_api
from tvm.runtime import Module
# pylint: disable=invalid-name
_wrap_async = None
......@@ -60,7 +61,7 @@ def to_mxnet_func(func, const_loc=None):
"MXTVMBridge not exist in mxnet package,"
" please update to latest version")
fdict = api.extract_ext_funcs(mxnet.base._LIB.MXTVMBridge)
fdict = tvm._ffi.registry.extract_ext_funcs(mxnet.base._LIB.MXTVMBridge)
ret = fdict["WrapAsyncCall"]
ret.is_global = True
return ret
......@@ -69,7 +70,8 @@ def to_mxnet_func(func, const_loc=None):
if _wrap_async is None:
# Register extension type in first time
_wrap_async = _get_bridge_func()
ndarray.register_extension(mxnet.nd.NDArray)
tvm._ffi.registry.register_extension(mxnet.nd.NDArray)
const_loc = const_loc if const_loc else []
return _wrap_async(func, _api_internal._TVMSetStream, len(const_loc), *const_loc)
return _wrap_async(func, tvm.runtime._ffi_api.TVMSetStream,
len(const_loc), *const_loc)
......@@ -16,12 +16,13 @@
# under the License.
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import numpy as _np
from tvm.runtime import ndarray as _nd
from .. import expr as _expr
from .. import api as _api
from .. import tensor as _tensor
from .. import ndarray as _nd
float32 = "float32"
itype = 'int32'
......
......@@ -32,7 +32,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
"""
# pylint: disable=missing-docstring
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode
from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const
from . import make as _make
from . import generic as _generic
......@@ -101,7 +101,7 @@ class ExprOp(object):
return _make._OpFloorMod(self, other)
def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
neg_one = const(-1, self.dtype)
return self.__mul__(neg_one)
def __lshift__(self, other):
......
......@@ -138,7 +138,7 @@ def create_micro_mod(c_mod, dev_config):
Parameters
----------
c_mod : tvm.module.Module
c_mod : tvm.runtime.Module
module with "c" as its target backend
dev_config : Dict[str, Any]
......@@ -146,7 +146,7 @@ def create_micro_mod(c_mod, dev_config):
Return
------
micro_mod : tvm.module.Module
micro_mod : tvm.runtim.Module
micro module for the target device
"""
temp_dir = _util.tempdir()
......@@ -154,14 +154,14 @@ def create_micro_mod(c_mod, dev_config):
c_mod.export_library(
lib_obj_path,
fcompile=cross_compiler(dev_config, LibType.OPERATOR))
micro_mod = tvm.module.load(lib_obj_path)
micro_mod = tvm.runtime.load_module(lib_obj_path)
return micro_mod
def cross_compiler(dev_config, lib_type):
"""Create a cross-compile function that wraps `create_lib` for a `Binutil` instance.
For use in `tvm.module.Module.export_library`.
For use in `tvm.runtime.Module.export_library`.
Parameters
----------
......
......@@ -104,7 +104,7 @@ class CompileEngine(Object):
return _backend._CompileEngineLowerShapeFunc(self, key)
def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
"""JIT a source_func to a tvm.runtime.PackedFunc.
Parameters
----------
......@@ -116,7 +116,7 @@ class CompileEngine(Object):
Returns
-------
jited_func: tvm.Function
jited_func: tvm.runtime.PackedFunc
The result of jited function.
"""
key = _get_cache_key(source_func, target)
......
......@@ -84,14 +84,14 @@ class Executor(object):
expr: relay.Expr
The expression to evaluate
args: List[tvm.NDArray]
args: List[tvm.nd.NDArray]
The arguments to pass to the evaluator.
kwargs: Dict[str, tvm.NDArrray]
The keyword arguments to pass to the evaluator.
Returns:
args: List[tvm.NDArray]
args: List[tvm.nd.NDArray]
The new arguments with all keyword arguments placed in the correct slot.
"""
assert expr is not None
......
......@@ -85,7 +85,7 @@ class Executable(object):
can then be saved to disk and later deserialized into a new
Executable.
lib : :py:class:`~tvm.module.Module`
lib : :py:class:`~tvm.runtime.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
......@@ -125,7 +125,7 @@ class Executable(object):
lib.export_library(path_lib)
with open(tmp.relpath("code.ro"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_lib = tvm.runtime.load_module(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read())
# deserialize.
des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code)
......@@ -147,7 +147,7 @@ class Executable(object):
bytecode : bytearray
The binary blob representing a the Relay VM bytecode.
lib : :py:class:`~tvm.module.Module`
lib : :py:class:`~tvm.runtime.Module`
The runtime module that contains the generated code.
Returns
......@@ -161,8 +161,8 @@ class Executable(object):
raise TypeError("bytecode is expected to be the type of bytearray " +
"or TVMByteArray, but received {}".format(type(code)))
if lib is not None and not isinstance(lib, tvm.module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
if lib is not None and not isinstance(lib, tvm.runtime.Module):
raise TypeError("lib is expected to be the type of tvm.runtime.Module" +
", but received {}".format(type(lib)))
return Executable(_vm.Load_Executable(bytecode, lib))
......@@ -270,7 +270,7 @@ class Executable(object):
class VirtualMachine(object):
"""Relay VM runtime."""
def __init__(self, mod):
if not isinstance(mod, (Executable, tvm.module.Module)):
if not isinstance(mod, (Executable, tvm.runtime.Module)):
raise TypeError("mod is expected to be the type of Executable or " +
"tvm.Module, but received {}".format(type(mod)))
m = mod.module if isinstance(mod, Executable) else mod
......@@ -534,7 +534,7 @@ class VMCompiler(object):
target_host = tgt
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
if isinstance(target_host, str):
target_host = tvm.target.create(target_host)
return target_host
......
......@@ -567,8 +567,8 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"):
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
caffe2 = Caffe2NetDef(shape, dtype)
......
......@@ -455,7 +455,7 @@ def from_coreml(model, shape=None):
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by Relay.
"""
try:
......
......@@ -843,7 +843,7 @@ def from_darknet(net,
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by relay
"""
......
......@@ -756,7 +756,7 @@ def from_keras(model, shape=None):
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by Relay.
"""
def _check_model_is_tf_keras():
......
......@@ -2012,7 +2012,7 @@ def from_mxnet(symbol,
mod : tvm.relay.Module
The relay module for compilation
params : dict of str to tvm.NDArray
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by nnvm
"""
try:
......
......@@ -1791,7 +1791,7 @@ def from_onnx(model,
mod : tvm.relay.Module
The relay module for compilation
params : dict of str to tvm.NDArray
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by relay
"""
try:
......
......@@ -2655,8 +2655,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
......
......@@ -1896,7 +1896,7 @@ def from_tflite(model, shape_dict, dtype_dict):
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by relay
"""
try:
......
......@@ -41,7 +41,7 @@ import tvm._ffi
from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
from tvm.runtime.module import load as _load_module
from tvm.runtime.module import load_module as _load_module
from tvm.contrib import util
from . import base
from . base import TrackerCode
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM runtime."""
"""TVM runtime namespace."""
# class exposures
from .packed_func import PackedFunc
......@@ -27,6 +27,4 @@ from .module import Module
from .object_generic import convert_to_object, convert, const
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .module import load as load_module
DataType = DataType
from .module import load_module, enabled, system_lib
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.runtime"""
import tvm._ffi
# Exports functions registered via TVM_REGISTER_GLOBAL with the "runtime" prefix.
# e.g. TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile")
tvm._ffi._init_api("runtime", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""FFI for tvm.runtime.extra"""
import tvm._ffi
# The implementations below are default ones when the corresponding
# functions are not available in the runtime only mode.
# They will be overriden via _init_api to the ones registered
# via TVM_REGISTER_GLOBAL in the compiler mode.
def AsRepr(obj):
return obj.type_key() + "(" + obj.handle.value + ")"
def NodeListAttrNames(obj):
return lambda x: 0
def NodeGetAttr(obj, name):
raise AttributeError()
def SaveJSON(obj):
raise RuntimeError(
"Do not support object serialization in runtime only mode")
def LoadJSON(json_str):
raise RuntimeError(
"Do not support object serialization in runtime only mode")
# Exports functions registered via TVM_REGISTER_GLOBAL with the "node" prefix.
# e.g. TVM_REGISTER_GLOBAL("node.AsRepr")
tvm._ffi._init_api("node", __name__)
......@@ -26,6 +26,8 @@ from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
from tvm._ffi.libinfo import find_include_path
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
from . import _ffi_api
# profile result of time evaluator
ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
......@@ -52,7 +54,7 @@ class Module(object):
Returns
-------
f : Function
f : tvm.runtime.PackedFunc
The entry function if exist
"""
if self._entry:
......@@ -73,7 +75,7 @@ class Module(object):
Returns
-------
f : Function
f : tvm.runtime.PackedFunc
The result function.
"""
ret_handle = PackedFuncHandle()
......@@ -91,7 +93,7 @@ class Module(object):
Parameters
----------
module : Module
module : tvm.runtime.Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
......@@ -114,7 +116,7 @@ class Module(object):
@property
def type_key(self):
"""Get type key of the module."""
return _GetTypeKey(self)
return _ffi_api.ModuleGetTypeKey(self)
def get_source(self, fmt=""):
"""Get source code from module, if available.
......@@ -129,7 +131,7 @@ class Module(object):
source : str
The result source code.
"""
return _GetSource(self, fmt)
return _ffi_api.ModuleGetSource(self, fmt)
@property
def imported_modules(self):
......@@ -140,8 +142,8 @@ class Module(object):
modules : list of Module
The module
"""
nmod = _ImportsSize(self)
return [_GetImport(self, i) for i in range(nmod)]
nmod = _ffi_api.ModuleImportsSize(self)
return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)]
def save(self, file_name, fmt=""):
"""Save the module to file.
......@@ -158,9 +160,9 @@ class Module(object):
See Also
--------
Module.export_library : export the module to shared library.
runtime.Module.export_library : export the module to shared library.
"""
_SaveToFile(self, file_name, fmt)
_ffi_api.ModuleSaveToFile(self, file_name, fmt)
def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
"""Get an evaluator that measures time cost of running function.
......@@ -199,13 +201,14 @@ class Module(object):
Returns
-------
ftimer : Function
ftimer : function
The function that takes same argument as func and returns a ProfileResult.
The ProfileResult reports `repeat` time costs in seconds.
"""
try:
feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number, repeat, min_repeat_ms)
feval = _ffi_api.RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id,
number, repeat, min_repeat_ms)
def evaluator(*args):
"""Internal wrapped evaluator."""
......@@ -314,13 +317,13 @@ class Module(object):
if self.imported_modules:
if enabled("llvm") and llvm_target_triple:
path_obj = temp.relpath("devc.o")
m = _PackImportsToLLVM(self, is_system_lib, llvm_target_triple)
m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_triple)
m.save(path_obj)
files.append(path_obj)
else:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib))
files.append(path_cc)
if has_c_module:
......@@ -349,13 +352,13 @@ def system_lib():
Returns
-------
module : Module
module : runtime.Module
The system-wide library module.
"""
return _GetSystemLib()
return _ffi_api.SystemLib()
def load(path, fmt=""):
def load_module(path, fmt=""):
"""Load module from file.
Parameters
......@@ -369,7 +372,7 @@ def load(path, fmt=""):
Returns
-------
module : Module
module : runtime.Module
The loaded module
Note
......@@ -396,7 +399,7 @@ def load(path, fmt=""):
elif path.endswith(".obj"):
fmt = "micro_dev"
# Redirect to the load API
return _LoadFromFile(path, fmt)
return _ffi_api.ModuleLoadFromFile(path, fmt)
def enabled(target):
......@@ -416,11 +419,9 @@ def enabled(target):
--------
The following code checks if gpu is enabled.
>>> tvm.module.enabled("gpu")
>>> tvm.runtime.enabled("gpu")
"""
return _Enabled(target)
return _ffi_api.RuntimeEnabled(target)
_set_class_module(Module)
tvm._ffi._init_api("tvm.module", "tvm.runtime.module")
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime NDArray api"""
"""Runtime NDArray API"""
import ctypes
import numpy as np
import tvm._ffi
......@@ -146,7 +146,7 @@ class NDArray(NDArrayBase):
return self
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res = "<tvm.nd.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
......@@ -203,7 +203,7 @@ def context(dev_type, dev_id=0):
Returns
-------
ctx: TVMContext
ctx: tvm.runtime.TVMContext
The corresponding context.
Examples
......
......@@ -19,7 +19,7 @@
import ctypes
from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .. import _api_internal
from . import _ffi_api, _ffi_node_api
try:
# pylint: disable=wrong-import-position,unused-import
......@@ -41,22 +41,22 @@ def _new_object(cls):
class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
def __repr__(self):
return _api_internal._format_str(self)
return _ffi_node_api.AsRepr(self)
def __dir__(self):
fnames = _api_internal._NodeListAttrNames(self)
fnames = _ffi_node_api.NodeListAttrNames(self)
size = fnames(-1)
return [fnames(i) for i in range(size)]
def __getattr__(self, name):
try:
return _api_internal._NodeGetAttr(self, name)
return _ffi_node_api.NodeGetAttr(self, name)
except AttributeError:
raise AttributeError(
"%s has no attribute %s" % (str(type(self)), name))
def __hash__(self):
return _api_internal._raw_ptr(self)
return _ffi_api.ObjectHash(self)
def __eq__(self, other):
return self.same_as(other)
......@@ -71,25 +71,19 @@ class Object(ObjectBase):
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
return {'handle': _ffi_node_api.SaveJSON(self)}
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
# pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state['handle']
if handle is not None:
json_str = handle
other = _api_internal._load_json(json_str)
other = _ffi_node_api.LoadJSON(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, Object):
return False
return self.__hash__() == other.__hash__()
_set_class_object(Object)
......@@ -19,7 +19,7 @@
from numbers import Number, Integral
from tvm._ffi.base import string_types
from .. import _api_internal
from . import _ffi_node_api
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
......@@ -56,10 +56,10 @@ def convert_to_object(value):
if isinstance(value, Number):
return const(value)
if isinstance(value, string_types):
return _api_internal._str(value)
return _ffi_node_api.String(value)
if isinstance(value, (list, tuple)):
value = [convert_to_object(x) for x in value]
return _api_internal._Array(*value)
return _ffi_node_api.Array(*value)
if isinstance(value, dict):
vlist = []
for item in value.items():
......@@ -68,7 +68,7 @@ def convert_to_object(value):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist)
return _ffi_node_api.Map(*vlist)
if isinstance(value, ObjectGeneric):
return value.asobject()
if value is None:
......@@ -133,9 +133,9 @@ def const(value, dtype=None):
if dtype is None:
dtype = _scalar_type_inference(value)
if dtype == "uint64" and value >= (1 << 63):
return _api_internal._LargeUIntImm(
return _ffi_node_api.LargeUIntImm(
dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype)
return _ffi_node_api._const(value, dtype)
_set_class_object_generic(ObjectGeneric, convert_to_object)
......@@ -47,7 +47,7 @@ class PackedFunc(PackedFuncBase):
For example, the developer function exposed in tvm.ir_pass are actually
C++ functions that are registered as PackedFunc
The following are list of common usage scenario of tvm.Function.
The following are list of common usage scenario of tvm.runtime.PackedFunc.
- Automatic exposure of C++ API into python
- To call PackedFunc from python side
......
......@@ -47,7 +47,7 @@ The list of options include:
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
:any:`tvm.runtime.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
......
......@@ -125,7 +125,7 @@ import tvm
from tvm.contrib import cc
def test_add(target_dir):
if not tvm.module.enabled("cuda"):
if not tvm.runtime.enabled("cuda"):
print("skip {__file__} because cuda is not enabled...".format(__file__=__file__))
return
n = tvm.var("n")
......
......@@ -111,7 +111,7 @@ def download_img_labels():
def test_build(build_dir):
""" Sanity check with random input"""
graph = open(osp.join(build_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(build_dir, "deploy_lib.so"))
lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
ctx = tvm.cpu()
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of basic API functions
* \file api_base.cc
*/
#include <dmlc/memory_io.h>
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/serialization.h>
namespace tvm {
TVM_REGISTER_GLOBAL("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kTVMObjectHandle);
std::ostringstream os;
os << args[0].operator ObjectRef();
*ret = os.str();
});
TVM_REGISTER_GLOBAL("_raw_ptr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kTVMObjectHandle);
*ret = reinterpret_cast<int64_t>(args[0].value().v_handle);
});
TVM_REGISTER_GLOBAL("_save_json")
.set_body_typed(SaveJSON);
TVM_REGISTER_GLOBAL("_load_json")
.set_body_typed(LoadJSON);
TVM_REGISTER_GLOBAL("_TVMSetStream")
.set_body_typed(TVMSetStream);
TVM_REGISTER_GLOBAL("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
constexpr uint64_t TVMNDArrayListMagic = 0xF7E58D4F05049CB7;
size_t num_params = args.size() / 2;
std::vector<std::string> names;
names.reserve(num_params);
std::vector<DLTensor*> arrays;
arrays.reserve(num_params);
for (size_t i = 0; i < num_params * 2; i += 2) {
names.emplace_back(args[i].operator std::string());
arrays.emplace_back(args[i + 1].operator DLTensor*());
}
std::string bytes;
dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm;
uint64_t header = TVMNDArrayListMagic, reserved = 0;
fo->Write(header);
fo->Write(reserved);
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(sz);
for (size_t i = 0; i < sz; ++i) {
tvm::runtime::SaveDLTensor(fo, arrays[i]);
}
}
TVMByteArray arr;
arr.data = bytes.c_str();
arr.size = bytes.length();
*rv = arr;
});
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of API functions related to Codegen
* \file c_api_codegen.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace codegen {
TVM_REGISTER_GLOBAL("codegen._Build")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<tir::LoweredFunc>()) {
*ret = Build({args[0]}, args[1]);
} else {
*ret = Build(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("module._PackImportsToC")
.set_body_typed(PackImportsToC);
TVM_REGISTER_GLOBAL("module._PackImportsToLLVM")
.set_body_typed(PackImportsToLLVM);
} // namespace codegen
} // namespace tvm
......@@ -21,7 +21,7 @@
* Implementation of API functions related to Higher DSL build.
* \file api_lang.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/te/operation.h>
......@@ -32,7 +32,6 @@
#include <tvm/driver/driver_api.h>
#include <tvm/tir/data_layout.h>
namespace tvm {
TVM_REGISTER_GLOBAL("_min_value")
......@@ -41,172 +40,6 @@ TVM_REGISTER_GLOBAL("_min_value")
TVM_REGISTER_GLOBAL("_max_value")
.set_body_typed(max_value);
TVM_REGISTER_GLOBAL("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kDLInt) {
*ret = tir::make_const(args[1], args[0].operator int64_t());
} else if (args[0].type_code() == kDLFloat) {
*ret = tir::make_const(args[1], args[0].operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
});
TVM_REGISTER_GLOBAL("_LargeUIntImm")
.set_body_typed(LargeUIntImm);
TVM_REGISTER_GLOBAL("_str")
.set_body_typed(tir::StringImmNode::make);
TVM_REGISTER_GLOBAL("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data;
for (int i = 0; i < args.size(); ++i) {
if (args[i].type_code() != kTVMNullptr) {
data.push_back(args[i].operator ObjectRef());
} else {
data.push_back(ObjectRef(nullptr));
}
}
auto node = make_object<ArrayNode>();
node->data = std::move(data);
*ret = Array<ObjectRef>(node);
});
TVM_REGISTER_GLOBAL("_ArrayGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t i = args[1];
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(ptr);
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
*ret = n->data[static_cast<size_t>(i)];
});
TVM_REGISTER_GLOBAL("_ArraySize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(ptr)->data.size());
});
TVM_REGISTER_GLOBAL("_Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kTVMStr) {
// StrMap
StrMapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kTVMStr)
<< "key of str map need to be str";
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef()));
}
auto node = make_object<StrMapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].IsObjectRef<ObjectRef>())
<< "key of str map need to be object";
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef()));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
}
});
TVM_REGISTER_GLOBAL("_MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr);
*ret = static_cast<int64_t>(n->data.size());
} else {
CHECK(ptr->IsInstance<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(ptr);
*ret = static_cast<int64_t>(n->data.size());
}
});
TVM_REGISTER_GLOBAL("_MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
CHECK(args[1].type_code() == kTVMObjectHandle);
auto* n = static_cast<const MapNode*>(ptr);
auto it = n->data.find(args[1].operator ObjectRef());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
} else {
CHECK(ptr->IsInstance<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(ptr);
auto it = n->data.find(args[1].operator std::string());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
}
});
TVM_REGISTER_GLOBAL("_MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr);
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
*ret = static_cast<int64_t>(
n->data.count(args[1].operator ObjectRef()));
} else {
CHECK(ptr->IsInstance<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(ptr);
*ret = static_cast<int64_t>(
n->data.count(args[1].operator std::string()));
}
});
TVM_REGISTER_GLOBAL("_MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr);
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
*ret = Array<ObjectRef>(rkvs);
} else {
auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(tir::StringImmNode::make(kv.first));
rkvs->data.push_back(kv.second);
}
*ret = Array<ObjectRef>(rkvs);
}
});
TVM_REGISTER_GLOBAL("Range")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Expose container API to frontend.
* \file src/node/container.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
namespace tvm {
TVM_REGISTER_GLOBAL("node.Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data;
for (int i = 0; i < args.size(); ++i) {
if (args[i].type_code() != kTVMNullptr) {
data.push_back(args[i].operator ObjectRef());
} else {
data.push_back(ObjectRef(nullptr));
}
}
auto node = make_object<ArrayNode>();
node->data = std::move(data);
*ret = Array<ObjectRef>(node);
});
TVM_REGISTER_GLOBAL("node.ArrayGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t i = args[1];
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(ptr);
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
*ret = n->data[static_cast<size_t>(i)];
});
TVM_REGISTER_GLOBAL("node.ArraySize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(ptr)->data.size());
});
TVM_REGISTER_GLOBAL("node.Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kTVMStr) {
// StrMap
StrMapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kTVMStr)
<< "key of str map need to be str";
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef()));
}
auto node = make_object<StrMapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].IsObjectRef<ObjectRef>())
<< "key of str map need to be object";
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef()));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
}
});
TVM_REGISTER_GLOBAL("node.MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr);
*ret = static_cast<int64_t>(n->data.size());
} else {
CHECK(ptr->IsInstance<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(ptr);
*ret = static_cast<int64_t>(n->data.size());
}
});
TVM_REGISTER_GLOBAL("node.MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
CHECK(args[1].type_code() == kTVMObjectHandle);
auto* n = static_cast<const MapNode*>(ptr);
auto it = n->data.find(args[1].operator ObjectRef());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
} else {
CHECK(ptr->IsInstance<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(ptr);
auto it = n->data.find(args[1].operator std::string());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
}
});
TVM_REGISTER_GLOBAL("node.MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr);
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
*ret = static_cast<int64_t>(
n->data.count(args[1].operator ObjectRef()));
} else {
CHECK(ptr->IsInstance<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(ptr);
*ret = static_cast<int64_t>(
n->data.count(args[1].operator std::string()));
}
});
TVM_REGISTER_GLOBAL("node.MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr);
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
*ret = Array<ObjectRef>(rkvs);
} else {
auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(tir::StringImmNode::make(kv.first));
rkvs->data.push_back(kv.second);
}
*ret = Array<ObjectRef>(rkvs);
}
});
} // namespace tvm
......@@ -298,13 +298,12 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
}
TVM_REGISTER_GLOBAL("_NodeGetAttr")
TVM_REGISTER_GLOBAL("node.NodeGetAttr")
.set_body(NodeGetAttr);
TVM_REGISTER_GLOBAL("_NodeListAttrNames")
TVM_REGISTER_GLOBAL("node.NodeListAttrNames")
.set_body(NodeListAttrNames);
TVM_REGISTER_GLOBAL("make._Node")
.set_body(MakeNode);
} // namespace tvm
......@@ -21,6 +21,7 @@
* Printer utilities
* \file node/repr_printer.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/node/repr_printer.h>
namespace tvm {
......@@ -53,4 +54,11 @@ ReprPrinter::FType& ReprPrinter::vtable() {
void Dump(const ObjectRef& n) {
std::cerr << n << "\n";
}
TVM_REGISTER_GLOBAL("node.AsRepr")
.set_body_typed([](runtime::ObjectRef obj) {
std::ostringstream os;
os << obj;
return os.str();
});
} // namespace tvm
......@@ -23,7 +23,7 @@
*/
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/node/container.h>
......@@ -455,4 +455,10 @@ ObjectRef LoadJSON(std::string json_str) {
}
return ObjectRef(nodes.at(jgraph.root));
}
TVM_REGISTER_GLOBAL("node.SaveJSON")
.set_body_typed(SaveJSON);
TVM_REGISTER_GLOBAL("node.LoadJSON")
.set_body_typed(LoadJSON);
} // namespace tvm
......@@ -194,7 +194,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
}
// Create a CSourceModule
const auto* pf = runtime::Registry::Get("module.csource_module_create");
const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code_stream_.str(), "cc");
}
......
......@@ -282,7 +282,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
}
// Create a CSourceModule
const auto* pf = runtime::Registry::Get("module.csource_module_create");
const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code_stream_.str(), "cc");
}
......
......@@ -630,3 +630,7 @@ TVM_REGISTER_GLOBAL("_GetDeviceAttr")
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
TVM_REGISTER_GLOBAL("runtime.TVMSetStream")
.set_body_typed(TVMSetStream);
......@@ -332,12 +332,12 @@ class ExampleJsonModule : public ModuleNode {
std::vector<std::string> op_id_;
};
TVM_REGISTER_GLOBAL("module.loadfile_examplejson")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_examplejson")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = ExampleJsonModule::Create(args[0]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_examplejson")
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_examplejson")
.set_body_typed(ExampleJsonModule::LoadFromBinary);
} // namespace runtime
......
......@@ -305,13 +305,13 @@ Module CUDAModuleLoadBinary(void* strm) {
return CUDAModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL("module.loadfile_cubin")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin")
.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_ptx")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx")
.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda")
.set_body_typed(CUDAModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -97,7 +97,7 @@ class DSOLibrary final : public Library {
#endif
};
TVM_REGISTER_GLOBAL("module.loadfile_so")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<DSOLibrary>();
n->Init(args[0]);
......
......@@ -148,7 +148,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
CHECK(stream->Read(&import_tree_row_ptr));
CHECK(stream->Read(&import_tree_child_indices));
} else {
std::string fkey = "module.loadbinary_" + tkey;
std::string fkey = "runtime.module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
......
......@@ -307,10 +307,10 @@ Module MetalModuleLoadBinary(void* strm) {
return MetalModuleCreate(data, fmt, fmap, "");
}
TVM_REGISTER_GLOBAL("module.loadfile_metal")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal")
.set_body_typed(MetalModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_metal")
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal")
.set_body_typed(MetalModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -101,7 +101,7 @@ PackedFunc MicroModuleNode::GetFunction(
}
// register loadfile function to load module from Python frontend
TVM_REGISTER_GLOBAL("module.loadfile_micro_dev")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_micro_dev")
.set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<MicroModuleNode>();
n->InitMicroModule(args[0]);
......
......@@ -84,7 +84,7 @@ Module Module::LoadFromFile(const std::string& file_name,
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so";
}
std::string load_f_name = "module.loadfile_" + fmt;
std::string load_f_name = "runtime.module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name);
CHECK(f != nullptr)
<< "Loader of " << format << "("
......@@ -164,42 +164,35 @@ bool RuntimeEnabled(const std::string& target) {
return runtime::Registry::Get(f_name) != nullptr;
}
TVM_REGISTER_GLOBAL("module._Enabled")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = RuntimeEnabled(args[0]);
});
TVM_REGISTER_GLOBAL("module._GetSource")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->GetSource(args[1]);
});
TVM_REGISTER_GLOBAL("module._ImportsSize")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = static_cast<int64_t>(
args[0].operator Module()->imports().size());
});
TVM_REGISTER_GLOBAL("module._GetImport")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->
imports().at(args[1].operator int());
});
TVM_REGISTER_GLOBAL("module._GetTypeKey")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = std::string(args[0].operator Module()->type_key());
});
TVM_REGISTER_GLOBAL("module._LoadFromFile")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Module::LoadFromFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module._SaveToFile")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Module()->
SaveToFile(args[1], args[2]);
});
TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled")
.set_body_typed(RuntimeEnabled);
TVM_REGISTER_GLOBAL("runtime.ModuleGetSource")
.set_body_typed([](Module mod, std::string fmt) {
return mod->GetSource(fmt);
});
TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize")
.set_body_typed([](Module mod) {
return static_cast<int64_t>(mod->imports().size());
});
TVM_REGISTER_GLOBAL("runtime.ModuleGetImport")
.set_body_typed([](Module mod, int index) {
return mod->imports().at(index);
});
TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey")
.set_body_typed([](Module mod) {
return std::string(mod->type_key());
});
TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile")
.set_body_typed(Module::LoadFromFile);
TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, std::string name, std::string fmt) {
mod->SaveToFile(name, fmt);
});
} // namespace runtime
} // namespace tvm
......@@ -21,6 +21,7 @@
* \brief Object type management system.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/object.h>
#include <mutex>
#include <string>
......@@ -202,6 +203,11 @@ uint32_t Object::TypeKey2Index(const std::string& key) {
return TypeContext::Global()->TypeKey2Index(key);
}
TVM_REGISTER_GLOBAL("runtime.ObjectHash")
.set_body_typed([](ObjectRef obj) {
return static_cast<int64_t>(ObjectHash()(obj));
});
} // namespace runtime
} // namespace tvm
......
......@@ -66,7 +66,7 @@ Module AOCLModuleLoadFile(const std::string& file_name,
return AOCLModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL("module.loadfile_aocx")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx")
.set_body_typed(AOCLModuleLoadFile);
} // namespace runtime
......
......@@ -278,13 +278,13 @@ Module OpenCLModuleLoadBinary(void* strm) {
return OpenCLModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL("module.loadfile_cl")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl")
.set_body_typed(OpenCLModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_clbin")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin")
.set_body_typed(OpenCLModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl")
.set_body_typed(OpenCLModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -77,10 +77,10 @@ Module SDAccelModuleLoadBinary(void* strm) {
return SDAccelModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL("module.loadfile_xclbin")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin")
.set_body_typed(SDAccelModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin")
.set_body_typed(SDAccelModuleLoadFile);
} // namespace runtime
} // namespace tvm
......@@ -278,17 +278,17 @@ Module OpenGLModuleLoadBinary(void* strm) {
return OpenGLModuleCreate(FromJSON(data), fmt, fmap);
}
TVM_REGISTER_GLOBAL("module.loadfile_gl")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenGLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadfile_glbin")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenGLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_opengl")
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenGLModuleLoadBinary(args[0]);
});
......
......@@ -254,18 +254,18 @@ Module ROCMModuleLoadBinary(void* strm) {
}
TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco")
.set_body_typed(ROCMModuleLoadBinary);
TVM_REGISTER_GLOBAL("module.loadbinary_hip")
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip")
.set_body_typed(ROCMModuleLoadBinary);
TVM_REGISTER_GLOBAL("module.loadfile_hsaco")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco")
.set_body_typed(ROCMModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_hip")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip")
.set_body_typed(ROCMModuleLoadFile);
} // namespace runtime
} // namespace tvm
......@@ -234,7 +234,7 @@ Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
return Module(n);
}
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
......
......@@ -84,7 +84,7 @@ void tvm_ecall_packed_func(int func_id,
TVM_REGISTER_ENCLAVE_FUNC("__tvm_main__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module mod = (*Registry::Get("module._GetSystemLib"))();
Module mod = (*Registry::Get("runtime.SystemLib"))();
mod.GetFunction("default_function").CallPacked(args, rv);
});
......
......@@ -243,7 +243,7 @@ TVM_REGISTER_GLOBAL("__sgx_reserve_space__")
} // extern "C"
} // namespace sgx
TVM_REGISTER_GLOBAL("module.loadfile_sgx")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_sgx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<SGXModuleNode> node = std::make_shared<SGXModuleNode>();
node->Init(args[0]);
......
......@@ -106,7 +106,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
for (uint64_t i = 0; i < num_imports; ++i) {
std::string tkey;
CHECK(strm->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
std::string fkey = "runtime.module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
......@@ -137,7 +137,7 @@ Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
return StackVMModuleNode::Create(fmap, entry_func);
}
TVM_REGISTER_GLOBAL("module.loadfile_stackvm")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm")
.set_body_typed(StackVMModuleNode::LoadFromFile);
} // namespace runtime
......
......@@ -68,12 +68,12 @@ class SystemLibrary : public Library {
std::unordered_map<std::string, void*> tbl_;
};
TVM_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("runtime.SystemLib")
.set_body_typed([]() {
static auto mod = CreateModuleFromLibrary(
SystemLibrary::Global());
*rv = mod;
});
return mod;
});
} // namespace runtime
} // namespace tvm
......
......@@ -1143,9 +1143,9 @@ Module VulkanModuleLoadBinary(void* strm) {
return VulkanModuleCreate(smap, fmap, "");
}
TVM_REGISTER_GLOBAL("module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
......
......@@ -244,5 +244,21 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
return (*codegen_f)(blob_byte_array, system_lib, target_triple);
}
TVM_REGISTER_GLOBAL("codegen._Build")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<tir::LoweredFunc>()) {
*ret = Build({args[0]}, args[1]);
} else {
*ret = Build(args[0], args[1]);
}
});
// Export two auxiliary function to the runtime namespace.
TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
.set_body_typed(PackImportsToC);
TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM")
.set_body_typed(PackImportsToLLVM);
} // namespace codegen
} // namespace tvm
......@@ -368,7 +368,7 @@ TVM_REGISTER_GLOBAL("codegen.llvm_version_major")
*rv = major;
});
TVM_REGISTER_GLOBAL("module.loadfile_ll")
TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
.set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<LLVMModuleNode>();
n->LoadIR(args[0]);
......
......@@ -184,10 +184,10 @@ runtime::Module DeviceSourceModuleCreate(
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("module.source_module_create")
TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate")
.set_body_typed(SourceModuleCreate);
TVM_REGISTER_GLOBAL("module.csource_module_create")
TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate")
.set_body_typed(CSourceModuleCreate);
} // namespace codegen
} // namespace tvm
......@@ -21,6 +21,7 @@
* \file expr_operator.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <cmath>
......@@ -632,4 +633,23 @@ PrimExpr trunc(PrimExpr x) {
return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic);
}
// expose basic functions to node namespace
TVM_REGISTER_GLOBAL("node._const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kDLInt) {
*ret = tir::make_const(args[1], args[0].operator int64_t());
} else if (args[0].type_code() == kDLFloat) {
*ret = tir::make_const(args[1], args[0].operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
});
TVM_REGISTER_GLOBAL("node.LargeUIntImm")
.set_body_typed(LargeUIntImm);
TVM_REGISTER_GLOBAL("node.String")
.set_body_typed(tir::StringImmNode::make);
} // namespace tvm
......@@ -76,8 +76,7 @@ TEST(BuildModule, Heterogeneous) {
using namespace tvm;
using namespace tvm::te;
const runtime::PackedFunc* pf = runtime::Registry::Get("module._Enabled");
bool enabled = (*pf)("cuda");
bool enabled = tvm::runtime::RuntimeEnabled("cuda");
if (!enabled) {
LOG(INFO) << "Skip heterogeneous test because cuda is not enabled."
<< "\n";
......
......@@ -235,7 +235,7 @@ TEST(PackedFunc, ObjectConversion) {
pf1(ObjectRef(x), NDArray());
// testcases for modules
auto* pf = tvm::runtime::Registry::Get("module.source_module_create");
auto* pf = tvm::runtime::Registry::Get("runtime.SourceModuleCreate");
CHECK(pf != nullptr);
Module m = (*pf)("", "xyz");
rv = m;
......
......@@ -37,7 +37,7 @@ def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32):
return np.dot(a, b) + bb
def verify(target="llvm"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
......@@ -81,7 +81,7 @@ def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=Fa
return topi.testing.batch_matmul(a, b)
def verify(target="llvm"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
......
......@@ -29,7 +29,7 @@ def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
s = tvm.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
......@@ -63,7 +63,7 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
s = tvm.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True):
......@@ -114,7 +114,7 @@ def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
s = tvm.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
......
......@@ -35,7 +35,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
height = 32
weight = 32
if not tvm.module.enabled("cuda"):
if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
......@@ -110,7 +110,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
height = 32
weight = 32
if not tvm.module.enabled("cuda"):
if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
......
......@@ -33,7 +33,7 @@ def benchmark_fc_int8_acc16():
print("Peak {} Gops/s \n".format(peak))
def verify(target="llvm -mcpu=skylake-avx512"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
......
......@@ -41,7 +41,7 @@ def test_fc_int8_acc32():
# (ignoring processor)" error with the following setting. After LLVM 8.0 is enabled in the
# test, we should use cascadelake setting.
def verify(target="llvm -mcpu=cascadelake"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
......
......@@ -32,7 +32,7 @@ def test_conv2d():
dilation_w = 1
xshape = [1, in_channel, 128, 128]
if not tvm.module.enabled("rocm"):
if not tvm.runtime.enabled("rocm"):
print("skip because rocm is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True):
......
......@@ -19,7 +19,7 @@ import numpy as np
from tvm.contrib import mps
def test_matmul():
if not tvm.module.enabled("metal"):
if not tvm.runtime.enabled("metal"):
print("skip because %s is not enabled..." % "metal")
return
n = 1024
......@@ -62,7 +62,7 @@ def test_matmul():
verify(A, B, D, s)
def test_conv2d():
if not tvm.module.enabled("metal"):
if not tvm.runtime.enabled("metal"):
print("skip because %s is not enabled..." % "metal")
return
n = 1
......
......@@ -34,7 +34,7 @@ def test_fully_connected_inference():
s = tvm.create_schedule(D.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
pytest.skip("%s is not enabled..." % target)
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
pytest.skip("extern function is not available")
......@@ -104,7 +104,7 @@ def test_convolution_inference():
def verify(target="llvm",
algorithm=nnpack.ConvolutionAlgorithm.AUTO,
with_bias=True):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
pytest.skip("%s is not enabled..." % target)
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
pytest.skip("extern function is not available")
......@@ -166,7 +166,7 @@ def test_convolution_inference_without_weight_transform():
def verify(target="llvm",
algorithm=nnpack.ConvolutionAlgorithm.AUTO,
with_bias=True):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
pytest.skip("%s is not enabled..." % target)
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
pytest.skip("extern function is not available")
......
......@@ -25,7 +25,7 @@ def test_randint():
s = tvm.create_schedule(A.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.random.randint", True):
......@@ -49,7 +49,7 @@ def test_uniform():
s = tvm.create_schedule(A.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.random.uniform", True):
......@@ -73,7 +73,7 @@ def test_normal():
s = tvm.create_schedule(A.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.random.normal", True):
......
......@@ -28,7 +28,7 @@ def test_matmul_add():
s = tvm.create_schedule(C.op)
def verify(target="rocm"):
if not tvm.module.enabled(target):
if not tvm.runtime.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment