Unverified Commit fc7dd6d7 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] Establish tvm.runtime (#4818)

* [REFACTOR][PY] Establish tvm.runtime

This PR establishes the tvm.runtime namespace that contains the core runtime data structures.
The top-level API are kept inact for now via re-exporting.

We will followup later to cleanup some of the top-level APIs.

* Fix ndarray name
parent 7d263c31
......@@ -15,19 +15,29 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation."""
"""TVM: Open Deep Learning Compiler Stack."""
import multiprocessing
import sys
import traceback
# import ffi related features
# top-level alias
# tvm._ffi
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object
# top-level alias
# tvm.runtime
from .runtime.object import Object
from .runtime.packed_func import PackedFunc as Function
from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import module
from .runtime import ndarray
# pylint: disable=reimported
from .runtime import ndarray as nd
# others
from . import tensor
from . import arith
from . import expr
......@@ -37,7 +47,7 @@ from . import ir_pass
from . import codegen
from . import container
from . import schedule
from . import module
from . import attrs
from . import ir_builder
from . import target
......@@ -47,9 +57,6 @@ from . import testing
from . import error
from . import datatype
from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .api import *
from .intrin import *
......
......@@ -23,7 +23,7 @@ from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from ..runtime_ctypes import DataType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
......@@ -132,7 +132,7 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext):
......
......@@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
from ..runtime_ctypes import DataType, TVMContext, TVMByteArray
cdef void tvm_callback_finalize(void* fhandle):
......@@ -129,7 +129,7 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kTVMStr
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime Module namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = PackedFuncHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return PackedFunc(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
......@@ -48,7 +48,7 @@ class TVMByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
class DataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
......@@ -60,7 +60,7 @@ class TVMType(ctypes.Structure):
4 : 'handle'
}
def __init__(self, type_str):
super(TVMType, self).__init__()
super(DataType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
......@@ -104,8 +104,8 @@ class TVMType(ctypes.Structure):
def __repr__(self):
if self.bits == 1 and self.lanes == 1:
return "bool"
if self.type_code in TVMType.CODE2STR:
type_name = TVMType.CODE2STR[self.type_code]
if self.type_code in DataType.CODE2STR:
type_name = DataType.CODE2STR[self.type_code]
else:
type_name = "custom[%s]" % \
_api_internal._datatype_get_type_name(self.type_code)
......@@ -263,7 +263,7 @@ class TVMArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("dtype", DataType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
......
......@@ -20,10 +20,10 @@ from numbers import Integral as _Integral
import tvm._ffi
from tvm.runtime import convert, const, DataType
from ._ffi.base import string_types, TVMError
from ._ffi.object_generic import convert, const
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal
from . import make as _make
from . import expr as _expr
......
......@@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object
from ._ffi.object import Object
from . import _api_internal
class IntSet(Object):
......
......@@ -17,7 +17,7 @@
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
import tvm._ffi
from ._ffi.object import Object
from tvm.runtime import Object
from . import _api_internal
......
......@@ -22,7 +22,7 @@ LoweredFunc and compiled Module.
import warnings
import tvm._ffi
from ._ffi.object import Object
from tvm.runtime import Object, ndarray
from . import api
from . import _api_internal
from . import tensor
......@@ -33,7 +33,6 @@ from . import stmt as _stmt
from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target
from . import make
......
......@@ -17,9 +17,9 @@
"""Container data structures used in TVM DSL."""
import tvm._ffi
from tvm import ndarray as _nd
from tvm.runtime import Object, ObjectTypes
from tvm.runtime.container import getitem_helper
from . import _api_internal
from ._ffi.object import Object, getitem_helper
@tvm._ffi.register_object
......@@ -31,23 +31,9 @@ class Array(Object):
to Array during tvm function call.
You may get Array in return values of TVM function call.
"""
def __getitem__(self, i):
if isinstance(i, slice):
start = i.start if i.start is not None else 0
stop = i.stop if i.stop is not None else len(self)
step = i.step if i.step is not None else 1
if start < 0:
start += len(self)
if stop < 0:
stop += len(self)
return [self[idx] for idx in range(start, stop, step)]
if i < -len(self) or i >= len(self):
raise IndexError("Array index out of range. Array size: {}, got index {}"
.format(len(self), i))
if i < 0:
i += len(self)
return _api_internal._ArrayGetItem(self, i)
def __getitem__(self, idx):
return getitem_helper(
self, _api_internal._ArrayGetItem, len(self), idx)
def __len__(self):
return _api_internal._ArraySize(self)
......@@ -133,7 +119,7 @@ class ADT(Object):
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
......@@ -164,7 +150,7 @@ def tuple_object(fields=None):
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
......
......@@ -23,7 +23,7 @@ import tvm._ffi
from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime
from tvm.ndarray import array
from tvm.runtime.ndarray import array
from . import debug_result
_DUMP_ROOT_PREFIX = "tvmdbg_"
......
......@@ -21,8 +21,9 @@ from __future__ import absolute_import as _abs
import subprocess
import os
import warnings
from tvm.runtime import ndarray as nd
from . import util
from .. import ndarray as nd
from ..api import register_func
from .._ffi.base import py_str
......
......@@ -20,7 +20,7 @@ import tvm._ffi
from . import make as _make
from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from ._ffi.runtime_ctypes import TVMType as _TVMType
from ._ffi.runtime_ctypes import DataType
from . import _api_internal
......@@ -131,7 +131,7 @@ def create_lower_func(extern_func_name):
width as the custom type is returned. Otherwise, the type is
unchanged."""
dtype = op.dtype
t = _TVMType(dtype)
t = DataType(dtype)
if get_type_registered(t.type_code):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
......
......@@ -31,12 +31,9 @@ For example, you can use addexp.a to get the left operand of an Add node.
assert(y.a == x)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
from . import _api_internal
......@@ -52,7 +49,7 @@ def _dtype_is_int(value):
if isinstance(value, int):
return True
return (isinstance(value, ExprOp) and
TVMType(value.dtype).type_code == TypeCode.INT)
DataType(value.dtype).type_code == TypeCode.INT)
class ExprOp(object):
......
......@@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Developer API of IR node builder make function."""
from __future__ import absolute_import as _abs
from tvm.runtime import ObjectGeneric, DataType
from ._ffi.base import string_types
from . import api as _api
from . import stmt as _stmt
......@@ -23,9 +25,6 @@ from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from . import container as _container
from ._ffi.base import string_types
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call
class WithScope(object):
......@@ -78,7 +77,7 @@ class BufferVar(ObjectGeneric):
return self._content_type
def __getitem__(self, index):
t = TVMType(self._content_type)
t = DataType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index)
......@@ -89,7 +88,7 @@ class BufferVar(ObjectGeneric):
raise ValueError(
"data type does not match content type %s vs %s" % (
value.dtype, self._content_type))
t = TVMType(self._content_type)
t = DataType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index))
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime NDArray API.
tvm.ndarray provides a minimum runtime array API to test
the correctness of the program.
"""
# pylint: disable=invalid-name,unused-import
import tvm._ffi
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
@tvm._ffi.register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container (a buffer object)
No arthimetic operations are defined.
All operations are performed by TVM functions.
The goal is not to re-build yet another array library.
Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers.
"""
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def rocm(dev_id=0):
"""Construct a ROCM device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(10, dev_id)
def opencl(dev_id=0):
"""Construct a OpenCL device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(4, dev_id)
def metal(dev_id=0):
"""Construct a metal device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(8, dev_id)
def vpi(dev_id=0):
"""Construct a VPI simulated device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(9, dev_id)
def vulkan(dev_id=0):
"""Construct a Vulkan device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(7, dev_id)
def opengl(dev_id=0):
"""Construct a OpenGL device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(11, dev_id)
def ext_dev(dev_id=0):
"""Construct a extension device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)
def micro_dev(dev_id=0):
"""Construct a micro device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(13, dev_id)
cl = opencl
mtl = metal
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
_set_class_ndarray(NDArray)
......@@ -33,9 +33,7 @@ To connect to the graph runtime, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
contrib.graph_runtime or any other TVM runtime compatible systems.
"""
from __future__ import absolute_import
from tvm.ndarray import empty
from tvm.runtime.ndarray import empty
from tvm.relay import _build_module
from tvm import target as _target
from tvm import expr as _expr
......
......@@ -23,9 +23,9 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np
import tvm
import tvm.ndarray as _nd
import tvm.runtime.ndarray as _nd
from tvm.runtime import Object
from tvm import autotvm, container
from tvm._ffi.object import Object
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
......
......@@ -18,12 +18,11 @@
"""The base node types for the Relay language."""
import tvm._ffi
from .._ffi.object import Object
from tvm.runtime import Object
from . import _make
from . import _expr
from . import _base
Object = Object
def register_relay_node(type_key=None):
"""Register a Relay node type.
......
......@@ -20,14 +20,14 @@ from __future__ import absolute_import
from numbers import Number as _Number
import numpy as _np
from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
from ..ndarray import NDArray
# will be registered afterwards
_op_make = None
......
......@@ -23,7 +23,7 @@ from .expr_functor import ExprMutator
from .scope_builder import ScopeBuilder
from . import transform
from . import op, ty, expr
from .. import TVMType, register_func
from .. import DataType, register_func
from .backend import compile_engine
......@@ -109,7 +109,7 @@ class ManifestAllocPass(ExprMutator):
return expr.Tuple(new_fields)
def compute_alignment(self, dtype):
dtype = TVMType(dtype)
dtype = DataType(dtype)
align = (dtype.bits // 8) * dtype.lanes
# MAGIC CONSTANT FROM device_api.h
if align < 64:
......@@ -118,7 +118,7 @@ class ManifestAllocPass(ExprMutator):
return expr.const(align, dtype="int64")
def compute_storage_in_relay(self, shape, dtype):
dtype = TVMType(dtype)
dtype = DataType(dtype)
els = op.prod(shape)
num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
num = num + expr.const(7, self.compute_dtype)
......@@ -126,7 +126,7 @@ class ManifestAllocPass(ExprMutator):
return els * (num / div)
def compute_storage(self, tensor_type):
dtype = TVMType(tensor_type.dtype)
dtype = DataType(tensor_type.dtype)
shape = [int(sh) for sh in tensor_type.shape]
size = 1
for sh in shape:
......
......@@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Annotation operations."""
from __future__ import absolute_import as _abs
from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make
from ..op import register_schedule, schedule_injective
from .... import nd as _nd
from .... import TVMContext as _TVMContext
def on_device(data, device):
"""Annotate an expression with a certain device type.
......
......@@ -16,11 +16,12 @@
# under the License.
"""Basic tensor operations."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs
from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make
from ..expr import Tuple
from ... import nd as _nd
from ... import TVMContext as _TVMContext
# We create a wrapper function for each operator in the
# python side to call into the positional _make.OpName function.
......
......@@ -22,12 +22,12 @@ import socket
import struct
import time
import tvm._ffi
from tvm.contrib import util
from tvm._ffi.base import TVMError
from tvm.runtime import ndarray as nd
from tvm.runtime import load_module as _load_module
from . import base
from ..contrib import util
from .._ffi.base import TVMError
from .._ffi import ndarray as nd
from ..module import load as _load_module
class RPCSession(object):
......
......@@ -38,10 +38,10 @@ import sys
import signal
import tvm._ffi
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
from ..contrib import util
from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
from tvm.runtime.module import load as _load_module
from tvm.contrib import util
from . import base
from . base import TrackerCode
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM runtime."""
# class exposures
from .packed_func import PackedFunc
from .object import Object
from .object_generic import ObjectGeneric, ObjectTypes
from .ndarray import NDArray, DataType, TypeCode, TVMContext
from .module import Module
# function exposures
from .object_generic import convert_to_object, convert, const
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .module import load as load_module
DataType = DataType
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Runtime container structures."""
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
......@@ -17,19 +17,20 @@
# pylint: disable=invalid-name, unused-import
"""Runtime Object API"""
import ctypes
from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
try:
# pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _set_class_object, _set_class_object_generic
from ._cy3.core import ObjectBase
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from ._ctypes.packed_func import _set_class_object, _set_class_object_generic
from ._ctypes.object import ObjectBase
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase
def _new_object(cls):
......@@ -91,44 +92,4 @@ class Object(ObjectBase):
return self.__hash__() == other.__hash__()
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
_set_class_object(Object)
......@@ -15,15 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Common implementation of object generic related logic"""
# pylint: disable=unused-import
# pylint: disable=unused-import, invalid-name
from numbers import Number, Integral
from .. import _api_internal
from tvm._ffi.base import string_types
from .base import string_types
from .. import _api_internal
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import ModuleBase
from .module import Module
class ObjectGeneric(object):
......@@ -33,7 +33,7 @@ class ObjectGeneric(object):
raise NotImplementedError()
_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase)
ObjectTypes = (ObjectBase, NDArrayBase, Module)
def convert_to_object(value):
......@@ -49,7 +49,7 @@ def convert_to_object(value):
obj : Object
The corresponding object value.
"""
if isinstance(value, _CLASS_OBJECTS):
if isinstance(value, ObjectTypes):
return value
if isinstance(value, bool):
return const(value, 'uint1x1')
......@@ -63,7 +63,7 @@ def convert_to_object(value):
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECTS) and
if (not isinstance(item[0], ObjectTypes) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
......
......@@ -18,20 +18,20 @@
# pylint: disable=invalid-name, unused-import
"""Packed Function namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types, _FFI_MODE
from tvm._ffi.base import _LIB, check_call, c_str, string_types, _FFI_MODE
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _set_class_packed_func, _set_class_module
from ._cy3.core import PackedFuncBase
from ._cy3.core import convert_to_tvm_func
from tvm._ffi._cy3.core import _set_class_packed_func, _set_class_module
from tvm._ffi._cy3.core import PackedFuncBase
from tvm._ffi._cy3.core import convert_to_tvm_func
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position
from ._ctypes.packed_func import _set_class_packed_func, _set_class_module
from ._ctypes.packed_func import PackedFuncBase
from ._ctypes.packed_func import convert_to_tvm_func
from tvm._ffi._ctypes.packed_func import _set_class_packed_func, _set_class_module
from tvm._ffi._ctypes.packed_func import PackedFuncBase
from tvm._ffi._ctypes.packed_func import convert_to_tvm_func
PackedFuncHandle = ctypes.c_void_p
......
......@@ -17,9 +17,8 @@
"""The computation schedule api of TVM."""
import tvm._ffi
from ._ffi.base import string_types
from ._ffi.object import Object
from ._ffi.object_generic import convert
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from . import _api_internal
from . import tensor as _tensor
......
......@@ -30,7 +30,8 @@ Each statement node have subfields that can be visited from python side.
assert(st.buffer_var == a)
"""
import tvm._ffi
from ._ffi.object import Object
from tvm.runtime import Object
from . import make as _make
......
......@@ -57,8 +57,8 @@ We can also use other specific function in this module to create specific target
import warnings
import tvm._ffi
from tvm.runtime import Object
from ._ffi.base import _LIB_NAME
from ._ffi.object import Object
from . import _api_internal
try:
......
......@@ -18,8 +18,7 @@
# pylint: disable=invalid-name
import tvm._ffi
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric, convert_to_object
from tvm.runtime import Object, ObjectGeneric, convert_to_object
from . import _api_internal
from . import make as _make
......@@ -129,7 +128,6 @@ class Tensor(Object, _expr.ExprOp):
return "%s.v%d" % (op.name, self.value_index)
class Operation(Object):
"""Represent an operation that generates a tensor"""
......
......@@ -17,6 +17,7 @@
"""Tensor intrinsics"""
import tvm._ffi
from tvm.runtime import Object
from . import _api_internal
from . import api as _api
from . import expr as _expr
......@@ -25,7 +26,6 @@ from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.object import Object
def _get_region(tslice):
......
......@@ -16,7 +16,7 @@
# under the License.
import tvm
import tvm.contrib.sparse as tvmsp
import tvm.ndarray as _nd
import tvm.runtime.ndarray as _nd
import numpy as np
from collections import namedtuple
......
......@@ -132,7 +132,7 @@ def test_comments():
def test_int_literal():
assert isinstance(parse_text("1"), relay.Constant)
assert isinstance(parse_text("1").data, tvm.ndarray.NDArray)
assert isinstance(parse_text("1").data, tvm.nd.NDArray)
assert get_scalar(parse_text("1")) == 1
assert get_scalar(parse_text("10")) == 10
......
......@@ -207,7 +207,7 @@ def test_cuda_shuffle():
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
c_ = np.zeros((64, ), dtype='int32')
ref = a_ + np.array((list(range(4))) * 16, dtype='int32')
nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
nda, ndb, ndc = [tvm.nd.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
module(nda, ndb, ndc)
tvm.testing.assert_allclose(ndc.asnumpy(), ref)
......
......@@ -657,9 +657,9 @@ def test_llvm_shuffle():
with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32'))
b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32'))
c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32'))
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
b_ = tvm.nd.array(np.arange(8, 0, -1, dtype='int32'))
c_ = tvm.nd.array(np.zeros((8, ), dtype='int32'))
module(a_, b_, c_)
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
......
......@@ -405,8 +405,8 @@ def test_math_intrin():
func = tvm.build(sch, [a8, b8])
assert func
a = numpy.arange(2, 10).astype('float32')
tvm_a = tvm.ndarray.array(a)
tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32'))
tvm_a = tvm.nd.array(a)
tvm_b = tvm.nd.array(numpy.zeros((8, ), dtype='float32'))
b = intrin_real(a)
func(tvm_a, tvm_b)
tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5)
......@@ -423,8 +423,8 @@ def test_math_intrin():
func = tvm.build(sch, [a1, b1])
assert func
a = numpy.array([114514]).astype('int32')
tvm_a = tvm.ndarray.array(a)
tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32'))
tvm_a = tvm.nd.array(a)
tvm_b = tvm.nd.array(numpy.array([0]).astype('int32'))
b = intrin_int(a)
func(tvm_a, tvm_b)
assert tvm_b.asnumpy()[0] == b[0]
......@@ -578,8 +578,8 @@ def test_const_param():
np_b = 11
np_c = numpy.zeros((11, )).astype('int32')
nd_a = tvm.ndarray.array(np_a)
nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32'))
nd_a = tvm.nd.array(np_a)
nd_c = tvm.nd.array(numpy.zeros((11, )).astype('int32'))
module(nd_a, nd_c)
ref = add_something(np_a, 11)
......@@ -614,8 +614,8 @@ def test_value_index():
np_b, np_c = kernel_a(np_a)
ref = kernel_b(np_c, np_b)
res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32'))
module(tvm.ndarray.array(np_a), res)
res = tvm.nd.array(numpy.zeros((4, 4)).astype('int32'))
module(tvm.nd.array(np_a), res)
tvm.testing.assert_allclose(res.asnumpy(), ref)
def test_func_call():
......
......@@ -28,7 +28,7 @@ def test_shared_memory():
N = 1024
M = 128
tvm_type = tvm.datatype._TVMType(dtype)
tvm_type = tvm.runtime.DataType(dtype)
type_size = tvm_type.bits // 8 * tvm_type.lanes
A = tvm.placeholder((N,), name='A', dtype=dtype)
......
......@@ -444,7 +444,7 @@ def test_reduction_and_dummy_fuse_split():
axo, axi = s[Y.op].split(ax, nparts=20)
f = tvm.build(s, [Y, X])
args = [tvm.nd.empty((), 'int32')] + [tvm.ndarray.array(np.ones((n,), dtype='int32'))]
args = [tvm.nd.empty((), 'int32')] + [tvm.nd.array(np.ones((n,), dtype='int32'))]
f(*args)
assert args[0].asnumpy() == n
......@@ -456,8 +456,8 @@ def test_reduction_and_dummy_fuse_split():
ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis)))
f = tvm.build(s, [Y, X])
args = [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + \
[tvm.ndarray.array(np.ones((n,), dtype='int32'))]
args = [tvm.nd.array(np.ones((n,), dtype='int32'))] + \
[tvm.nd.array(np.ones((n,), dtype='int32'))]
f(*args)
assert np.all(args[0].asnumpy() == n)
......
......@@ -231,8 +231,8 @@ def test_sparse_dense_csr():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm)
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np), tvm.nd.array(W_sp_np.data), tvm.nd.array(W_sp_np.indices), tvm.nd.array(W_sp_np.indptr), Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
def test_sparse_transpose_csr():
......@@ -252,11 +252,11 @@ def test_sparse_transpose_csr():
func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr])
X_T_data_tvm = tvm.ndarray.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
X_T_indices_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
X_T_indptr_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
X_T_data_tvm = tvm.nd.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
X_T_indices_tvm = tvm.nd.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
X_T_indptr_tvm = tvm.nd.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
func(tvm.ndarray.array(X_sp.data), tvm.ndarray.array(X_sp.indices), tvm.ndarray.array(X_sp.indptr),
func(tvm.nd.array(X_sp.data), tvm.nd.array(X_sp.indices), tvm.nd.array(X_sp.indptr),
X_T_data_tvm, X_T_indices_tvm, X_T_indptr_tvm)
X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense()
......@@ -295,11 +295,11 @@ def test_sparse_dense_bsr():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np),
tvm.ndarray.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr),
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np),
tvm.nd.array(W_sp_np.data),
tvm.nd.array(W_sp_np.indices),
tvm.nd.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
......@@ -324,11 +324,11 @@ def test_sparse_dense_bsr_randomized():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np),
tvm.ndarray.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr),
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np),
tvm.nd.array(W_sp_np.data),
tvm.nd.array(W_sp_np.indices),
tvm.nd.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment