Unverified Commit fc7dd6d7 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] Establish tvm.runtime (#4818)

* [REFACTOR][PY] Establish tvm.runtime

This PR establishes the tvm.runtime namespace that contains the core runtime data structures.
The top-level API are kept inact for now via re-exporting.

We will followup later to cleanup some of the top-level APIs.

* Fix ndarray name
parent 7d263c31
...@@ -15,19 +15,29 @@ ...@@ -15,19 +15,29 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=redefined-builtin, wildcard-import # pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation.""" """TVM: Open Deep Learning Compiler Stack."""
import multiprocessing import multiprocessing
import sys import sys
import traceback import traceback
# import ffi related features # top-level alias
# tvm._ffi
from ._ffi.base import TVMError, __version__ from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.registry import register_object, register_func, register_extension from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object
# top-level alias
# tvm.runtime
from .runtime.object import Object
from .runtime.packed_func import PackedFunc as Function
from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import module
from .runtime import ndarray
# pylint: disable=reimported
from .runtime import ndarray as nd
# others
from . import tensor from . import tensor
from . import arith from . import arith
from . import expr from . import expr
...@@ -37,7 +47,7 @@ from . import ir_pass ...@@ -37,7 +47,7 @@ from . import ir_pass
from . import codegen from . import codegen
from . import container from . import container
from . import schedule from . import schedule
from . import module
from . import attrs from . import attrs
from . import ir_builder from . import ir_builder
from . import target from . import target
...@@ -47,9 +57,6 @@ from . import testing ...@@ -47,9 +57,6 @@ from . import testing
from . import error from . import error
from . import datatype from . import datatype
from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .api import * from .api import *
from .intrin import * from .intrin import *
......
...@@ -23,7 +23,7 @@ from numbers import Number, Integral ...@@ -23,7 +23,7 @@ from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext from ..runtime_ctypes import DataType, TVMByteArray, TVMContext
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
...@@ -132,7 +132,7 @@ def _make_tvm_args(args, temp_args): ...@@ -132,7 +132,7 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, Number): elif isinstance(arg, Number):
values[i].v_float64 = arg values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType): elif isinstance(arg, DataType):
values[i].v_str = c_str(str(arg)) values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext): elif isinstance(arg, TVMContext):
......
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types, py2cerror from ..base import string_types, py2cerror
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray from ..runtime_ctypes import DataType, TVMContext, TVMByteArray
cdef void tvm_callback_finalize(void* fhandle): cdef void tvm_callback_finalize(void* fhandle):
...@@ -129,7 +129,7 @@ cdef inline int make_arg(object arg, ...@@ -129,7 +129,7 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number): elif isinstance(arg, Number):
value[0].v_float64 = arg value[0].v_float64 = arg
tcode[0] = kFloat tcode[0] = kFloat
elif isinstance(arg, TVMType): elif isinstance(arg, DataType):
tstr = c_str(str(arg)) tstr = c_str(str(arg))
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kTVMStr tcode[0] = kTVMStr
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime Module namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = PackedFuncHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return PackedFunc(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
...@@ -48,7 +48,7 @@ class TVMByteArray(ctypes.Structure): ...@@ -48,7 +48,7 @@ class TVMByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)] ("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure): class DataType(ctypes.Structure):
"""TVM datatype structure""" """TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8), _fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8), ("bits", ctypes.c_uint8),
...@@ -60,7 +60,7 @@ class TVMType(ctypes.Structure): ...@@ -60,7 +60,7 @@ class TVMType(ctypes.Structure):
4 : 'handle' 4 : 'handle'
} }
def __init__(self, type_str): def __init__(self, type_str):
super(TVMType, self).__init__() super(DataType, self).__init__()
if isinstance(type_str, np.dtype): if isinstance(type_str, np.dtype):
type_str = str(type_str) type_str = str(type_str)
...@@ -104,8 +104,8 @@ class TVMType(ctypes.Structure): ...@@ -104,8 +104,8 @@ class TVMType(ctypes.Structure):
def __repr__(self): def __repr__(self):
if self.bits == 1 and self.lanes == 1: if self.bits == 1 and self.lanes == 1:
return "bool" return "bool"
if self.type_code in TVMType.CODE2STR: if self.type_code in DataType.CODE2STR:
type_name = TVMType.CODE2STR[self.type_code] type_name = DataType.CODE2STR[self.type_code]
else: else:
type_name = "custom[%s]" % \ type_name = "custom[%s]" % \
_api_internal._datatype_get_type_name(self.type_code) _api_internal._datatype_get_type_name(self.type_code)
...@@ -263,7 +263,7 @@ class TVMArray(ctypes.Structure): ...@@ -263,7 +263,7 @@ class TVMArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p), _fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext), ("ctx", TVMContext),
("ndim", ctypes.c_int), ("ndim", ctypes.c_int),
("dtype", TVMType), ("dtype", DataType),
("shape", ctypes.POINTER(tvm_shape_index_t)), ("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)), ("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
......
...@@ -20,10 +20,10 @@ from numbers import Integral as _Integral ...@@ -20,10 +20,10 @@ from numbers import Integral as _Integral
import tvm._ffi import tvm._ffi
from tvm.runtime import convert, const, DataType
from ._ffi.base import string_types, TVMError from ._ffi.base import string_types, TVMError
from ._ffi.object_generic import convert, const
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Arithmetic data structure and utility""" """Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
import tvm._ffi import tvm._ffi
from tvm.runtime import Object
from ._ffi.object import Object
from . import _api_internal from . import _api_internal
class IntSet(Object): class IntSet(Object):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
""" TVM Attribute module, which is mainly used for defining attributes of operators""" """ TVM Attribute module, which is mainly used for defining attributes of operators"""
import tvm._ffi import tvm._ffi
from ._ffi.object import Object from tvm.runtime import Object
from . import _api_internal from . import _api_internal
......
...@@ -22,7 +22,7 @@ LoweredFunc and compiled Module. ...@@ -22,7 +22,7 @@ LoweredFunc and compiled Module.
import warnings import warnings
import tvm._ffi import tvm._ffi
from ._ffi.object import Object from tvm.runtime import Object, ndarray
from . import api from . import api
from . import _api_internal from . import _api_internal
from . import tensor from . import tensor
...@@ -33,7 +33,6 @@ from . import stmt as _stmt ...@@ -33,7 +33,6 @@ from . import stmt as _stmt
from . import container from . import container
from . import module from . import module
from . import codegen from . import codegen
from . import ndarray
from . import target as _target from . import target as _target
from . import make from . import make
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
"""Container data structures used in TVM DSL.""" """Container data structures used in TVM DSL."""
import tvm._ffi import tvm._ffi
from tvm import ndarray as _nd from tvm.runtime import Object, ObjectTypes
from tvm.runtime.container import getitem_helper
from . import _api_internal from . import _api_internal
from ._ffi.object import Object, getitem_helper
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -31,23 +31,9 @@ class Array(Object): ...@@ -31,23 +31,9 @@ class Array(Object):
to Array during tvm function call. to Array during tvm function call.
You may get Array in return values of TVM function call. You may get Array in return values of TVM function call.
""" """
def __getitem__(self, i): def __getitem__(self, idx):
if isinstance(i, slice): return getitem_helper(
start = i.start if i.start is not None else 0 self, _api_internal._ArrayGetItem, len(self), idx)
stop = i.stop if i.stop is not None else len(self)
step = i.step if i.step is not None else 1
if start < 0:
start += len(self)
if stop < 0:
stop += len(self)
return [self[idx] for idx in range(start, stop, step)]
if i < -len(self) or i >= len(self):
raise IndexError("Array index out of range. Array size: {}, got index {}"
.format(len(self), i))
if i < 0:
i += len(self)
return _api_internal._ArrayGetItem(self, i)
def __len__(self): def __len__(self):
return _api_internal._ArraySize(self) return _api_internal._ArraySize(self)
...@@ -133,7 +119,7 @@ class ADT(Object): ...@@ -133,7 +119,7 @@ class ADT(Object):
""" """
def __init__(self, tag, fields): def __init__(self, tag, fields):
for f in fields: for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \ assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f)) "tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields) self.__init_handle_by_constructor__(_ADT, tag, *fields)
...@@ -164,7 +150,7 @@ def tuple_object(fields=None): ...@@ -164,7 +150,7 @@ def tuple_object(fields=None):
""" """
fields = fields if fields else [] fields = fields if fields else []
for f in fields: for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \ assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f)) "NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields) return _Tuple(*fields)
......
...@@ -23,7 +23,7 @@ import tvm._ffi ...@@ -23,7 +23,7 @@ import tvm._ffi
from tvm._ffi.base import string_types from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.ndarray import array from tvm.runtime.ndarray import array
from . import debug_result from . import debug_result
_DUMP_ROOT_PREFIX = "tvmdbg_" _DUMP_ROOT_PREFIX = "tvmdbg_"
......
...@@ -21,8 +21,9 @@ from __future__ import absolute_import as _abs ...@@ -21,8 +21,9 @@ from __future__ import absolute_import as _abs
import subprocess import subprocess
import os import os
import warnings import warnings
from tvm.runtime import ndarray as nd
from . import util from . import util
from .. import ndarray as nd
from ..api import register_func from ..api import register_func
from .._ffi.base import py_str from .._ffi.base import py_str
......
...@@ -20,7 +20,7 @@ import tvm._ffi ...@@ -20,7 +20,7 @@ import tvm._ffi
from . import make as _make from . import make as _make
from .api import convert from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from ._ffi.runtime_ctypes import TVMType as _TVMType from ._ffi.runtime_ctypes import DataType
from . import _api_internal from . import _api_internal
...@@ -131,7 +131,7 @@ def create_lower_func(extern_func_name): ...@@ -131,7 +131,7 @@ def create_lower_func(extern_func_name):
width as the custom type is returned. Otherwise, the type is width as the custom type is returned. Otherwise, the type is
unchanged.""" unchanged."""
dtype = op.dtype dtype = op.dtype
t = _TVMType(dtype) t = DataType(dtype)
if get_type_registered(t.type_code): if get_type_registered(t.type_code):
dtype = "uint" + str(t.bits) dtype = "uint" + str(t.bits)
if t.lanes > 1: if t.lanes > 1:
......
...@@ -31,12 +31,9 @@ For example, you can use addexp.a to get the left operand of an Add node. ...@@ -31,12 +31,9 @@ For example, you can use addexp.a to get the left operand of an Add node.
assert(y.a == x) assert(y.a == x)
""" """
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
import tvm._ffi import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make from . import make as _make
from . import generic as _generic from . import generic as _generic
from . import _api_internal from . import _api_internal
...@@ -52,7 +49,7 @@ def _dtype_is_int(value): ...@@ -52,7 +49,7 @@ def _dtype_is_int(value):
if isinstance(value, int): if isinstance(value, int):
return True return True
return (isinstance(value, ExprOp) and return (isinstance(value, ExprOp) and
TVMType(value.dtype).type_code == TypeCode.INT) DataType(value.dtype).type_code == TypeCode.INT)
class ExprOp(object): class ExprOp(object):
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Developer API of IR node builder make function.""" """Developer API of IR node builder make function."""
from __future__ import absolute_import as _abs from tvm.runtime import ObjectGeneric, DataType
from ._ffi.base import string_types
from . import api as _api from . import api as _api
from . import stmt as _stmt from . import stmt as _stmt
...@@ -23,9 +25,6 @@ from . import expr as _expr ...@@ -23,9 +25,6 @@ from . import expr as _expr
from . import make as _make from . import make as _make
from . import ir_pass as _pass from . import ir_pass as _pass
from . import container as _container from . import container as _container
from ._ffi.base import string_types
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call from .expr import Call as _Call
class WithScope(object): class WithScope(object):
...@@ -78,7 +77,7 @@ class BufferVar(ObjectGeneric): ...@@ -78,7 +77,7 @@ class BufferVar(ObjectGeneric):
return self._content_type return self._content_type
def __getitem__(self, index): def __getitem__(self, index):
t = TVMType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes) index = _make.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index) return _make.Load(self._content_type, self._buffer_var, index)
...@@ -89,7 +88,7 @@ class BufferVar(ObjectGeneric): ...@@ -89,7 +88,7 @@ class BufferVar(ObjectGeneric):
raise ValueError( raise ValueError(
"data type does not match content type %s vs %s" % ( "data type does not match content type %s vs %s" % (
value.dtype, self._content_type)) value.dtype, self._content_type))
t = TVMType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes) index = _make.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index)) self._builder.emit(_make.Store(self._buffer_var, value, index))
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime NDArray API.
tvm.ndarray provides a minimum runtime array API to test
the correctness of the program.
"""
# pylint: disable=invalid-name,unused-import
import tvm._ffi
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
@tvm._ffi.register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container (a buffer object)
No arthimetic operations are defined.
All operations are performed by TVM functions.
The goal is not to re-build yet another array library.
Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers.
"""
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def rocm(dev_id=0):
"""Construct a ROCM device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(10, dev_id)
def opencl(dev_id=0):
"""Construct a OpenCL device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(4, dev_id)
def metal(dev_id=0):
"""Construct a metal device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(8, dev_id)
def vpi(dev_id=0):
"""Construct a VPI simulated device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(9, dev_id)
def vulkan(dev_id=0):
"""Construct a Vulkan device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(7, dev_id)
def opengl(dev_id=0):
"""Construct a OpenGL device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(11, dev_id)
def ext_dev(dev_id=0):
"""Construct a extension device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)
def micro_dev(dev_id=0):
"""Construct a micro device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(13, dev_id)
cl = opencl
mtl = metal
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
_set_class_ndarray(NDArray)
...@@ -33,9 +33,7 @@ To connect to the graph runtime, we use a printer that converts our graph format ...@@ -33,9 +33,7 @@ To connect to the graph runtime, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by into TVM's JSON format. The resulting string can be loaded by
contrib.graph_runtime or any other TVM runtime compatible systems. contrib.graph_runtime or any other TVM runtime compatible systems.
""" """
from __future__ import absolute_import from tvm.runtime.ndarray import empty
from tvm.ndarray import empty
from tvm.relay import _build_module from tvm.relay import _build_module
from tvm import target as _target from tvm import target as _target
from tvm import expr as _expr from tvm import expr as _expr
......
...@@ -23,9 +23,9 @@ Implements a Python interface to compiling and executing on the Relay VM. ...@@ -23,9 +23,9 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np import numpy as np
import tvm import tvm
import tvm.ndarray as _nd import tvm.runtime.ndarray as _nd
from tvm.runtime import Object
from tvm import autotvm, container from tvm import autotvm, container
from tvm._ffi.object import Object
from tvm.relay import expr as _expr from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base from tvm._ffi import base as _base
......
...@@ -18,12 +18,11 @@ ...@@ -18,12 +18,11 @@
"""The base node types for the Relay language.""" """The base node types for the Relay language."""
import tvm._ffi import tvm._ffi
from .._ffi.object import Object from tvm.runtime import Object
from . import _make from . import _make
from . import _expr from . import _expr
from . import _base from . import _base
Object = Object
def register_relay_node(type_key=None): def register_relay_node(type_key=None):
"""Register a Relay node type. """Register a Relay node type.
......
...@@ -20,14 +20,14 @@ from __future__ import absolute_import ...@@ -20,14 +20,14 @@ from __future__ import absolute_import
from numbers import Number as _Number from numbers import Number as _Number
import numpy as _np import numpy as _np
from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import _expr from . import _expr
from . import ty as _ty from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
from ..ndarray import NDArray
# will be registered afterwards # will be registered afterwards
_op_make = None _op_make = None
......
...@@ -23,7 +23,7 @@ from .expr_functor import ExprMutator ...@@ -23,7 +23,7 @@ from .expr_functor import ExprMutator
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
from . import transform from . import transform
from . import op, ty, expr from . import op, ty, expr
from .. import TVMType, register_func from .. import DataType, register_func
from .backend import compile_engine from .backend import compile_engine
...@@ -109,7 +109,7 @@ class ManifestAllocPass(ExprMutator): ...@@ -109,7 +109,7 @@ class ManifestAllocPass(ExprMutator):
return expr.Tuple(new_fields) return expr.Tuple(new_fields)
def compute_alignment(self, dtype): def compute_alignment(self, dtype):
dtype = TVMType(dtype) dtype = DataType(dtype)
align = (dtype.bits // 8) * dtype.lanes align = (dtype.bits // 8) * dtype.lanes
# MAGIC CONSTANT FROM device_api.h # MAGIC CONSTANT FROM device_api.h
if align < 64: if align < 64:
...@@ -118,7 +118,7 @@ class ManifestAllocPass(ExprMutator): ...@@ -118,7 +118,7 @@ class ManifestAllocPass(ExprMutator):
return expr.const(align, dtype="int64") return expr.const(align, dtype="int64")
def compute_storage_in_relay(self, shape, dtype): def compute_storage_in_relay(self, shape, dtype):
dtype = TVMType(dtype) dtype = DataType(dtype)
els = op.prod(shape) els = op.prod(shape)
num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype) num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
num = num + expr.const(7, self.compute_dtype) num = num + expr.const(7, self.compute_dtype)
...@@ -126,7 +126,7 @@ class ManifestAllocPass(ExprMutator): ...@@ -126,7 +126,7 @@ class ManifestAllocPass(ExprMutator):
return els * (num / div) return els * (num / div)
def compute_storage(self, tensor_type): def compute_storage(self, tensor_type):
dtype = TVMType(tensor_type.dtype) dtype = DataType(tensor_type.dtype)
shape = [int(sh) for sh in tensor_type.shape] shape = [int(sh) for sh in tensor_type.shape]
size = 1 size = 1
for sh in shape: for sh in shape:
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Annotation operations.""" """Annotation operations."""
from __future__ import absolute_import as _abs from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make from . import _make
from ..op import register_schedule, schedule_injective from ..op import register_schedule, schedule_injective
from .... import nd as _nd
from .... import TVMContext as _TVMContext
def on_device(data, device): def on_device(data, device):
"""Annotate an expression with a certain device type. """Annotate an expression with a certain device type.
......
...@@ -16,11 +16,12 @@ ...@@ -16,11 +16,12 @@
# under the License. # under the License.
"""Basic tensor operations.""" """Basic tensor operations."""
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make from . import _make
from ..expr import Tuple from ..expr import Tuple
from ... import nd as _nd
from ... import TVMContext as _TVMContext
# We create a wrapper function for each operator in the # We create a wrapper function for each operator in the
# python side to call into the positional _make.OpName function. # python side to call into the positional _make.OpName function.
......
...@@ -22,12 +22,12 @@ import socket ...@@ -22,12 +22,12 @@ import socket
import struct import struct
import time import time
import tvm._ffi import tvm._ffi
from tvm.contrib import util
from tvm._ffi.base import TVMError
from tvm.runtime import ndarray as nd
from tvm.runtime import load_module as _load_module
from . import base from . import base
from ..contrib import util
from .._ffi.base import TVMError
from .._ffi import ndarray as nd
from ..module import load as _load_module
class RPCSession(object): class RPCSession(object):
......
...@@ -38,10 +38,10 @@ import sys ...@@ -38,10 +38,10 @@ import sys
import signal import signal
import tvm._ffi import tvm._ffi
from .._ffi.base import py_str from tvm._ffi.base import py_str
from .._ffi.libinfo import find_lib_path from tvm._ffi.libinfo import find_lib_path
from ..module import load as _load_module from tvm.runtime.module import load as _load_module
from ..contrib import util from tvm.contrib import util
from . import base from . import base
from . base import TrackerCode from . base import TrackerCode
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM runtime."""
# class exposures
from .packed_func import PackedFunc
from .object import Object
from .object_generic import ObjectGeneric, ObjectTypes
from .ndarray import NDArray, DataType, TypeCode, TVMContext
from .module import Module
# function exposures
from .object_generic import convert_to_object, convert, const
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .module import load as load_module
DataType = DataType
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Runtime container structures."""
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
...@@ -17,19 +17,20 @@ ...@@ -17,19 +17,20 @@
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import
"""Runtime Object API""" """Runtime Object API"""
import ctypes import ctypes
from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .. import _api_internal from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
try: try:
# pylint: disable=wrong-import-position,unused-import # pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
from ._cy3.core import _set_class_object, _set_class_object_generic from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from ._cy3.core import ObjectBase from tvm._ffi._cy3.core import ObjectBase
except (RuntimeError, ImportError): except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import # pylint: disable=wrong-import-position,unused-import
from ._ctypes.packed_func import _set_class_object, _set_class_object_generic from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from ._ctypes.object import ObjectBase from tvm._ffi._ctypes.object import ObjectBase
def _new_object(cls): def _new_object(cls):
...@@ -91,44 +92,4 @@ class Object(ObjectBase): ...@@ -91,44 +92,4 @@ class Object(ObjectBase):
return self.__hash__() == other.__hash__() return self.__hash__() == other.__hash__()
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
_set_class_object(Object) _set_class_object(Object)
...@@ -15,15 +15,15 @@ ...@@ -15,15 +15,15 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Common implementation of object generic related logic""" """Common implementation of object generic related logic"""
# pylint: disable=unused-import # pylint: disable=unused-import, invalid-name
from numbers import Number, Integral from numbers import Number, Integral
from .. import _api_internal from tvm._ffi.base import string_types
from .base import string_types from .. import _api_internal
from .object import ObjectBase, _set_class_object_generic from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import ModuleBase from .module import Module
class ObjectGeneric(object): class ObjectGeneric(object):
...@@ -33,7 +33,7 @@ class ObjectGeneric(object): ...@@ -33,7 +33,7 @@ class ObjectGeneric(object):
raise NotImplementedError() raise NotImplementedError()
_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase) ObjectTypes = (ObjectBase, NDArrayBase, Module)
def convert_to_object(value): def convert_to_object(value):
...@@ -49,7 +49,7 @@ def convert_to_object(value): ...@@ -49,7 +49,7 @@ def convert_to_object(value):
obj : Object obj : Object
The corresponding object value. The corresponding object value.
""" """
if isinstance(value, _CLASS_OBJECTS): if isinstance(value, ObjectTypes):
return value return value
if isinstance(value, bool): if isinstance(value, bool):
return const(value, 'uint1x1') return const(value, 'uint1x1')
...@@ -63,7 +63,7 @@ def convert_to_object(value): ...@@ -63,7 +63,7 @@ def convert_to_object(value):
if isinstance(value, dict): if isinstance(value, dict):
vlist = [] vlist = []
for item in value.items(): for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECTS) and if (not isinstance(item[0], ObjectTypes) and
not isinstance(item[0], string_types)): not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type") raise ValueError("key of map must already been a container type")
vlist.append(item[0]) vlist.append(item[0])
......
...@@ -18,20 +18,20 @@ ...@@ -18,20 +18,20 @@
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import
"""Packed Function namespace.""" """Packed Function namespace."""
import ctypes import ctypes
from .base import _LIB, check_call, c_str, string_types, _FFI_MODE from tvm._ffi.base import _LIB, check_call, c_str, string_types, _FFI_MODE
try: try:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
from ._cy3.core import _set_class_packed_func, _set_class_module from tvm._ffi._cy3.core import _set_class_packed_func, _set_class_module
from ._cy3.core import PackedFuncBase from tvm._ffi._cy3.core import PackedFuncBase
from ._cy3.core import convert_to_tvm_func from tvm._ffi._cy3.core import convert_to_tvm_func
except (RuntimeError, ImportError): except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.packed_func import _set_class_packed_func, _set_class_module from tvm._ffi._ctypes.packed_func import _set_class_packed_func, _set_class_module
from ._ctypes.packed_func import PackedFuncBase from tvm._ffi._ctypes.packed_func import PackedFuncBase
from ._ctypes.packed_func import convert_to_tvm_func from tvm._ffi._ctypes.packed_func import convert_to_tvm_func
PackedFuncHandle = ctypes.c_void_p PackedFuncHandle = ctypes.c_void_p
......
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
"""The computation schedule api of TVM.""" """The computation schedule api of TVM."""
import tvm._ffi import tvm._ffi
from ._ffi.base import string_types from tvm._ffi.base import string_types
from ._ffi.object import Object from tvm.runtime import Object, convert
from ._ffi.object_generic import convert
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
......
...@@ -30,7 +30,8 @@ Each statement node have subfields that can be visited from python side. ...@@ -30,7 +30,8 @@ Each statement node have subfields that can be visited from python side.
assert(st.buffer_var == a) assert(st.buffer_var == a)
""" """
import tvm._ffi import tvm._ffi
from ._ffi.object import Object
from tvm.runtime import Object
from . import make as _make from . import make as _make
......
...@@ -57,8 +57,8 @@ We can also use other specific function in this module to create specific target ...@@ -57,8 +57,8 @@ We can also use other specific function in this module to create specific target
import warnings import warnings
import tvm._ffi import tvm._ffi
from tvm.runtime import Object
from ._ffi.base import _LIB_NAME from ._ffi.base import _LIB_NAME
from ._ffi.object import Object
from . import _api_internal from . import _api_internal
try: try:
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
# pylint: disable=invalid-name # pylint: disable=invalid-name
import tvm._ffi import tvm._ffi
from ._ffi.object import Object from tvm.runtime import Object, ObjectGeneric, convert_to_object
from ._ffi.object_generic import ObjectGeneric, convert_to_object
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
...@@ -129,7 +128,6 @@ class Tensor(Object, _expr.ExprOp): ...@@ -129,7 +128,6 @@ class Tensor(Object, _expr.ExprOp):
return "%s.v%d" % (op.name, self.value_index) return "%s.v%d" % (op.name, self.value_index)
class Operation(Object): class Operation(Object):
"""Represent an operation that generates a tensor""" """Represent an operation that generates a tensor"""
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Tensor intrinsics""" """Tensor intrinsics"""
import tvm._ffi import tvm._ffi
from tvm.runtime import Object
from . import _api_internal from . import _api_internal
from . import api as _api from . import api as _api
from . import expr as _expr from . import expr as _expr
...@@ -25,7 +26,6 @@ from . import make as _make ...@@ -25,7 +26,6 @@ from . import make as _make
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from .build_module import current_build_config from .build_module import current_build_config
from ._ffi.object import Object
def _get_region(tslice): def _get_region(tslice):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
import tvm import tvm
import tvm.contrib.sparse as tvmsp import tvm.contrib.sparse as tvmsp
import tvm.ndarray as _nd import tvm.runtime.ndarray as _nd
import numpy as np import numpy as np
from collections import namedtuple from collections import namedtuple
......
...@@ -132,7 +132,7 @@ def test_comments(): ...@@ -132,7 +132,7 @@ def test_comments():
def test_int_literal(): def test_int_literal():
assert isinstance(parse_text("1"), relay.Constant) assert isinstance(parse_text("1"), relay.Constant)
assert isinstance(parse_text("1").data, tvm.ndarray.NDArray) assert isinstance(parse_text("1").data, tvm.nd.NDArray)
assert get_scalar(parse_text("1")) == 1 assert get_scalar(parse_text("1")) == 1
assert get_scalar(parse_text("10")) == 10 assert get_scalar(parse_text("10")) == 10
......
...@@ -207,7 +207,7 @@ def test_cuda_shuffle(): ...@@ -207,7 +207,7 @@ def test_cuda_shuffle():
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32') b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
c_ = np.zeros((64, ), dtype='int32') c_ = np.zeros((64, ), dtype='int32')
ref = a_ + np.array((list(range(4))) * 16, dtype='int32') ref = a_ + np.array((list(range(4))) * 16, dtype='int32')
nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]] nda, ndb, ndc = [tvm.nd.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
module(nda, ndb, ndc) module(nda, ndb, ndc)
tvm.testing.assert_allclose(ndc.asnumpy(), ref) tvm.testing.assert_allclose(ndc.asnumpy(), ref)
......
...@@ -657,9 +657,9 @@ def test_llvm_shuffle(): ...@@ -657,9 +657,9 @@ def test_llvm_shuffle():
with tvm.build_config(add_lower_pass=[(1, my_vectorize)]): with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c]) module = tvm.build(sch, [a, b, c])
a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32')) a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32')) b_ = tvm.nd.array(np.arange(8, 0, -1, dtype='int32'))
c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32')) c_ = tvm.nd.array(np.zeros((8, ), dtype='int32'))
module(a_, b_, c_) module(a_, b_, c_)
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
......
...@@ -405,8 +405,8 @@ def test_math_intrin(): ...@@ -405,8 +405,8 @@ def test_math_intrin():
func = tvm.build(sch, [a8, b8]) func = tvm.build(sch, [a8, b8])
assert func assert func
a = numpy.arange(2, 10).astype('float32') a = numpy.arange(2, 10).astype('float32')
tvm_a = tvm.ndarray.array(a) tvm_a = tvm.nd.array(a)
tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32')) tvm_b = tvm.nd.array(numpy.zeros((8, ), dtype='float32'))
b = intrin_real(a) b = intrin_real(a)
func(tvm_a, tvm_b) func(tvm_a, tvm_b)
tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5) tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5)
...@@ -423,8 +423,8 @@ def test_math_intrin(): ...@@ -423,8 +423,8 @@ def test_math_intrin():
func = tvm.build(sch, [a1, b1]) func = tvm.build(sch, [a1, b1])
assert func assert func
a = numpy.array([114514]).astype('int32') a = numpy.array([114514]).astype('int32')
tvm_a = tvm.ndarray.array(a) tvm_a = tvm.nd.array(a)
tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32')) tvm_b = tvm.nd.array(numpy.array([0]).astype('int32'))
b = intrin_int(a) b = intrin_int(a)
func(tvm_a, tvm_b) func(tvm_a, tvm_b)
assert tvm_b.asnumpy()[0] == b[0] assert tvm_b.asnumpy()[0] == b[0]
...@@ -578,8 +578,8 @@ def test_const_param(): ...@@ -578,8 +578,8 @@ def test_const_param():
np_b = 11 np_b = 11
np_c = numpy.zeros((11, )).astype('int32') np_c = numpy.zeros((11, )).astype('int32')
nd_a = tvm.ndarray.array(np_a) nd_a = tvm.nd.array(np_a)
nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32')) nd_c = tvm.nd.array(numpy.zeros((11, )).astype('int32'))
module(nd_a, nd_c) module(nd_a, nd_c)
ref = add_something(np_a, 11) ref = add_something(np_a, 11)
...@@ -614,8 +614,8 @@ def test_value_index(): ...@@ -614,8 +614,8 @@ def test_value_index():
np_b, np_c = kernel_a(np_a) np_b, np_c = kernel_a(np_a)
ref = kernel_b(np_c, np_b) ref = kernel_b(np_c, np_b)
res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32')) res = tvm.nd.array(numpy.zeros((4, 4)).astype('int32'))
module(tvm.ndarray.array(np_a), res) module(tvm.nd.array(np_a), res)
tvm.testing.assert_allclose(res.asnumpy(), ref) tvm.testing.assert_allclose(res.asnumpy(), ref)
def test_func_call(): def test_func_call():
......
...@@ -28,7 +28,7 @@ def test_shared_memory(): ...@@ -28,7 +28,7 @@ def test_shared_memory():
N = 1024 N = 1024
M = 128 M = 128
tvm_type = tvm.datatype._TVMType(dtype) tvm_type = tvm.runtime.DataType(dtype)
type_size = tvm_type.bits // 8 * tvm_type.lanes type_size = tvm_type.bits // 8 * tvm_type.lanes
A = tvm.placeholder((N,), name='A', dtype=dtype) A = tvm.placeholder((N,), name='A', dtype=dtype)
......
...@@ -444,7 +444,7 @@ def test_reduction_and_dummy_fuse_split(): ...@@ -444,7 +444,7 @@ def test_reduction_and_dummy_fuse_split():
axo, axi = s[Y.op].split(ax, nparts=20) axo, axi = s[Y.op].split(ax, nparts=20)
f = tvm.build(s, [Y, X]) f = tvm.build(s, [Y, X])
args = [tvm.nd.empty((), 'int32')] + [tvm.ndarray.array(np.ones((n,), dtype='int32'))] args = [tvm.nd.empty((), 'int32')] + [tvm.nd.array(np.ones((n,), dtype='int32'))]
f(*args) f(*args)
assert args[0].asnumpy() == n assert args[0].asnumpy() == n
...@@ -456,8 +456,8 @@ def test_reduction_and_dummy_fuse_split(): ...@@ -456,8 +456,8 @@ def test_reduction_and_dummy_fuse_split():
ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis))) ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis)))
f = tvm.build(s, [Y, X]) f = tvm.build(s, [Y, X])
args = [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + \ args = [tvm.nd.array(np.ones((n,), dtype='int32'))] + \
[tvm.ndarray.array(np.ones((n,), dtype='int32'))] [tvm.nd.array(np.ones((n,), dtype='int32'))]
f(*args) f(*args)
assert np.all(args[0].asnumpy() == n) assert np.all(args[0].asnumpy() == n)
......
...@@ -231,8 +231,8 @@ def test_sparse_dense_csr(): ...@@ -231,8 +231,8 @@ def test_sparse_dense_csr():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op) s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm) func(tvm.nd.array(X_np), tvm.nd.array(W_sp_np.data), tvm.nd.array(W_sp_np.indices), tvm.nd.array(W_sp_np.indptr), Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
def test_sparse_transpose_csr(): def test_sparse_transpose_csr():
...@@ -252,11 +252,11 @@ def test_sparse_transpose_csr(): ...@@ -252,11 +252,11 @@ def test_sparse_transpose_csr():
func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr]) func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr])
X_T_data_tvm = tvm.ndarray.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype)) X_T_data_tvm = tvm.nd.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
X_T_indices_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype)) X_T_indices_tvm = tvm.nd.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
X_T_indptr_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype)) X_T_indptr_tvm = tvm.nd.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
func(tvm.ndarray.array(X_sp.data), tvm.ndarray.array(X_sp.indices), tvm.ndarray.array(X_sp.indptr), func(tvm.nd.array(X_sp.data), tvm.nd.array(X_sp.indices), tvm.nd.array(X_sp.indptr),
X_T_data_tvm, X_T_indices_tvm, X_T_indptr_tvm) X_T_data_tvm, X_T_indices_tvm, X_T_indptr_tvm)
X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense() X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense()
...@@ -295,11 +295,11 @@ def test_sparse_dense_bsr(): ...@@ -295,11 +295,11 @@ def test_sparse_dense_bsr():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op) s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np), func(tvm.nd.array(X_np),
tvm.ndarray.array(W_sp_np.data), tvm.nd.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices), tvm.nd.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr), tvm.nd.array(W_sp_np.indptr),
Y_tvm) Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
...@@ -324,11 +324,11 @@ def test_sparse_dense_bsr_randomized(): ...@@ -324,11 +324,11 @@ def test_sparse_dense_bsr_randomized():
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op) s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np), func(tvm.nd.array(X_np),
tvm.ndarray.array(W_sp_np.data), tvm.nd.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices), tvm.nd.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr), tvm.nd.array(W_sp_np.indptr),
Y_tvm) Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment