Unverified Commit fc7dd6d7 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] Establish tvm.runtime (#4818)

* [REFACTOR][PY] Establish tvm.runtime

This PR establishes the tvm.runtime namespace that contains the core runtime data structures.
The top-level API are kept inact for now via re-exporting.

We will followup later to cleanup some of the top-level APIs.

* Fix ndarray name
parent 7d263c31
......@@ -15,19 +15,29 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation."""
"""TVM: Open Deep Learning Compiler Stack."""
import multiprocessing
import sys
import traceback
# import ffi related features
# top-level alias
# tvm._ffi
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object
# top-level alias
# tvm.runtime
from .runtime.object import Object
from .runtime.packed_func import PackedFunc as Function
from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import module
from .runtime import ndarray
# pylint: disable=reimported
from .runtime import ndarray as nd
# others
from . import tensor
from . import arith
from . import expr
......@@ -37,7 +47,7 @@ from . import ir_pass
from . import codegen
from . import container
from . import schedule
from . import module
from . import attrs
from . import ir_builder
from . import target
......@@ -47,9 +57,6 @@ from . import testing
from . import error
from . import datatype
from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .api import *
from .intrin import *
......
......@@ -23,7 +23,7 @@ from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from ..runtime_ctypes import DataType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
......@@ -132,7 +132,7 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext):
......
......@@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
from ..runtime_ctypes import DataType, TVMContext, TVMByteArray
cdef void tvm_callback_finalize(void* fhandle):
......@@ -129,7 +129,7 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kTVMStr
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime Module namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = PackedFuncHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return PackedFunc(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
......@@ -48,7 +48,7 @@ class TVMByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
class DataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
......@@ -60,7 +60,7 @@ class TVMType(ctypes.Structure):
4 : 'handle'
}
def __init__(self, type_str):
super(TVMType, self).__init__()
super(DataType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
......@@ -104,8 +104,8 @@ class TVMType(ctypes.Structure):
def __repr__(self):
if self.bits == 1 and self.lanes == 1:
return "bool"
if self.type_code in TVMType.CODE2STR:
type_name = TVMType.CODE2STR[self.type_code]
if self.type_code in DataType.CODE2STR:
type_name = DataType.CODE2STR[self.type_code]
else:
type_name = "custom[%s]" % \
_api_internal._datatype_get_type_name(self.type_code)
......@@ -263,7 +263,7 @@ class TVMArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("dtype", DataType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
......
......@@ -20,10 +20,10 @@ from numbers import Integral as _Integral
import tvm._ffi
from tvm.runtime import convert, const, DataType
from ._ffi.base import string_types, TVMError
from ._ffi.object_generic import convert, const
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal
from . import make as _make
from . import expr as _expr
......
......@@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object
from ._ffi.object import Object
from . import _api_internal
class IntSet(Object):
......
......@@ -17,7 +17,7 @@
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
import tvm._ffi
from ._ffi.object import Object
from tvm.runtime import Object
from . import _api_internal
......
......@@ -22,7 +22,7 @@ LoweredFunc and compiled Module.
import warnings
import tvm._ffi
from ._ffi.object import Object
from tvm.runtime import Object, ndarray
from . import api
from . import _api_internal
from . import tensor
......@@ -33,7 +33,6 @@ from . import stmt as _stmt
from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target
from . import make
......
......@@ -17,9 +17,9 @@
"""Container data structures used in TVM DSL."""
import tvm._ffi
from tvm import ndarray as _nd
from tvm.runtime import Object, ObjectTypes
from tvm.runtime.container import getitem_helper
from . import _api_internal
from ._ffi.object import Object, getitem_helper
@tvm._ffi.register_object
......@@ -31,23 +31,9 @@ class Array(Object):
to Array during tvm function call.
You may get Array in return values of TVM function call.
"""
def __getitem__(self, i):
if isinstance(i, slice):
start = i.start if i.start is not None else 0
stop = i.stop if i.stop is not None else len(self)
step = i.step if i.step is not None else 1
if start < 0:
start += len(self)
if stop < 0:
stop += len(self)
return [self[idx] for idx in range(start, stop, step)]
if i < -len(self) or i >= len(self):
raise IndexError("Array index out of range. Array size: {}, got index {}"
.format(len(self), i))
if i < 0:
i += len(self)
return _api_internal._ArrayGetItem(self, i)
def __getitem__(self, idx):
return getitem_helper(
self, _api_internal._ArrayGetItem, len(self), idx)
def __len__(self):
return _api_internal._ArraySize(self)
......@@ -133,7 +119,7 @@ class ADT(Object):
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
......@@ -164,7 +150,7 @@ def tuple_object(fields=None):
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
......
......@@ -23,7 +23,7 @@ import tvm._ffi
from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime
from tvm.ndarray import array
from tvm.runtime.ndarray import array
from . import debug_result
_DUMP_ROOT_PREFIX = "tvmdbg_"
......
......@@ -21,8 +21,9 @@ from __future__ import absolute_import as _abs
import subprocess
import os
import warnings
from tvm.runtime import ndarray as nd
from . import util
from .. import ndarray as nd
from ..api import register_func
from .._ffi.base import py_str
......
......@@ -20,7 +20,7 @@ import tvm._ffi
from . import make as _make
from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from ._ffi.runtime_ctypes import TVMType as _TVMType
from ._ffi.runtime_ctypes import DataType
from . import _api_internal
......@@ -131,7 +131,7 @@ def create_lower_func(extern_func_name):
width as the custom type is returned. Otherwise, the type is
unchanged."""
dtype = op.dtype
t = _TVMType(dtype)
t = DataType(dtype)
if get_type_registered(t.type_code):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
......
......@@ -31,12 +31,9 @@ For example, you can use addexp.a to get the left operand of an Add node.
assert(y.a == x)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
from . import _api_internal
......@@ -52,7 +49,7 @@ def _dtype_is_int(value):
if isinstance(value, int):
return True
return (isinstance(value, ExprOp) and
TVMType(value.dtype).type_code == TypeCode.INT)
DataType(value.dtype).type_code == TypeCode.INT)
class ExprOp(object):
......
......@@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Developer API of IR node builder make function."""
from __future__ import absolute_import as _abs
from tvm.runtime import ObjectGeneric, DataType
from ._ffi.base import string_types
from . import api as _api
from . import stmt as _stmt
......@@ -23,9 +25,6 @@ from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from . import container as _container
from ._ffi.base import string_types
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call
class WithScope(object):
......@@ -78,7 +77,7 @@ class BufferVar(ObjectGeneric):
return self._content_type
def __getitem__(self, index):
t = TVMType(self._content_type)
t = DataType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index)
......@@ -89,7 +88,7 @@ class BufferVar(ObjectGeneric):
raise ValueError(
"data type does not match content type %s vs %s" % (
value.dtype, self._content_type))
t = TVMType(self._content_type)
t = DataType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index))
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime NDArray API.
tvm.ndarray provides a minimum runtime array API to test
the correctness of the program.
"""
# pylint: disable=invalid-name,unused-import
import tvm._ffi
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
@tvm._ffi.register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container (a buffer object)
No arthimetic operations are defined.
All operations are performed by TVM functions.
The goal is not to re-build yet another array library.
Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers.
"""
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def rocm(dev_id=0):
"""Construct a ROCM device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(10, dev_id)
def opencl(dev_id=0):
"""Construct a OpenCL device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(4, dev_id)
def metal(dev_id=0):
"""Construct a metal device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(8, dev_id)
def vpi(dev_id=0):
"""Construct a VPI simulated device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(9, dev_id)
def vulkan(dev_id=0):
"""Construct a Vulkan device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(7, dev_id)
def opengl(dev_id=0):
"""Construct a OpenGL device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(11, dev_id)
def ext_dev(dev_id=0):
"""Construct a extension device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)
def micro_dev(dev_id=0):
"""Construct a micro device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(13, dev_id)
cl = opencl
mtl = metal
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
_set_class_ndarray(NDArray)
......@@ -33,9 +33,7 @@ To connect to the graph runtime, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
contrib.graph_runtime or any other TVM runtime compatible systems.
"""
from __future__ import absolute_import
from tvm.ndarray import empty
from tvm.runtime.ndarray import empty
from tvm.relay import _build_module
from tvm import target as _target
from tvm import expr as _expr
......
......@@ -23,9 +23,9 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np
import tvm
import tvm.ndarray as _nd
import tvm.runtime.ndarray as _nd
from tvm.runtime import Object
from tvm import autotvm, container
from tvm._ffi.object import Object
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
......
......@@ -18,12 +18,11 @@
"""The base node types for the Relay language."""
import tvm._ffi
from .._ffi.object import Object
from tvm.runtime import Object
from . import _make
from . import _expr
from . import _base
Object = Object
def register_relay_node(type_key=None):
"""Register a Relay node type.
......
......@@ -20,14 +20,14 @@ from __future__ import absolute_import
from numbers import Number as _Number
import numpy as _np
from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
from ..ndarray import NDArray
# will be registered afterwards
_op_make = None
......
......@@ -23,7 +23,7 @@ from .expr_functor import ExprMutator
from .scope_builder import ScopeBuilder
from . import transform
from . import op, ty, expr
from .. import TVMType, register_func
from .. import DataType, register_func
from .backend import compile_engine
......@@ -109,7 +109,7 @@ class ManifestAllocPass(ExprMutator):
return expr.Tuple(new_fields)
def compute_alignment(self, dtype):
dtype = TVMType(dtype)
dtype = DataType(dtype)
align = (dtype.bits // 8) * dtype.lanes
# MAGIC CONSTANT FROM device_api.h
if align < 64:
......@@ -118,7 +118,7 @@ class ManifestAllocPass(ExprMutator):
return expr.const(align, dtype="int64")
def compute_storage_in_relay(self, shape, dtype):
dtype = TVMType(dtype)
dtype = DataType(dtype)
els = op.prod(shape)
num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
num = num + expr.const(7, self.compute_dtype)
......@@ -126,7 +126,7 @@ class ManifestAllocPass(ExprMutator):
return els * (num / div)
def compute_storage(self, tensor_type):
dtype = TVMType(tensor_type.dtype)
dtype = DataType(tensor_type.dtype)
shape = [int(sh) for sh in tensor_type.shape]
size = 1
for sh in shape:
......
......@@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Annotation operations."""
from __future__ import absolute_import as _abs
from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make
from ..op import register_schedule, schedule_injective
from .... import nd as _nd
from .... import TVMContext as _TVMContext
def on_device(data, device):
"""Annotate an expression with a certain device type.
......
......@@ -16,11 +16,12 @@
# under the License.
"""Basic tensor operations."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs
from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make
from ..expr import Tuple
from ... import nd as _nd
from ... import TVMContext as _TVMContext
# We create a wrapper function for each operator in the
# python side to call into the positional _make.OpName function.
......
......@@ -22,12 +22,12 @@ import socket
import struct
import time
import tvm._ffi
from tvm.contrib import util
from tvm._ffi.base import TVMError
from tvm.runtime import ndarray as nd
from tvm.runtime import load_module as _load_module
from . import base
from ..contrib import util
from .._ffi.base import TVMError
from .._ffi import ndarray as nd
from ..module import load as _load_module
class RPCSession(object):
......
......@@ -38,10 +38,10 @@ import sys
import signal
import tvm._ffi
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
from ..contrib import util
from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
from tvm.runtime.module import load as _load_module
from tvm.contrib import util
from . import base
from . base import TrackerCode
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM runtime."""
# class exposures
from .packed_func import PackedFunc
from .object import Object
from .object_generic import ObjectGeneric, ObjectTypes
from .ndarray import NDArray, DataType, TypeCode, TVMContext
from .module import Module
# function exposures
from .object_generic import convert_to_object, convert, const
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .module import load as load_module
DataType = DataType
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Runtime container structures."""
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
......@@ -14,22 +14,99 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Container of compiled functions of TVM."""
from __future__ import absolute_import as _abs
# pylint: disable=invalid-name, unused-import
"""Runtime Module namespace."""
import ctypes
import struct
from collections import namedtuple
import tvm._ffi
from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
from tvm._ffi.libinfo import find_include_path
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
from ._ffi.module import ModuleBase, _set_class_module
from ._ffi.libinfo import find_include_path
from .contrib import cc as _cc, tar as _tar, util as _util
# profile result of time evaluator
ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
class Module(ModuleBase):
"""Module container of all TVM generated functions"""
class Module(object):
"""Runtime Module."""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = PackedFuncHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return PackedFunc(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
# pylint: disable=not-callable
return self.entry_func(*args)
def __repr__(self):
return "Module(%s, %x)" % (self.type_key, self.handle.value)
......@@ -85,6 +162,83 @@ class Module(ModuleBase):
"""
_SaveToFile(self, file_name, fmt)
def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
"""Get an evaluator that measures time cost of running function.
Parameters
----------
func_name: str
The name of the function in the module.
ctx: TVMContext
The context we should run this function on.
number: int
The number of times to run this function for taking average.
We call these runs as one `repeat` of measurement.
repeat: int, optional
The number of times to repeat the measurement.
In total, the function will be invoked (1 + number x repeat) times,
where the first one is warm up and will be discarded.
The returned result contains `repeat` costs,
each of which is an average of `number` costs.
min_repeat_ms: int, optional
The minimum duration of one `repeat` in milliseconds.
By default, one `repeat` contains `number` runs. If this parameter is set,
the parameters `number` will be dynamically adjusted to meet the
minimum duration requirement of one `repeat`.
i.e., When the run time of one `repeat` falls below this time, the `number` parameter
will be automatically increased.
Note
----
The function will be invoked (1 + number x repeat) times,
with the first call discarded in case there is lazy initialization.
Returns
-------
ftimer : Function
The function that takes same argument as func and returns a ProfileResult.
The ProfileResult reports `repeat` time costs in seconds.
"""
try:
feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number, repeat, min_repeat_ms)
def evaluator(*args):
"""Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future.
blob = feval(*args)
fmt = "@" + ("d" * repeat)
results = struct.unpack(fmt, blob)
mean = sum(results) / float(repeat)
return ProfileResult(mean=mean, results=results)
return evaluator
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")
def _collect_dso_modules(self):
"""Helper function to collect dso modules, then return it."""
visited, stack, dso_modules = set(), [], []
# append root module
visited.add(self)
stack.append(self)
while stack:
module = stack.pop()
if module._dso_exportable():
dso_modules.append(module)
for m in module.imported_modules:
if m not in visited:
visited.add(m)
stack.append(m)
return dso_modules
def _dso_exportable(self):
return self.type_key == "llvm" or self.type_key == "c"
def export_library(self,
file_name,
fcompile=None,
......@@ -107,7 +261,14 @@ class Module(ModuleBase):
kwargs : dict, optional
Additional arguments passed to fcompile
"""
# NOTE: this function depends on contrib library features
# which are only available in when TVM function is available.
if _RUNTIME_ONLY:
raise RuntimeError("Cannot call export_library in runtime only mode")
# Extra dependencies during runtime.
from pathlib import Path
from tvm.contrib import cc as _cc, tar as _tar, util as _util
if isinstance(file_name, Path):
file_name = str(file_name)
......@@ -172,83 +333,6 @@ class Module(ModuleBase):
fcompile(file_name, files, **kwargs)
def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
"""Get an evaluator that measures time cost of running function.
Parameters
----------
func_name: str
The name of the function in the module.
ctx: TVMContext
The context we should run this function on.
number: int
The number of times to run this function for taking average.
We call these runs as one `repeat` of measurement.
repeat: int, optional
The number of times to repeat the measurement.
In total, the function will be invoked (1 + number x repeat) times,
where the first one is warm up and will be discarded.
The returned result contains `repeat` costs,
each of which is an average of `number` costs.
min_repeat_ms: int, optional
The minimum duration of one `repeat` in milliseconds.
By default, one `repeat` contains `number` runs. If this parameter is set,
the parameters `number` will be dynamically adjusted to meet the
minimum duration requirement of one `repeat`.
i.e., When the run time of one `repeat` falls below this time, the `number` parameter
will be automatically increased.
Note
----
The function will be invoked (1 + number x repeat) times,
with the first call discarded in case there is lazy initialization.
Returns
-------
ftimer : Function
The function that takes same argument as func and returns a ProfileResult.
The ProfileResult reports `repeat` time costs in seconds.
"""
try:
feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number, repeat, min_repeat_ms)
def evaluator(*args):
"""Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future.
blob = feval(*args)
fmt = "@" + ("d" * repeat)
results = struct.unpack(fmt, blob)
mean = sum(results) / float(repeat)
return ProfileResult(mean=mean, results=results)
return evaluator
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")
def _collect_dso_modules(self):
"""Helper function to collect dso modules, then return it."""
visited, stack, dso_modules = set(), [], []
# append root module
visited.add(self)
stack.append(self)
while stack:
module = stack.pop()
if module._dso_exportable():
dso_modules.append(module)
for m in module.imported_modules:
if m not in visited:
visited.add(m)
stack.append(m)
return dso_modules
def _dso_exportable(self):
return self.type_key == "llvm" or self.type_key == "c"
def system_lib():
"""Get system-wide library module singleton.
......@@ -296,9 +380,13 @@ def load(path, fmt=""):
# High level handling for .o and .tar file.
# We support this to be consistent with RPC module load.
if path.endswith(".o"):
# Extra dependencies during runtime.
from tvm.contrib import cc as _cc
_cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
# Extra dependencies during runtime.
from tvm.contrib import cc as _cc, util as _util, tar as _tar
tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
_tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
......@@ -333,5 +421,6 @@ def enabled(target):
return _Enabled(target)
tvm._ffi._init_api("tvm.module")
_set_class_module(Module)
tvm._ffi._init_api("tvm.module", "tvm.runtime.module")
......@@ -18,20 +18,176 @@
"""Runtime NDArray api"""
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
import tvm._ffi
from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
from tvm._ffi.runtime_ctypes import DataType, TVMContext, TVMArray, TVMArrayHandle
from tvm._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
from tvm._ffi._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from tvm._ffi._cy3.core import NDArrayBase
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from tvm._ffi._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from tvm._ffi._ctypes.ndarray import NDArrayBase
@tvm._ffi.register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container (a buffer object)
No arthimetic operations are defined.
All operations are performed by TVM functions.
The goal is not to re-build yet another array library.
Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers.
"""
@property
def dtype(self):
"""Type of this array"""
return str(self.handle.contents.dtype)
@property
def ctx(self):
"""context of this array"""
return self.handle.contents.ctx
@property
def context(self):
"""context of this array"""
return self.ctx
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Check object identity equality
Parameters
----------
other : object
The other object to compare to
Returns
-------
same : bool
Whether other is same as self.
"""
if not isinstance(other, NDArrayBase):
return False
return self.__hash__() == other.__hash__()
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
in_slice.start is not None
or in_slice.stop is not None):
raise ValueError('Array only support set from numpy array')
if isinstance(value, NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
def copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
Parameters
----------
source_array : array_like
The data source we should like to copy from.
Returns
-------
arr : NDArray
Reference to self.
"""
if isinstance(source_array, NDArrayBase):
source_array.copyto(self)
return self
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
t = DataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
def __str__(self):
return str(self.asnumpy())
def asnumpy(self):
"""Convert this array to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
t = DataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr
def copyto(self, target):
"""Copy array to target
Parameters
----------
target : NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, NDArrayBase):
return self._copyto(target)
elif isinstance(target, TVMContext):
res = empty(self.shape, self.dtype, target)
return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target)))
def context(dev_type, dev_id=0):
......@@ -82,7 +238,7 @@ def numpyasarray(np_data):
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(np.dtype(data.dtype).name)
arr.dtype = DataType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = context(1, 0)
......@@ -111,7 +267,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
shape = c_array(tvm_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = TVMArrayHandle()
dtype = TVMType(dtype)
dtype = DataType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim,
ctypes.c_int(dtype.type_code),
......@@ -142,145 +298,193 @@ def from_dlpack(dltensor):
return _from_dlpack(dltensor)
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
def cpu(dev_id=0):
"""Construct a CPU device
@property
def dtype(self):
"""Type of this array"""
return str(self.handle.contents.dtype)
Parameters
----------
dev_id : int, optional
The integer device id
@property
def ctx(self):
"""context of this array"""
return self.handle.contents.ctx
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
@property
def context(self):
"""context of this array"""
return self.ctx
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def gpu(dev_id=0):
"""Construct a CPU device
def __eq__(self, other):
return self.same_as(other)
Parameters
----------
dev_id : int, optional
The integer device id
def __ne__(self, other):
return not self.__eq__(other)
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def same_as(self, other):
"""Check object identity equality
def rocm(dev_id=0):
"""Construct a ROCM device
Parameters
----------
other : object
The other object to compare to
dev_id : int, optional
The integer device id
Returns
-------
same : bool
Whether other is same as self.
ctx : TVMContext
The created context
"""
if not isinstance(other, NDArrayBase):
return False
return self.__hash__() == other.__hash__()
return TVMContext(10, dev_id)
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
in_slice.start is not None
or in_slice.stop is not None):
raise ValueError('Array only support set from numpy array')
if isinstance(value, NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
def copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
def opencl(dev_id=0):
"""Construct a OpenCL device
Parameters
----------
source_array : array_like
The data source we should like to copy from.
dev_id : int, optional
The integer device id
Returns
-------
arr : NDArray
Reference to self.
ctx : TVMContext
The created context
"""
if isinstance(source_array, NDArrayBase):
source_array.copyto(self)
return self
return TVMContext(4, dev_id)
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
def metal(dev_id=0):
"""Construct a metal device
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
Parameters
----------
dev_id : int, optional
The integer device id
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(8, dev_id)
def __str__(self):
return str(self.asnumpy())
def asnumpy(self):
"""Convert this array to numpy array
def vpi(dev_id=0):
"""Construct a VPI simulated device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
ctx : TVMContext
The created context
"""
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr
return TVMContext(9, dev_id)
def copyto(self, target):
"""Copy array to target
def vulkan(dev_id=0):
"""Construct a Vulkan device
Parameters
----------
target : NDArray
The target array to be copied, must have same shape as this array.
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
if isinstance(target, NDArrayBase):
return self._copyto(target)
elif isinstance(target, TVMContext):
res = empty(self.shape, self.dtype, target)
return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target)))
return TVMContext(7, dev_id)
def opengl(dev_id=0):
"""Construct a OpenGL device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(11, dev_id)
def ext_dev(dev_id=0):
"""Construct a extension device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)
def micro_dev(dev_id=0):
"""Construct a micro device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(13, dev_id)
cl = opencl
mtl = metal
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (np.ndarray, NDArray)):
arr = np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
# Register back to FFI
_set_class_ndarray(NDArray)
......@@ -17,19 +17,20 @@
# pylint: disable=invalid-name, unused-import
"""Runtime Object API"""
import ctypes
from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
try:
# pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _set_class_object, _set_class_object_generic
from ._cy3.core import ObjectBase
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from ._ctypes.packed_func import _set_class_object, _set_class_object_generic
from ._ctypes.object import ObjectBase
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase
def _new_object(cls):
......@@ -91,44 +92,4 @@ class Object(ObjectBase):
return self.__hash__() == other.__hash__()
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
_set_class_object(Object)
......@@ -15,15 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Common implementation of object generic related logic"""
# pylint: disable=unused-import
# pylint: disable=unused-import, invalid-name
from numbers import Number, Integral
from .. import _api_internal
from tvm._ffi.base import string_types
from .base import string_types
from .. import _api_internal
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import ModuleBase
from .module import Module
class ObjectGeneric(object):
......@@ -33,7 +33,7 @@ class ObjectGeneric(object):
raise NotImplementedError()
_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase)
ObjectTypes = (ObjectBase, NDArrayBase, Module)
def convert_to_object(value):
......@@ -49,7 +49,7 @@ def convert_to_object(value):
obj : Object
The corresponding object value.
"""
if isinstance(value, _CLASS_OBJECTS):
if isinstance(value, ObjectTypes):
return value
if isinstance(value, bool):
return const(value, 'uint1x1')
......@@ -63,7 +63,7 @@ def convert_to_object(value):
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECTS) and
if (not isinstance(item[0], ObjectTypes) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
......
......@@ -18,20 +18,20 @@
# pylint: disable=invalid-name, unused-import
"""Packed Function namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types, _FFI_MODE
from tvm._ffi.base import _LIB, check_call, c_str, string_types, _FFI_MODE
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _set_class_packed_func, _set_class_module
from ._cy3.core import PackedFuncBase
from ._cy3.core import convert_to_tvm_func
from tvm._ffi._cy3.core import _set_class_packed_func, _set_class_module
from tvm._ffi._cy3.core import PackedFuncBase
from tvm._ffi._cy3.core import convert_to_tvm_func
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position
from ._ctypes.packed_func import _set_class_packed_func, _set_class_module
from ._ctypes.packed_func import PackedFuncBase
from ._ctypes.packed_func import convert_to_tvm_func
from tvm._ffi._ctypes.packed_func import _set_class_packed_func, _set_class_module
from tvm._ffi._ctypes.packed_func import PackedFuncBase
from tvm._ffi._ctypes.packed_func import convert_to_tvm_func
PackedFuncHandle = ctypes.c_void_p
......
......@@ -17,9 +17,8 @@
"""The computation schedule api of TVM."""
import tvm._ffi
from ._ffi.base import string_types
from ._ffi.object import Object
from ._ffi.object_generic import convert
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from . import _api_internal
from . import tensor as _tensor
......
......@@ -30,7 +30,8 @@ Each statement node have subfields that can be visited from python side.
assert(st.buffer_var == a)
"""
import tvm._ffi
from ._ffi.object import Object
from tvm.runtime import Object
from . import make as _make
......
......@@ -57,8 +57,8 @@ We can also use other specific function in this module to create specific target
import warnings
import tvm._ffi
from tvm.runtime import Object
from ._ffi.base import _LIB_NAME
from ._ffi.object import Object
from . import _api_internal
try:
......
......@@ -18,8 +18,7 @@
# pylint: disable=invalid-name
import tvm._ffi
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric, convert_to_object
from tvm.runtime import Object, ObjectGeneric, convert_to_object
from . import _api_internal
from . import make as _make
......@@ -129,7 +128,6 @@ class Tensor(Object, _expr.ExprOp):
return "%s.v%d" % (op.name, self.value_index)
class Operation(Object):
"""Represent an operation that generates a tensor"""
......
......@@ -17,6 +17,7 @@
"""Tensor intrinsics"""
import tvm._ffi
from tvm.runtime import Object
from . import _api_internal
from . import api as _api
from . import expr as _expr
......@@ -25,7 +26,6 @@ from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.object import Object
def _get_region(tslice):
......
......@@ -16,7 +16,7 @@
# under the License.
import tvm
import tvm.contrib.sparse as tvmsp
import tvm.ndarray as _nd
import tvm.runtime.ndarray as _nd
import numpy as np
from collections import namedtuple
......
......@@ -132,7 +132,7 @@ def test_comments():
def test_int_literal():
assert isinstance(parse_text("1"), relay.Constant)
assert isinstance(parse_text("1").data, tvm.ndarray.NDArray)
assert isinstance(parse_text("1").data, tvm.nd.NDArray)
assert get_scalar(parse_text("1")) == 1
assert get_scalar(parse_text("10")) == 10
......
......@@ -207,7 +207,7 @@ def test_cuda_shuffle():
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
c_ = np.zeros((64, ), dtype='int32')
ref = a_ + np.array((list(range(4))) * 16, dtype='int32')
nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
nda, ndb, ndc = [tvm.nd.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
module(nda, ndb, ndc)
tvm.testing.assert_allclose(ndc.asnumpy(), ref)
......
......@@ -657,9 +657,9 @@ def test_llvm_shuffle():
with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32'))
b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32'))
c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32'))
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
b_ = tvm.nd.array(np.arange(8, 0, -1, dtype='int32'))
c_ = tvm.nd.array(np.zeros((8, ), dtype='int32'))
module(a_, b_, c_)
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
......
......@@ -405,8 +405,8 @@ def test_math_intrin():
func = tvm.build(sch, [a8, b8])
assert func
a = numpy.arange(2, 10).astype('float32')
tvm_a = tvm.ndarray.array(a)
tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32'))
tvm_a = tvm.nd.array(a)
tvm_b = tvm.nd.array(numpy.zeros((8, ), dtype='float32'))
b = intrin_real(a)
func(tvm_a, tvm_b)
tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5)
......@@ -423,8 +423,8 @@ def test_math_intrin():
func = tvm.build(sch, [a1, b1])
assert func
a = numpy.array([114514]).astype('int32')
tvm_a = tvm.ndarray.array(a)
tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32'))
tvm_a = tvm.nd.array(a)
tvm_b = tvm.nd.array(numpy.array([0]).astype('int32'))
b = intrin_int(a)
func(tvm_a, tvm_b)
assert tvm_b.asnumpy()[0] == b[0]
......@@ -578,8 +578,8 @@ def test_const_param():
np_b = 11
np_c = numpy.zeros((11, )).astype('int32')
nd_a = tvm.ndarray.array(np_a)
nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32'))
nd_a = tvm.nd.array(np_a)
nd_c = tvm.nd.array(numpy.zeros((11, )).astype('int32'))
module(nd_a, nd_c)
ref = add_something(np_a, 11)
......@@ -614,8 +614,8 @@ def test_value_index():
np_b, np_c = kernel_a(np_a)
ref = kernel_b(np_c, np_b)
res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32'))
module(tvm.ndarray.array(np_a), res)
res = tvm.nd.array(numpy.zeros((4, 4)).astype('int32'))
module(tvm.nd.array(np_a), res)
tvm.testing.assert_allclose(res.asnumpy(), ref)
def test_func_call():
......
......@@ -28,7 +28,7 @@ def test_shared_memory():
N = 1024
M = 128
tvm_type = tvm.datatype._TVMType(dtype)
tvm_type = tvm.runtime.DataType(dtype)
type_size = tvm_type.bits // 8 * tvm_type.lanes
A = tvm.placeholder((N,), name='A', dtype=dtype)
......
......@@ -444,7 +444,7 @@ def test_reduction_and_dummy_fuse_split():
axo, axi = s[Y.op].split(ax, nparts=20)
f = tvm.build(s, [Y, X])
args = [tvm.nd.empty((), 'int32')] + [tvm.ndarray.array(np.ones((n,), dtype='int32'))]
args = [tvm.nd.empty((), 'int32')] + [tvm.nd.array(np.ones((n,), dtype='int32'))]
f(*args)
assert args[0].asnumpy() == n
......@@ -456,8 +456,8 @@ def test_reduction_and_dummy_fuse_split():
ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis)))
f = tvm.build(s, [Y, X])
args = [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + \
[tvm.ndarray.array(np.ones((n,), dtype='int32'))]
args = [tvm.nd.array(np.ones((n,), dtype='int32'))] + \
[tvm.nd.array(np.ones((n,), dtype='int32'))]
f(*args)
assert np.all(args[0].asnumpy() == n)
......
......@@ -231,8 +231,8 @@ def test_sparse_dense_csr():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm)
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np), tvm.nd.array(W_sp_np.data), tvm.nd.array(W_sp_np.indices), tvm.nd.array(W_sp_np.indptr), Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
def test_sparse_transpose_csr():
......@@ -252,11 +252,11 @@ def test_sparse_transpose_csr():
func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr])
X_T_data_tvm = tvm.ndarray.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
X_T_indices_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
X_T_indptr_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
X_T_data_tvm = tvm.nd.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
X_T_indices_tvm = tvm.nd.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
X_T_indptr_tvm = tvm.nd.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
func(tvm.ndarray.array(X_sp.data), tvm.ndarray.array(X_sp.indices), tvm.ndarray.array(X_sp.indptr),
func(tvm.nd.array(X_sp.data), tvm.nd.array(X_sp.indices), tvm.nd.array(X_sp.indptr),
X_T_data_tvm, X_T_indices_tvm, X_T_indptr_tvm)
X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense()
......@@ -295,11 +295,11 @@ def test_sparse_dense_bsr():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np),
tvm.ndarray.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr),
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np),
tvm.nd.array(W_sp_np.data),
tvm.nd.array(W_sp_np.indices),
tvm.nd.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
......@@ -324,11 +324,11 @@ def test_sparse_dense_bsr_randomized():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np),
tvm.ndarray.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr),
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np),
tvm.nd.array(W_sp_np.data),
tvm.nd.array(W_sp_np.indices),
tvm.nd.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment