Commit 51a265af by tqchen Committed by Tianqi Chen

[REFACTOR][PY][API-CHANGE] Establish tvm.target

Move the related target modules into tvm.target.

API change:
- tvm.target.current_target -> tvm.target.Target.current
- tvm.datatype -> tvm.target.datatype
parent 79cfab00
...@@ -26,10 +26,10 @@ Python API ...@@ -26,10 +26,10 @@ Python API
ndarray ndarray
error error
ir ir
target
intrin intrin
tensor tensor
schedule schedule
target
build build
function function
autotvm autotvm
......
...@@ -19,3 +19,4 @@ tvm.target ...@@ -19,3 +19,4 @@ tvm.target
---------- ----------
.. automodule:: tvm.target .. automodule:: tvm.target
:members: :members:
:imported-members:
...@@ -46,7 +46,6 @@ from . import expr ...@@ -46,7 +46,6 @@ from . import expr
from . import stmt from . import stmt
from . import make from . import make
from . import ir_pass from . import ir_pass
from . import codegen
from . import schedule from . import schedule
from . import ir_builder from . import ir_builder
...@@ -55,7 +54,6 @@ from . import generic ...@@ -55,7 +54,6 @@ from . import generic
from . import hybrid from . import hybrid
from . import testing from . import testing
from . import error from . import error
from . import datatype
from .api import * from .api import *
......
...@@ -20,7 +20,6 @@ import ctypes ...@@ -20,7 +20,6 @@ import ctypes
import json import json
import numpy as np import numpy as np
from .base import _LIB, check_call from .base import _LIB, check_call
from .. import _api_internal
tvm_shape_index_t = ctypes.c_int64 tvm_shape_index_t = ctypes.c_int64
...@@ -48,6 +47,7 @@ class TVMByteArray(ctypes.Structure): ...@@ -48,6 +47,7 @@ class TVMByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)] ("size", ctypes.c_size_t)]
class DataType(ctypes.Structure): class DataType(ctypes.Structure):
"""TVM datatype structure""" """TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8), _fields_ = [("type_code", ctypes.c_uint8),
...@@ -89,11 +89,13 @@ class DataType(ctypes.Structure): ...@@ -89,11 +89,13 @@ class DataType(ctypes.Structure):
bits = 64 bits = 64
head = "" head = ""
elif head.startswith("custom"): elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
low, high = head.find('['), head.find(']') low, high = head.find('['), head.find(']')
if not low or not high or low >= high: if not low or not high or low >= high:
raise ValueError("Badly formatted custom type string %s" % type_str) raise ValueError("Badly formatted custom type string %s" % type_str)
type_name = head[low + 1:high] type_name = head[low + 1:high]
self.type_code = _api_internal._datatype_get_type_code(type_name) self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name)
head = head[high+1:] head = head[high+1:]
else: else:
raise ValueError("Do not know how to handle type %s" % type_str) raise ValueError("Do not know how to handle type %s" % type_str)
...@@ -102,13 +104,15 @@ class DataType(ctypes.Structure): ...@@ -102,13 +104,15 @@ class DataType(ctypes.Structure):
def __repr__(self): def __repr__(self):
# pylint: disable=import-outside-toplevel
if self.bits == 1 and self.lanes == 1: if self.bits == 1 and self.lanes == 1:
return "bool" return "bool"
if self.type_code in DataType.CODE2STR: if self.type_code in DataType.CODE2STR:
type_name = DataType.CODE2STR[self.type_code] type_name = DataType.CODE2STR[self.type_code]
else: else:
import tvm.runtime._ffi_api
type_name = "custom[%s]" % \ type_name = "custom[%s]" % \
_api_internal._datatype_get_type_name(self.type_code) tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
x = "%s%d" % (type_name, self.bits) x = "%s%d" % (type_name, self.bits)
if self.lanes != 1: if self.lanes != 1:
x += "x%d" % self.lanes x += "x%d" % self.lanes
...@@ -168,28 +172,35 @@ class TVMContext(ctypes.Structure): ...@@ -168,28 +172,35 @@ class TVMContext(ctypes.Structure):
self.device_type = device_type self.device_type = device_type
self.device_id = device_id self.device_id = device_id
def _GetDeviceAttr(self, device_type, device_id, attr_id):
"""Internal helper function to invoke runtime.GetDeviceAttr"""
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
return tvm.runtime._ffi_api.GetDeviceAttr(
device_type, device_id, attr_id)
@property @property
def exist(self): def exist(self):
"""Whether this device exist.""" """Whether this device exist."""
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0 self.device_type, self.device_id, 0) != 0
@property @property
def max_threads_per_block(self): def max_threads_per_block(self):
"""Maximum number of threads on each block.""" """Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 1) self.device_type, self.device_id, 1)
@property @property
def warp_size(self): def warp_size(self):
"""Number of threads that executes in concurrent.""" """Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 2) self.device_type, self.device_id, 2)
@property @property
def max_shared_memory_per_block(self): def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes.""" """Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 3) self.device_type, self.device_id, 3)
@property @property
...@@ -203,25 +214,25 @@ class TVMContext(ctypes.Structure): ...@@ -203,25 +214,25 @@ class TVMContext(ctypes.Structure):
version : str version : str
The version string in `major.minor` format. The version string in `major.minor` format.
""" """
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 4) self.device_type, self.device_id, 4)
@property @property
def device_name(self): def device_name(self):
"""Return the string name of device.""" """Return the string name of device."""
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 5) self.device_type, self.device_id, 5)
@property @property
def max_clock_rate(self): def max_clock_rate(self):
"""Return the max clock frequency of device.""" """Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 6) self.device_type, self.device_id, 6)
@property @property
def multi_processor_count(self): def multi_processor_count(self):
"""Return the number of compute units of device.""" """Return the number of compute units of device."""
return _api_internal._GetDeviceAttr( return self._GetDeviceAttr(
self.device_type, self.device_id, 7) self.device_type, self.device_id, 7)
@property @property
...@@ -233,7 +244,7 @@ class TVMContext(ctypes.Structure): ...@@ -233,7 +244,7 @@ class TVMContext(ctypes.Structure):
dims: List of int dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
""" """
return json.loads(_api_internal._GetDeviceAttr( return json.loads(self._GetDeviceAttr(
self.device_type, self.device_id, 8)) self.device_type, self.device_id, 8))
def sync(self): def sync(self):
......
...@@ -106,7 +106,7 @@ class DispatchContext(object): ...@@ -106,7 +106,7 @@ class DispatchContext(object):
def _alter_conv2d_layout(attrs, inputs, tinfo): def _alter_conv2d_layout(attrs, inputs, tinfo):
workload = get_conv2d_workload(...) workload = get_conv2d_workload(...)
dispatch_ctx = autotvm.task.DispatchContext.current dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target() target = tvm.target.Target.current()
config = dispatch_ctx.query(target, workload) config = dispatch_ctx.query(target, workload)
# Get conv2d_NCHWc workload from config # Get conv2d_NCHWc workload from config
...@@ -207,7 +207,7 @@ def dispatcher(fworkload): ...@@ -207,7 +207,7 @@ def dispatcher(fworkload):
def dispatch_func(func, *args, **kwargs): def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function""" """The wrapped dispatch function"""
tgt = _target.current_target() tgt = _target.Target.current()
workload = func(*args, **kwargs) workload = func(*args, **kwargs)
cfg = DispatchContext.current.query(tgt, workload) cfg = DispatchContext.current.query(tgt, workload)
if cfg.is_fallback and not cfg.template_key: if cfg.is_fallback and not cfg.template_key:
......
...@@ -25,6 +25,8 @@ import tvm.runtime ...@@ -25,6 +25,8 @@ import tvm.runtime
from tvm.runtime import Object, ndarray from tvm.runtime import Object, ndarray
from tvm.ir import container from tvm.ir import container
from tvm.target import codegen
from . import api from . import api
from . import _api_internal from . import _api_internal
from . import tensor from . import tensor
...@@ -32,7 +34,6 @@ from . import schedule ...@@ -32,7 +34,6 @@ from . import schedule
from . import expr from . import expr
from . import ir_pass from . import ir_pass
from . import stmt as _stmt from . import stmt as _stmt
from . import codegen
from . import target as _target from . import target as _target
from . import make from . import make
from .stmt import LoweredFunc from .stmt import LoweredFunc
...@@ -602,7 +603,7 @@ def build(inputs, ...@@ -602,7 +603,7 @@ def build(inputs,
"LoweredFunc.") "LoweredFunc.")
if not isinstance(inputs, (dict, container.Map)): if not isinstance(inputs, (dict, container.Map)):
target = _target.current_target() if target is None else target target = _target.Target.current() if target is None else target
target = target if target else "llvm" target = target if target else "llvm"
target_flist = {target: flist} target_flist = {target: flist}
else: else:
......
...@@ -16,11 +16,10 @@ ...@@ -16,11 +16,10 @@
# under the License. # under the License.
"""Util to invoke clang in the system.""" """Util to invoke clang in the system."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import subprocess import subprocess
from .._ffi.base import py_str from tvm._ffi.base import py_str
from .. import codegen import tvm.target
from . import util from . import util
...@@ -44,8 +43,8 @@ def find_clang(required=True): ...@@ -44,8 +43,8 @@ def find_clang(required=True):
matches the major llvm version that built with tvm matches the major llvm version that built with tvm
""" """
cc_list = [] cc_list = []
if hasattr(codegen, "llvm_version_major"): major = tvm.target.codegen.llvm_version_major(allow_none=True)
major = codegen.llvm_version_major() if major is not None:
cc_list += ["clang-%d.0" % major] cc_list += ["clang-%d.0" % major]
cc_list += ["clang-%d" % major] cc_list += ["clang-%d" % major]
cc_list += ["clang"] cc_list += ["clang"]
......
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
"""Utility for ROCm backend""" """Utility for ROCm backend"""
import subprocess import subprocess
from os.path import join, exists from os.path import join, exists
from tvm._ffi.base import py_str
import tvm.target
from . import util from . import util
from .._ffi.base import py_str
from .. import codegen
from ..api import register_func, convert from ..api import register_func, convert
def find_lld(required=True): def find_lld(required=True):
...@@ -42,8 +44,8 @@ def find_lld(required=True): ...@@ -42,8 +44,8 @@ def find_lld(required=True):
matches the major llvm version that built with tvm matches the major llvm version that built with tvm
""" """
lld_list = [] lld_list = []
if hasattr(codegen, "llvm_version_major"): major = tvm.target.codegen.llvm_version_major(allow_none=True)
major = codegen.llvm_version_major() if major is not None:
lld_list += ["ld.lld-%d.0" % major] lld_list += ["ld.lld-%d.0" % major]
lld_list += ["ld.lld-%d" % major] lld_list += ["ld.lld-%d" % major]
lld_list += ["ld.lld"] lld_list += ["ld.lld"]
......
...@@ -154,8 +154,8 @@ def max_num_threads(func_id, args): ...@@ -154,8 +154,8 @@ def max_num_threads(func_id, args):
_internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!") _internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!")
_internal_assert(args.__len__() <= 1, "At most one argument accepted!") _internal_assert(args.__len__() <= 1, "At most one argument accepted!")
if args.__len__() == 0: if args.__len__() == 0:
res = _tgt.current_target().max_num_threads res = _tgt.Target.current().max_num_threads
else: else:
_internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads res = _tgt.Target.current(args[0].value).max_num_threads
return _api.convert(res) return _api.convert(res)
...@@ -107,7 +107,7 @@ def sigmoid(x): ...@@ -107,7 +107,7 @@ def sigmoid(x):
def max_num_threads(allow_none=True): def max_num_threads(allow_none=True):
"""Get max number of threads for GPU targets.""" """Get max number of threads for GPU targets."""
return target.current_target(allow_none).max_num_threads return target.Target.current(allow_none).max_num_threads
HYBRID_GLOBALS = { HYBRID_GLOBALS = {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Expression Intrinsics and math functions in TVM.""" """Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
import tvm._ffi import tvm._ffi
import tvm.codegen import tvm.target.codegen
from . import make as _make from . import make as _make
from .api import convert, const from .api import convert, const
...@@ -189,7 +189,7 @@ def call_llvm_intrin(dtype, name, *args): ...@@ -189,7 +189,7 @@ def call_llvm_intrin(dtype, name, *args):
call : Expr call : Expr
The call expression. The call expression.
""" """
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
......
...@@ -176,7 +176,7 @@ class VMCompiler(object): ...@@ -176,7 +176,7 @@ class VMCompiler(object):
def _update_target(self, target): def _update_target(self, target):
"""Update target.""" """Update target."""
target = target if target else tvm.target.current_target() target = target if target else tvm.target.Target.current()
if target is None: if target is None:
raise ValueError("Target is not set in env or passed as argument.") raise ValueError("Target is not set in env or passed as argument.")
tgts = {} tgts = {}
......
...@@ -33,7 +33,7 @@ from .backend import interpreter as _interpreter ...@@ -33,7 +33,7 @@ from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor from .backend.vm import VMExecutor
def _update_target(target): def _update_target(target):
target = target if target else _target.current_target() target = target if target else _target.Target.current()
if target is None: if target is None:
raise ValueError("Target is not set in env or passed as argument.") raise ValueError("Target is not set in env or passed as argument.")
......
...@@ -220,13 +220,13 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): ...@@ -220,13 +220,13 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
def is_fast_int8_on_intel(): def is_fast_int8_on_intel():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'} intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
return intel_supported_arches.intersection(set(target.options)) return intel_supported_arches.intersection(set(target.options))
def is_fast_int8_on_arm(): def is_fast_int8_on_arm():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
return '+v8.2a,+dotprod' in ' '.join(target.options) return '+v8.2a,+dotprod' in ' '.join(target.options)
######################## ########################
......
...@@ -37,8 +37,8 @@ def _get_profile_runtime(mod): ...@@ -37,8 +37,8 @@ def _get_profile_runtime(mod):
func = mod['main'] func = mod['main']
func = _quantize.CreateStatsCollector(func) func = _quantize.CreateStatsCollector(func)
if tvm.target.current_target(): if tvm.target.Target.current():
target = tvm.target.current_target() target = tvm.target.Target.current()
ctx = tvm.context(target.target_name) ctx = tvm.context(target.target_name)
else: else:
target = 'llvm' target = 'llvm'
......
...@@ -16,9 +16,7 @@ ...@@ -16,9 +16,7 @@
# under the License. # under the License.
#pylint: disable=unused-argument,inconsistent-return-statements #pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation.""" """Internal module for registering attribute for annotation."""
from __future__ import absolute_import import tvm
from ... import target as _target
from .. import expr as _expr from .. import expr as _expr
from .. import analysis as _analysis from .. import analysis as _analysis
from ..base import register_relay_node from ..base import register_relay_node
...@@ -133,7 +131,7 @@ def add_partition_generic(ref_call, new_args, ctx): ...@@ -133,7 +131,7 @@ def add_partition_generic(ref_call, new_args, ctx):
@register_partition_function("add") @register_partition_function("add")
def add_partition_function(ref_call, new_args, ctx): def add_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition""" """Rewrite function for ewise add for partition"""
target = _target.current_target() target = tvm.target.Target.current()
if target and 'cuda' in target.keys: if target and 'cuda' in target.keys:
#TODO(wuwei/ziheng) cuda specific rules #TODO(wuwei/ziheng) cuda specific rules
return add_partition_generic(ref_call, new_args, ctx) return add_partition_generic(ref_call, new_args, ctx)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Target description and codgen module.
TVM's target string is in fomat ``<target_name> [-option=value]...``.
Note
----
The list of options include:
- **-device=<device name>**
The device name.
- **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.runtime.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
We can use :py:func:`~tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
from .target import Target, create
from .target import cuda, rocm, mali, intel_graphics, opengl, arm_cpu, rasp, vta, bifrost
from .generic_func import GenericFunc
from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
from . import datatype
from . import codegen
...@@ -14,25 +14,8 @@ ...@@ -14,25 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Code generation related functions.""" """FFI APIs for tvm.target"""
import tvm._ffi import tvm._ffi
def build_module(lowered_func, target):
"""Build lowered_func into Module.
Parameters tvm._ffi._init_api("target", __name__)
----------
lowered_func : LoweredFunc
The lowered function
target : str
The target module type.
Returns
-------
module : Module
The corressponding module.
"""
return _Build(lowered_func, target)
tvm._ffi._init_api("tvm.codegen")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Code generation related functions."""
from . import _ffi_api
def build_module(lowered_func, target):
"""Build lowered_func into Module.
Parameters
----------
lowered_func : LoweredFunc
The lowered function
target : str
The target module type.
Returns
-------
module : runtime.Module
The corressponding module.
"""
return _ffi_api.Build(lowered_func, target)
def llvm_lookup_intrinsic_id(name):
"""Lookup LLVM intrinsic id by name.
Parameters
----------
name : str
The name of the intrinsic.
Returns
-------
intrin_id : int
The intrinsic id.
"""
return _ffi_api.llvm_lookup_intrinsic_id(name)
def llvm_version_major(allow_none=False):
"""Get the major LLVM version.
Parameters
----------
allow_none : bool
Whether do we allow none.
Returns
-------
major : int
The major LLVM version.
"""
try:
return _ffi_api.llvm_version_major()
except AttributeError:
if allow_none:
return None
raise RuntimeError(
"LLVM version is not available, please check if you build with LLVM")
...@@ -17,11 +17,9 @@ ...@@ -17,11 +17,9 @@
"""Custom datatype functionality""" """Custom datatype functionality"""
import tvm._ffi import tvm._ffi
from . import make as _make import tvm.runtime._ffi_api
from .api import convert from tvm.runtime import convert, DataType
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm from tvm.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from ._ffi.runtime_ctypes import DataType
from . import _api_internal
def register(type_name, type_code): def register(type_name, type_code):
...@@ -39,7 +37,7 @@ def register(type_name, type_code): ...@@ -39,7 +37,7 @@ def register(type_name, type_code):
type_code : int type_code : int
The type's code, which should be >= kCustomBegin The type's code, which should be >= kCustomBegin
""" """
_api_internal._datatype_register(type_name, type_code) tvm.runtime._ffi_api._datatype_register(type_name, type_code)
def get_type_name(type_code): def get_type_name(type_code):
...@@ -50,7 +48,7 @@ def get_type_name(type_code): ...@@ -50,7 +48,7 @@ def get_type_name(type_code):
type_code : int type_code : int
The type code The type code
""" """
return _api_internal._datatype_get_type_name(type_code) return tvm.runtime._ffi_api._datatype_get_type_name(type_code)
def get_type_code(type_name): def get_type_code(type_name):
...@@ -61,7 +59,7 @@ def get_type_code(type_name): ...@@ -61,7 +59,7 @@ def get_type_code(type_name):
type_name : str type_name : str
The type name The type name
""" """
return _api_internal._datatype_get_type_code(type_name) return tvm.runtime._ffi_api._datatype_get_type_code(type_name)
def get_type_registered(type_code): def get_type_registered(type_code):
...@@ -72,7 +70,7 @@ def get_type_registered(type_code): ...@@ -72,7 +70,7 @@ def get_type_registered(type_code):
type_code: int type_code: int
The type code The type code
""" """
return _api_internal._datatype_get_type_registered(type_code) return tvm.runtime._ffi_api._datatype_get_type_registered(type_code)
def register_op(lower_func, op_name, target, type_name, src_type_name=None): def register_op(lower_func, op_name, target, type_name, src_type_name=None):
...@@ -137,9 +135,9 @@ def create_lower_func(extern_func_name): ...@@ -137,9 +135,9 @@ def create_lower_func(extern_func_name):
if t.lanes > 1: if t.lanes > 1:
dtype += "x" + str(t.lanes) dtype += "x" + str(t.lanes)
if isinstance(op, (_Cast, _FloatImm)): if isinstance(op, (_Cast, _FloatImm)):
return _make.Call(dtype, extern_func_name, convert([op.value]), return _Call(dtype, extern_func_name, convert([op.value]),
_Call.Extern, None, 0) _Call.Extern, None, 0)
return _make.Call(dtype, extern_func_name, convert([op.a, op.b]), return _Call(dtype, extern_func_name, convert([op.a, op.b]),
_Call.Extern, None, 0) _Call.Extern, None, 0)
return lower return lower
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Generic function."""
import tvm._ffi
try:
from decorator import decorate
except ImportError as err_msg:
# Allow decorator to be missing in runtime
if _LIB_NAME != "libtvm_runtime.so":
raise err_msg
from tvm.runtime import Object
from . target import Target
from . import _ffi_api
@tvm._ffi.register_object
class GenericFunc(Object):
"""GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is
called, a specialization is chosen based on the current target.
Note
----
Do not construct an instance of this object, it should only ever be
used as a return value from calling into C++.
"""
def __call__(self, *args):
return _ffi_api.GenericFuncCallFunc(self, *args)
def set_default(self, func, allow_override=False):
"""Set the default function to be used if no specializations match
the current target.
Parameters
----------
func : function
The default function
allow_override : bool
Whether to allow the current default to be overridden
"""
_ffi_api.GenericFuncSetDefault(self, func, allow_override)
def register(self, func, key_list, allow_override=False):
"""Register a specialization for this GenericFunc.
Parameters
----------
func : function
The function to be registered.
key : str or list of str
The key to be registered.
allow_override : bool, optional
Whether to allow existing keys to be overridden.
"""
key_list = [key_list] if isinstance(key_list, str) else key_list
_ffi_api.GenericFuncRegisterFunc(self, func, key_list, allow_override)
def get_native_generic_func(name):
"""Get a generic function from the global registry. If no
function is registered under the given name, a new generic
function is created.
Parameters
----------
name : string
The name of the generic function to get
Returns
-------
func : GenericFunc
The generic function for the given name
"""
return _ffi_api.GenericFuncGetGlobal(name)
def override_native_generic_func(func_name):
"""Override a generic function defined in C++
Generic function allows registration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.
Parameters
----------
func_name : string
The name of the generic func to be overridden
Returns
-------
fgeneric : function
A wrapped generic function.
Example
-------
.. code-block:: python
import tvm
# wrap function as target generic
@tvm.target.override_native_generic_func("my_func")
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
generic_func_node = get_native_generic_func(func_name)
def fdecorate(fdefault):
"""Wrap a target generic function, overriding the previous
default that was set for the generic function.
Parameters
----------
fdefault : function
The default function.
Returns
-------
fgeneric : function
A wrapped generic function.
"""
generic_func_node.set_default(fdefault, allow_override=True)
def register(key, func=None, override=True):
"""Register function to be the dispatch function.
Parameters
----------
key : str or list of str
The key to be registered.
func : function
The function to be registered.
override : bool, optional
Whether override existing registration.
Returns
-------
The register function is necessary.
"""
def _do_reg(myf):
generic_func_node.register(myf, key, override)
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
#pylint: disable=unused-argument
"""The wrapped dispath function"""
if kwargs:
raise RuntimeError(
"Keyword arguments cannot be used when invoking generic_func %s" % func_name)
return generic_func_node(*args)
fresult = decorate(fdefault, dispatch_func)
fresult.fdefault = fdefault
fresult.register = register
return fresult
return fdecorate
def generic_func(fdefault):
"""Wrap a target generic function.
Generic function allows registration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.
Parameters
----------
fdefault : function
The default function.
Returns
-------
fgeneric : function
A wrapped generic function.
Example
-------
.. code-block:: python
import tvm
# wrap function as target generic
@tvm.target.generic_func
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
dispatch_dict = {}
func_name = fdefault.__name__
def register(key, func=None, override=False):
"""Register function to be the dispatch function.
Parameters
----------
key : str or list of str
The key to be registered.
func : function
The function to be registered.
override : bool
Whether override existing registration.
Returns
-------
The register function is necessary.
"""
def _do_reg(myf):
key_list = [key] if isinstance(key, str) else key
for k in key_list:
if k in dispatch_dict and not override:
raise ValueError(
"Key is already registered for %s" % func_name)
dispatch_dict[k] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispath function"""
target = Target.current()
if target is None:
return func(*args, **kwargs)
for k in target.keys:
if k in dispatch_dict:
return dispatch_dict[k](*args, **kwargs)
return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
fdecorate.fdefault = fdefault
return fdecorate
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Target data structure."""
import warnings
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
@tvm._ffi.register_object
class Target(Object):
"""Target device information, use through TVM API.
Note
----
Do not use class constructor, you can create target using the following functions
- :py:func:`~tvm.target.create` create target from string
- :py:func:`~tvm.target.arm_cpu` create arm_cpu target
- :py:func:`~tvm.target.cuda` create CUDA target
- :py:func:`~tvm.target.rocm` create ROCM target
- :py:func:`~tvm.target.mali` create Mali target
- :py:func:`~tvm.target.intel_graphics` create Intel Graphics target
"""
def __new__(cls):
# Always override new to enable class
obj = Object.__new__(cls)
obj._keys = None
obj._options = None
obj._libs = None
return obj
@property
def keys(self):
if not self._keys:
self._keys = [k.value for k in self.keys_array]
return self._keys
@property
def options(self):
if not self._options:
self._options = [o.value for o in self.options_array]
return self._options
@property
def libs(self):
if not self._libs:
self._libs = [l.value for l in self.libs_array]
return self._libs
@property
def model(self):
for opt in self.options_array:
if opt.value.startswith('-model='):
return opt.value[7:]
return 'unknown'
@property
def mcpu(self):
"""Returns the mcpu from the target if it exists."""
mcpu = ''
if self.options is not None:
for opt in self.options:
if 'mcpu' in opt:
mcpu = opt.split('=')[1]
return mcpu
def __enter__(self):
_ffi_api.EnterTargetScope(self)
return self
def __exit__(self, ptype, value, trace):
_ffi_api.ExitTargetScope(self)
@staticmethod
def current(allow_none=True):
"""Returns the current target.
Parameters
----------
allow_none : bool
Whether allow the current target to be none
Raises
------
ValueError if current target is not set.
"""
return _ffi_api.GetCurrentTarget(allow_none)
def _merge_opts(opts, new_opts):
"""Helper function to merge options"""
if isinstance(new_opts, str):
new_opts = new_opts.split()
if new_opts:
opt_set = set(opts)
new_opts = [opt for opt in new_opts if opt not in opt_set]
return opts + new_opts
return opts
def cuda(model='unknown', options=None):
"""Returns a cuda target.
Parameters
----------
model: str
The model of cuda device (e.g. 1080ti)
options : str or list of str
Additional options
"""
opts = _merge_opts(['-model=%s' % model], options)
return _ffi_api.TargetCreate("cuda", *opts)
def rocm(model='unknown', options=None):
"""Returns a ROCM target.
Parameters
----------
model: str
The model of this device
options : str or list of str
Additional options
"""
opts = _merge_opts(["-model=%s" % model], options)
return _ffi_api.TargetCreate("rocm", *opts)
def mali(model='unknown', options=None):
"""Returns a ARM Mali GPU target.
Parameters
----------
model: str
The model of this device
options : str or list of str
Additional options
"""
opts = ["-device=mali", '-model=%s' % model]
opts = _merge_opts(opts, options)
return _ffi_api.TargetCreate("opencl", *opts)
def intel_graphics(model='unknown', options=None):
"""Returns an Intel Graphics target.
Parameters
----------
model: str
The model of this device
options : str or list of str
Additional options
"""
opts = ["-device=intel_graphics", '-model=%s' % model]
opts = _merge_opts(opts, options)
return _ffi_api.TargetCreate("opencl", *opts)
def opengl(model='unknown', options=None):
"""Returns a OpenGL target.
Parameters
----------
options : str or list of str
Additional options
"""
opts = _merge_opts(["-model=%s" % model], options)
return _ffi_api.TargetCreate("opengl", *opts)
def arm_cpu(model='unknown', options=None):
"""Returns a ARM CPU target.
This function will also download pre-tuned op parameters when there is none.
Parameters
----------
model: str
SoC name or phone name of the arm board.
options : str or list of str
Additional options
"""
trans_table = {
"pixel2": ["-model=snapdragon835", "-target=arm64-linux-android -mattr=+neon"],
"mate10": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
"mate10pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
"p20": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
"p20pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
"rasp3b": ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"],
"rasp4b": ["-model=bcm2711", "-target=arm-linux-gnueabihf -mattr=+neon"],
"rk3399": ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"],
"pynq": ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"],
"ultra96": ["-model=ultra96", "-target=aarch64-linux-gnu -mattr=+neon"],
}
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
opts = ["-device=arm_cpu"] + pre_defined_opt
opts = _merge_opts(opts, options)
return _ffi_api.TargetCreate("llvm", *opts)
def rasp(options=None):
"""Return a Raspberry 3b target.
Parameters
----------
options : str or list of str
Additional options
"""
warnings.warn('tvm.target.rasp() is going to be deprecated. '
'Please use tvm.target.arm_cpu("rasp3b")')
return arm_cpu('rasp3b', options)
def vta(model='unknown', options=None):
opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
opts = _merge_opts(opts, options)
ret = _ffi_api.TargetCreate("ext_dev", *opts)
return ret
def bifrost(model='unknown', options=None):
"""Return an ARM Mali GPU target (Bifrost architecture).
Parameters
----------
options : str or list of str
Additional options
"""
opts = ["-device=bifrost", '-model=%s' % model]
opts = _merge_opts(opts, options)
return _ffi_api.TargetCreate("opencl", *opts)
def create(target_str):
"""Get a target given target string.
Parameters
----------
target_str : str
The target string.
Returns
-------
target : Target
The target object
Note
----
See the note on :py:mod:`~tvm.target` on target string format.
"""
if isinstance(target_str, Target):
return target_str
if not isinstance(target_str, str):
raise ValueError("target_str has to be string type")
return _ffi_api.TargetFromString(target_str)
...@@ -46,20 +46,20 @@ namespace tvm { ...@@ -46,20 +46,20 @@ namespace tvm {
namespace runtime { namespace runtime {
std::string GetCustomTypeName(uint8_t type_code) { std::string GetCustomTypeName(uint8_t type_code) {
auto f = tvm::runtime::Registry::Get("_datatype_get_type_name"); auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_name");
CHECK(f) << "Function _datatype_get_type_name not found"; CHECK(f) << "Function runtime._datatype_get_type_name not found";
return (*f)(type_code).operator std::string(); return (*f)(type_code).operator std::string();
} }
uint8_t GetCustomTypeCode(const std::string& type_name) { uint8_t GetCustomTypeCode(const std::string& type_name) {
auto f = tvm::runtime::Registry::Get("_datatype_get_type_code"); auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_code");
CHECK(f) << "Function _datatype_get_type_code not found"; CHECK(f) << "Function runtime._datatype_get_type_code not found";
return (*f)(type_name).operator int(); return (*f)(type_name).operator int();
} }
bool GetCustomTypeRegistered(uint8_t type_code) { bool GetCustomTypeRegistered(uint8_t type_code) {
auto f = tvm::runtime::Registry::Get("_datatype_get_type_registered"); auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_registered");
CHECK(f) << "Function _datatype_get_type_registered not found"; CHECK(f) << "Function runtime._datatype_get_type_registered not found";
return (*f)(type_code).operator bool(); return (*f)(type_code).operator bool();
} }
...@@ -612,7 +612,7 @@ TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) ...@@ -612,7 +612,7 @@ TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
}); });
// set device api // set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr") TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx; TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
......
...@@ -244,7 +244,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, ...@@ -244,7 +244,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
return (*codegen_f)(blob_byte_array, system_lib, target_triple); return (*codegen_f)(blob_byte_array, system_lib, target_triple);
} }
TVM_REGISTER_GLOBAL("codegen._Build") TVM_REGISTER_GLOBAL("target.Build")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<tir::LoweredFunc>()) { if (args[0].IsObjectRef<tir::LoweredFunc>()) {
*ret = Build({args[0]}, args[1]); *ret = Build({args[0]}, args[1]);
......
...@@ -25,22 +25,22 @@ namespace datatype { ...@@ -25,22 +25,22 @@ namespace datatype {
using runtime::TVMArgs; using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
TVM_REGISTER_GLOBAL("_datatype_register") TVM_REGISTER_GLOBAL("runtime._datatype_register")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int())); datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
}); });
TVM_REGISTER_GLOBAL("_datatype_get_type_code") TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = datatype::Registry::Global()->GetTypeCode(args[0]); *ret = datatype::Registry::Global()->GetTypeCode(args[0]);
}); });
TVM_REGISTER_GLOBAL("_datatype_get_type_name") TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Registry::Global()->GetTypeName(args[0].operator int()); *ret = Registry::Global()->GetTypeName(args[0].operator int());
}); });
TVM_REGISTER_GLOBAL("_datatype_get_type_registered") TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
}); });
...@@ -90,7 +90,6 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t ...@@ -90,7 +90,6 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t
} else { } else {
ss << runtime::TypeCode2Str(src_type_code); ss << runtime::TypeCode2Str(src_type_code);
} }
return runtime::Registry::Get(ss.str()); return runtime::Registry::Get(ss.str());
} }
......
...@@ -123,18 +123,18 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { ...@@ -123,18 +123,18 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
func.CallPacked(args, ret); func.CallPacked(args, ret);
} }
TVM_REGISTER_GLOBAL("_GenericFuncCreate") TVM_REGISTER_GLOBAL("target.GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(make_object<GenericFuncNode>()); *ret = GenericFunc(make_object<GenericFuncNode>());
}); });
TVM_REGISTER_GLOBAL("_GenericFuncGetGlobal") TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::string func_name = args[0]; std::string func_name = args[0];
*ret = GenericFunc::Get(func_name); *ret = GenericFunc::Get(func_name);
}); });
TVM_REGISTER_GLOBAL("_GenericFuncSetDefault") TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0]; GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
...@@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncSetDefault") ...@@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncSetDefault")
.set_default(*func, allow_override); .set_default(*func, allow_override);
}); });
TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0]; GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
...@@ -162,7 +162,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") ...@@ -162,7 +162,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
.register_func(tags_vector, *func, allow_override); .register_func(tags_vector, *func, allow_override);
}); });
TVM_REGISTER_GLOBAL("_GenericFuncCallFunc") TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0]; GenericFunc generic_func = args[0];
TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1);
......
...@@ -349,11 +349,6 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { ...@@ -349,11 +349,6 @@ unsigned LookupLLVMIntrinsic(const std::string& name) {
return llvm::Function::lookupIntrinsicID(name); return llvm::Function::lookupIntrinsicID(name);
} }
TVM_REGISTER_GLOBAL("codegen.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
});
TVM_REGISTER_GLOBAL("codegen.build_llvm") TVM_REGISTER_GLOBAL("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<LLVMModuleNode>(); auto n = make_object<LLVMModuleNode>();
...@@ -361,9 +356,13 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm") ...@@ -361,9 +356,13 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm")
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
TVM_REGISTER_GLOBAL("codegen.llvm_version_major") TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
});
TVM_REGISTER_GLOBAL("target.llvm_version_major")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::ostringstream os;
int major = TVM_LLVM_VERSION / 10; int major = TVM_LLVM_VERSION / 10;
*rv = major; *rv = major;
}); });
......
...@@ -144,7 +144,7 @@ Target CreateTarget(const std::string& target_name, ...@@ -144,7 +144,7 @@ Target CreateTarget(const std::string& target_name,
return Target(t); return Target(t);
} }
TVM_REGISTER_GLOBAL("_TargetCreate") TVM_REGISTER_GLOBAL("target.TargetCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_name = args[0]; std::string target_name = args[0];
std::vector<std::string> options; std::vector<std::string> options;
...@@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("_TargetCreate") ...@@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("_TargetCreate")
*ret = CreateTarget(target_name, options); *ret = CreateTarget(target_name, options);
}); });
TVM_REGISTER_GLOBAL("_TargetFromString") TVM_REGISTER_GLOBAL("target.TargetFromString")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_str = args[0]; std::string target_str = args[0];
*ret = Target::Create(target_str); *ret = Target::Create(target_str);
...@@ -269,7 +269,7 @@ tvm::Target Target::Current(bool allow_not_defined) { ...@@ -269,7 +269,7 @@ tvm::Target Target::Current(bool allow_not_defined) {
return Target(); return Target();
} }
TVM_REGISTER_GLOBAL("_GetCurrentTarget") TVM_REGISTER_GLOBAL("target.GetCurrentTarget")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
bool allow_not_defined = args[0]; bool allow_not_defined = args[0];
*ret = Target::Current(allow_not_defined); *ret = Target::Current(allow_not_defined);
...@@ -284,10 +284,10 @@ class Target::Internal { ...@@ -284,10 +284,10 @@ class Target::Internal {
} }
}; };
TVM_REGISTER_GLOBAL("_EnterTargetScope") TVM_REGISTER_GLOBAL("target.EnterTargetScope")
.set_body_typed(Target::Internal::EnterScope); .set_body_typed(Target::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitTargetScope") TVM_REGISTER_GLOBAL("target.ExitTargetScope")
.set_body_typed(Target::Internal::ExitScope); .set_body_typed(Target::Internal::ExitScope);
namespace target { namespace target {
......
...@@ -95,19 +95,19 @@ class CustomDatatypesLowerer : public StmtExprMutator { ...@@ -95,19 +95,19 @@ class CustomDatatypesLowerer : public StmtExprMutator {
return expr; return expr;
} }
#define DEFINE_MUTATE__(OP, NodeName) \ #define DEFINE_MUTATE__(OP, NodeName) \
inline PrimExpr VisitExpr_(const NodeName* op) final { \ inline PrimExpr VisitExpr_(const NodeName* op) final { \
auto type_code = op->dtype.code(); \ auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ PrimExpr expr = StmtExprMutator::VisitExpr_(op); \
op = expr.as<NodeName>(); \ op = expr.as<NodeName>(); \
if (toBeLowered) { \ if (toBeLowered) { \
auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ CHECK(lower) << #OP " lowering function for target " << target_ << " type " \
<< static_cast<unsigned>(type_code) << " not found"; \ << static_cast<unsigned>(type_code) << " not found"; \
return (*lower)(expr); \ return (*lower)(expr); \
} \ } \
return expr; \ return expr; \
} }
DEFINE_MUTATE__(Add, AddNode); DEFINE_MUTATE__(Add, AddNode);
......
...@@ -103,11 +103,11 @@ TEST(BuildModule, Heterogeneous) { ...@@ -103,11 +103,11 @@ TEST(BuildModule, Heterogeneous) {
return copy[i] - C[i]; return copy[i] - C[i];
}, "elemwise_sub"); }, "elemwise_sub");
const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope"); With<Target> cuda_scope(target_cuda);
(*enter_target_scope_func)(target_cuda);
auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
(*enter_target_scope_func)(target_llvm);
With<Target> llvm_scope(target_llvm);
auto s2 = create_schedule({elemwise_sub->op}); auto s2 = create_schedule({elemwise_sub->op});
auto config = BuildConfig::Create(); auto config = BuildConfig::Create();
......
...@@ -55,7 +55,7 @@ def test_dot(): ...@@ -55,7 +55,7 @@ def test_dot():
if not tvm.runtime.enabled(target): if not tvm.runtime.enabled(target):
print("Target %s is not enabled" % target) print("Target %s is not enabled" % target)
return return
f = tvm.codegen.build_module(fapi, target) f = tvm.target.codegen.build_module(fapi, target)
# verify # verify
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx)
......
...@@ -1115,7 +1115,7 @@ def test_conv2d_int8_intrinsics(): ...@@ -1115,7 +1115,7 @@ def test_conv2d_int8_intrinsics():
# compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions
targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
llvm_version = tvm.codegen.llvm_version_major() llvm_version = tvm.target.codegen.llvm_version_major()
for target in targets: for target in targets:
if llvm_version >= 8: if llvm_version >= 8:
dtypes = ('uint8', 'int8', 'int32') dtypes = ('uint8', 'int8', 'int32')
...@@ -1208,7 +1208,7 @@ def test_depthwise_conv2d_int8(): ...@@ -1208,7 +1208,7 @@ def test_depthwise_conv2d_int8():
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
llvm_version = tvm.codegen.llvm_version_major() llvm_version = tvm.target.codegen.llvm_version_major()
for target in targets: for target in targets:
if llvm_version >= 8: if llvm_version >= 8:
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
......
...@@ -50,7 +50,7 @@ def matmul(N, L, M, dtype): ...@@ -50,7 +50,7 @@ def matmul(N, L, M, dtype):
@autotvm.template @autotvm.template
def bad_matmul(N, L, M, dtype): def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.current_target().keys: if 'bad_device' in tvm.target.Target.current().keys:
A = tvm.placeholder((N, L), name='A', dtype=dtype) A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype) B = tvm.placeholder((L, M), name='B', dtype=dtype)
......
...@@ -75,7 +75,7 @@ def test_add_pipeline(): ...@@ -75,7 +75,7 @@ def test_add_pipeline():
f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline") f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline")
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)] fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
mhost = tvm.codegen.build_module(fsplits[0], "c") mhost = tvm.target.codegen.build_module(fsplits[0], "c")
temp = util.tempdir() temp = util.tempdir()
path_dso = temp.relpath("temp.so") path_dso = temp.relpath("temp.so")
mhost.export_library(path_dso) mhost.export_library(path_dso)
......
...@@ -84,8 +84,8 @@ def test_add_pipeline(): ...@@ -84,8 +84,8 @@ def test_add_pipeline():
return return
if not tvm.runtime.enabled(host): if not tvm.runtime.enabled(host):
return return
mhost = tvm.codegen.build_module(fsplits[0], host) mhost = tvm.target.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device) mdev = tvm.target.codegen.build_module(fsplits[1:], device)
mhost.import_module(mdev) mhost.import_module(mdev)
code = mdev.get_source() code = mdev.get_source()
f = mhost.entry_func f = mhost.entry_func
...@@ -110,8 +110,8 @@ def test_add_pipeline(): ...@@ -110,8 +110,8 @@ def test_add_pipeline():
fmt = "hsaco" fmt = "hsaco"
else: else:
fmt = device fmt = device
mhost = tvm.codegen.build_module(fsplits[0], host) mhost = tvm.target.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device) mdev = tvm.target.codegen.build_module(fsplits[1:], device)
temp = util.tempdir() temp = util.tempdir()
mpath = temp.relpath("test.%s" % fmt) mpath = temp.relpath("test.%s" % fmt)
mdev.save(mpath) mdev.save(mpath)
......
...@@ -570,9 +570,9 @@ def test_dwarf_debug_information(): ...@@ -570,9 +570,9 @@ def test_dwarf_debug_information():
def check_llvm_object(): def check_llvm_object():
if not tvm.runtime.enabled("llvm"): if not tvm.runtime.enabled("llvm"):
return return
if tvm.codegen.llvm_version_major() < 5: if tvm.target.codegen.llvm_version_major() < 5:
return return
if tvm.codegen.llvm_version_major() > 6: if tvm.target.codegen.llvm_version_major() > 6:
return return
# build two functions # build two functions
f2 = tvm.lower(s, [A, B, C], name="fadd1") f2 = tvm.lower(s, [A, B, C], name="fadd1")
...@@ -607,9 +607,9 @@ def test_dwarf_debug_information(): ...@@ -607,9 +607,9 @@ def test_dwarf_debug_information():
def check_llvm_ir(): def check_llvm_ir():
if not tvm.runtime.enabled("llvm"): if not tvm.runtime.enabled("llvm"):
return return
if tvm.codegen.llvm_version_major() < 5: if tvm.target.codegen.llvm_version_major() < 5:
return return
if tvm.codegen.llvm_version_major() > 6: if tvm.target.codegen.llvm_version_major() > 6:
return return
# build two functions # build two functions
f2 = tvm.lower(s, [A, B, C], name="fadd1") f2 = tvm.lower(s, [A, B, C], name="fadd1")
......
...@@ -33,7 +33,7 @@ def test_static_callback(): ...@@ -33,7 +33,7 @@ def test_static_callback():
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "llvm") f = tvm.target.codegen.build_module(fapi, "llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
f(a) f(a)
...@@ -57,7 +57,7 @@ def test_static_init(): ...@@ -57,7 +57,7 @@ def test_static_init():
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "llvm") f = tvm.target.codegen.build_module(fapi, "llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
......
...@@ -21,7 +21,7 @@ def run_jit(fapi, check): ...@@ -21,7 +21,7 @@ def run_jit(fapi, check):
for target in ["llvm", "stackvm"]: for target in ["llvm", "stackvm"]:
if not tvm.runtime.enabled(target): if not tvm.runtime.enabled(target):
continue continue
f = tvm.codegen.build_module(fapi, target) f = tvm.target.codegen.build_module(fapi, target)
s = f.get_source() s = f.get_source()
check(f) check(f)
......
...@@ -19,9 +19,9 @@ import re ...@@ -19,9 +19,9 @@ import re
def test_fp16_to_fp32(): def test_fp16_to_fp32():
if tvm.codegen.llvm_version_major() < 6: if tvm.target.codegen.llvm_version_major() < 6:
print("Skipping due to LLVM version being {} < 6".format( print("Skipping due to LLVM version being {} < 6".format(
tvm.codegen.llvm_version_major())) tvm.target.codegen.llvm_version_major()))
return return
def fp16_to_fp32(target, width, match=None, not_match=None): def fp16_to_fp32(target, width, match=None, not_match=None):
......
...@@ -29,19 +29,19 @@ def setup_module(): ...@@ -29,19 +29,19 @@ def setup_module():
# In this case, we have built the test functions used below right into TVM. # In this case, we have built the test functions used below right into TVM.
# CDLL("libmybfloat16.so", RTLD_GLOBAL) # CDLL("libmybfloat16.so", RTLD_GLOBAL)
tvm.datatype.register("bfloat", 129) tvm.target.datatype.register("bfloat", 129)
tvm.datatype.register_op( tvm.target.datatype.register_op(
tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast", tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast",
"llvm", "bfloat", "float") "llvm", "bfloat", "float")
tvm.datatype.register_op( tvm.target.datatype.register_op(
tvm.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast", tvm.target.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast",
"llvm", "float", "bfloat") "llvm", "float", "bfloat")
tvm.datatype.register_op( tvm.target.datatype.register_op(
tvm.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm", tvm.target.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm",
"bfloat") "bfloat")
tvm.datatype.register_op( tvm.target.datatype.register_op(
tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm", tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm",
"llvm", "bfloat") "llvm", "bfloat")
def lower_datatypes_and_build(schedule, args): def lower_datatypes_and_build(schedule, args):
......
...@@ -50,7 +50,7 @@ def test_target_dispatch(): ...@@ -50,7 +50,7 @@ def test_target_dispatch():
with tvm.target.create("metal"): with tvm.target.create("metal"):
assert mygeneric(1) == 3 assert mygeneric(1) == 3
assert tvm.target.current_target() is None assert tvm.target.Target.current() is None
def test_target_string_parse(): def test_target_string_parse():
......
...@@ -39,7 +39,7 @@ def test_dltensor_compatible(): ...@@ -39,7 +39,7 @@ def test_dltensor_compatible():
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "stackvm") f = tvm.target.codegen.build_module(fapi, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a) aview = MyTensorView(a)
f(aview) f(aview)
......
...@@ -57,7 +57,7 @@ def test_dso_module_load(): ...@@ -57,7 +57,7 @@ def test_dso_module_load():
i + 1)) i + 1))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
m = tvm.codegen.build_module(fapi, "llvm") m = tvm.target.codegen.build_module(fapi, "llvm")
for name in names: for name in names:
m.save(name) m.save(name)
......
...@@ -588,7 +588,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -588,7 +588,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
idxd = tvm.indexdiv idxd = tvm.indexdiv
if groups == 1: if groups == 1:
target = tvm.target.current_target() target = tvm.target.Target.current()
dispatch_ctx = autotvm.DispatchContext.current dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload) cfg = dispatch_ctx.query(target, workload)
...@@ -693,12 +693,12 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -693,12 +693,12 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
else: else:
raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key) raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key)
else: else:
target = tvm.target.current_target() target = tvm.target.Target.current()
dispatch_ctx = autotvm.DispatchContext.current dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload) cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload) autotvm.task.clear_fallback_cache(tvm.target.Target.current(), workload)
if layout == 'NHWC' and kernel_layout == 'HWOI': if layout == 'NHWC' and kernel_layout == 'HWOI':
new_attrs['data_layout'] = 'NCHW' new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW' new_attrs['kernel_layout'] = 'OIHW'
......
...@@ -156,7 +156,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): ...@@ -156,7 +156,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
# this part to make tuning records correct # this part to make tuning records correct
s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region') s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region')
else: else:
max_threads = tvm.target.current_target(allow_none=False).max_num_threads max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
co, ci, kh, kw, vc = s[kernel_vec].op.axis co, ci, kh, kw, vc = s[kernel_vec].op.axis
fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) fused = s[kernel_vec].fuse(co, ci, kh, kw, vc)
fused, vec = s[kernel_vec].split(fused, VC) fused, vec = s[kernel_vec].split(fused, VC)
......
...@@ -24,7 +24,7 @@ def fuse_and_bind(s, tensor, axis=None, num_thread=None): ...@@ -24,7 +24,7 @@ def fuse_and_bind(s, tensor, axis=None, num_thread=None):
"""Fuse all the axis and bind to GPU threads""" """Fuse all the axis and bind to GPU threads"""
axis = axis or s[tensor].op.axis axis = axis or s[tensor].op.axis
fused = s[tensor].fuse(*axis) fused = s[tensor].fuse(*axis)
max_threads = tvm.target.current_target(allow_none=False).max_num_threads max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
bx, tx = s[tensor].split(fused, num_thread or max_threads) bx, tx = s[tensor].split(fused, num_thread or max_threads)
s[tensor].bind(bx, tvm.thread_axis("blockIdx.x")) s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(tx, tvm.thread_axis("threadIdx.x")) s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
......
...@@ -41,7 +41,7 @@ def batch_matmul_cuda(x, y): ...@@ -41,7 +41,7 @@ def batch_matmul_cuda(x, y):
output : tvm.Tensor output : tvm.Tensor
3-D with shape [batch, M, N] 3-D with shape [batch, M, N]
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name == "cuda" and "cublas" in target.libs: if target.target_name == "cuda" and "cublas" in target.libs:
return cublas.batch_matmul(x, y, False, True) return cublas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y) return batch_matmul_default(x, y)
...@@ -61,7 +61,7 @@ def schedule_batch_matmul(outs): ...@@ -61,7 +61,7 @@ def schedule_batch_matmul(outs):
s: Schedule s: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name == "cuda" and "cublas" in target.libs: if target.target_name == "cuda" and "cublas" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
......
...@@ -115,7 +115,7 @@ def schedule_conv1d_ncw(cfg, outs): ...@@ -115,7 +115,7 @@ def schedule_conv1d_ncw(cfg, outs):
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
...@@ -230,7 +230,7 @@ def schedule_conv1d_nwc(cfg, outs): ...@@ -230,7 +230,7 @@ def schedule_conv1d_nwc(cfg, outs):
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
......
...@@ -116,7 +116,7 @@ def schedule_conv1d_transpose_ncw_cuda(cfg, outs): ...@@ -116,7 +116,7 @@ def schedule_conv1d_transpose_ncw_cuda(cfg, outs):
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
......
...@@ -69,7 +69,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -69,7 +69,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cudnn" in target.libs: if "cudnn" in target.libs:
if layout == 'NCHW': if layout == 'NCHW':
...@@ -148,7 +148,7 @@ def schedule_conv2d_nchw_cuda(cfg, outs): ...@@ -148,7 +148,7 @@ def schedule_conv2d_nchw_cuda(cfg, outs):
s: Schedule s: Schedule
The computation schedule for conv2d. The computation schedule for conv2d.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if 'cudnn' in target.libs: if 'cudnn' in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
...@@ -186,7 +186,7 @@ def schedule_conv2d_nhwc_cuda(cfg, outs): ...@@ -186,7 +186,7 @@ def schedule_conv2d_nhwc_cuda(cfg, outs):
s: Schedule s: Schedule
The computation schedule for conv2d. The computation schedule for conv2d.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if 'cudnn' in target.libs: if 'cudnn' in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
......
...@@ -34,7 +34,7 @@ def schedule_direct_cuda(cfg, s, conv): ...@@ -34,7 +34,7 @@ def schedule_direct_cuda(cfg, s, conv):
cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
......
...@@ -170,7 +170,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -170,7 +170,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
......
...@@ -194,7 +194,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed): ...@@ -194,7 +194,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=2) cfg.define_split("tile_rc", rc, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 128, 1500]) cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
...@@ -325,7 +325,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): ...@@ -325,7 +325,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
Unlike other TOPI functions, this function operates on both graph level and operator level, Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, Relay. so we have to pass 'F' to make it support our two versions of graph IR, Relay.
""" """
if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs: if 'cudnn' in tvm.target.Target.current().libs or 'miopen' in tvm.target.Target.current().libs:
return None return None
copy_inputs = list(inputs) copy_inputs = list(inputs)
...@@ -349,7 +349,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): ...@@ -349,7 +349,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
CO, _, KH, KW = get_const_tuple(kernel.shape) CO, _, KH, KW = get_const_tuple(kernel.shape)
dispatch_ctx = autotvm.DispatchContext.current dispatch_ctx = autotvm.DispatchContext.current
target = tvm.target.current_target() target = tvm.target.Target.current()
if groups == 1: if groups == 1:
# query config of this workload # query config of this workload
......
...@@ -64,7 +64,7 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o ...@@ -64,7 +64,7 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
output : tvm.Tensor output : tvm.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width] 5-D with shape [batch, out_channel, out_depth, out_height, out_width]
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cudnn" in target.libs: if "cudnn" in target.libs:
if layout == 'NCDHW': if layout == 'NCDHW':
...@@ -126,7 +126,7 @@ def schedule_conv3d_ncdhw_cuda(cfg, outs): ...@@ -126,7 +126,7 @@ def schedule_conv3d_ncdhw_cuda(cfg, outs):
s: Schedule s: Schedule
The computation schedule for conv2d. The computation schedule for conv2d.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if 'cudnn' in target.libs: if 'cudnn' in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
...@@ -160,7 +160,7 @@ def schedule_conv3d_ndhwc_cuda(cfg, outs): ...@@ -160,7 +160,7 @@ def schedule_conv3d_ndhwc_cuda(cfg, outs):
s: Schedule s: Schedule
The computation schedule for conv2d. The computation schedule for conv2d.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if 'cudnn' in target.libs: if 'cudnn' in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
......
...@@ -36,7 +36,7 @@ def schedule_direct_3d_cuda(cfg, s, conv): ...@@ -36,7 +36,7 @@ def schedule_direct_3d_cuda(cfg, s, conv):
cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
......
...@@ -67,7 +67,7 @@ def schedule_direct_cuda(cfg, s, conv): ...@@ -67,7 +67,7 @@ def schedule_direct_cuda(cfg, s, conv):
cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
......
...@@ -60,7 +60,7 @@ def dense_cuda(cfg, data, weight, bias=None, out_dtype=None): ...@@ -60,7 +60,7 @@ def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
out_dtype = data.dtype out_dtype = data.dtype
batch, in_dim = data.shape batch, in_dim = data.shape
out_dim, _ = weight.shape out_dim, _ = weight.shape
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cublas" in target.libs: if "cublas" in target.libs:
matmul = cublas.matmul(data, weight, False, True, out_dtype) matmul = cublas.matmul(data, weight, False, True, out_dtype)
if bias is not None: if bias is not None:
...@@ -87,7 +87,7 @@ def schedule_dense(cfg, outs): ...@@ -87,7 +87,7 @@ def schedule_dense(cfg, outs):
The computation schedule for dense. The computation schedule for dense.
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
target = tvm.target.current_target() target = tvm.target.Target.current()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name == "cuda" and "cublas" in target.libs: if target.target_name == "cuda" and "cublas" in target.libs:
...@@ -259,7 +259,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): ...@@ -259,7 +259,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
batch, in_dim = get_const_tuple(data.shape) batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape) out_dim, _ = get_const_tuple(weight.shape)
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cublas" in target.libs: if "cublas" in target.libs:
matmul = cublas.matmul(data, weight, False, True, out_dtype) matmul = cublas.matmul(data, weight, False, True, out_dtype)
if bias is not None: if bias is not None:
...@@ -290,7 +290,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): ...@@ -290,7 +290,7 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
def schedule_dense_int8(cfg, outs): def schedule_dense_int8(cfg, outs):
"""Dense schedule for int8 on CUDA""" """Dense schedule for int8 on CUDA"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target() target = tvm.target.Target.current()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if "cublas" in target.libs: if "cublas" in target.libs:
......
...@@ -57,7 +57,7 @@ def schedule_depthwise_conv2d_nchw_cuda(cfg, outs): ...@@ -57,7 +57,7 @@ def schedule_depthwise_conv2d_nchw_cuda(cfg, outs):
cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_knob("auto_unroll_max_step", [0, 256, 1500]) cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
...@@ -166,7 +166,7 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -166,7 +166,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
# num_thread here could be 728, it is larger than cuda.max_num_threads # num_thread here could be 728, it is larger than cuda.max_num_threads
num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value
target = tvm.target.current_target() target = tvm.target.Target.current()
if target and (target.target_name not in ["cuda", "nvptx"]): if target and (target.target_name not in ["cuda", "nvptx"]):
num_thread = target.max_num_threads num_thread = target.max_num_threads
xoc, xic = s[Output].split(c, factor=num_thread) xoc, xic = s[Output].split(c, factor=num_thread)
......
...@@ -340,7 +340,7 @@ def schedule_group_conv2d_nchw_direct(cfg, s, conv): ...@@ -340,7 +340,7 @@ def schedule_group_conv2d_nchw_direct(cfg, s, conv):
cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
......
...@@ -37,7 +37,7 @@ def schedule_injective_from_existing(sch, out): ...@@ -37,7 +37,7 @@ def schedule_injective_from_existing(sch, out):
The updated schedule. The updated schedule.
""" """
fused = sch[out].fuse(*sch[out].op.axis) fused = sch[out].fuse(*sch[out].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
max_block = 256 max_block = 256
try: try:
......
...@@ -71,7 +71,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index ...@@ -71,7 +71,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1 nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
...@@ -120,7 +120,7 @@ def get_valid_counts_upsweep(data, idx_in, idx, partial): ...@@ -120,7 +120,7 @@ def get_valid_counts_upsweep(data, idx_in, idx, partial):
idx_in = ib.buffer_ptr(idx_in) idx_in = ib.buffer_ptr(idx_in)
idx = ib.buffer_ptr(idx) idx = ib.buffer_ptr(idx)
partial = ib.buffer_ptr(partial) partial = ib.buffer_ptr(partial)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1 elem_per_thread = num_anchors // max_threads + 1
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = batch_size nthread_bx = batch_size
...@@ -176,7 +176,7 @@ def get_valid_counts_scan(data, partial_in, partial): ...@@ -176,7 +176,7 @@ def get_valid_counts_scan(data, partial_in, partial):
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
partial_in = ib.buffer_ptr(partial_in) partial_in = ib.buffer_ptr(partial_in)
partial = ib.buffer_ptr(partial) partial = ib.buffer_ptr(partial)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1 elem_per_thread = num_anchors // max_threads + 1
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = batch_size nthread_bx = batch_size
...@@ -234,7 +234,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx): ...@@ -234,7 +234,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
idx_in = ib.buffer_ptr(idx_in) idx_in = ib.buffer_ptr(idx_in)
idx = ib.buffer_ptr(idx) idx = ib.buffer_ptr(idx)
partial = ib.buffer_ptr(partial) partial = ib.buffer_ptr(partial)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1 elem_per_thread = num_anchors // max_threads + 1
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1 nthread_bx = batch_size * num_anchors // max_threads + 1
...@@ -297,7 +297,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): ...@@ -297,7 +297,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
valid_count = ib.buffer_ptr(valid_count) valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out) out = ib.buffer_ptr(out)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1 nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
...@@ -356,7 +356,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): ...@@ -356,7 +356,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
""" """
batch_size = data.shape[0] batch_size = data.shape[0]
num_anchors = data.shape[1] num_anchors = data.shape[1]
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1 elem_per_thread = num_anchors // max_threads + 1
new_range = num_anchors // elem_per_thread + 1 new_range = num_anchors // elem_per_thread + 1
temp_flag_buf = api.decl_buffer( temp_flag_buf = api.decl_buffer(
...@@ -482,7 +482,7 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices, ...@@ -482,7 +482,7 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
max_threads = int( max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads) tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1 nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
...@@ -594,7 +594,7 @@ def invalid_to_bottom_pre(data, flag, idx): ...@@ -594,7 +594,7 @@ def invalid_to_bottom_pre(data, flag, idx):
idx = ib.buffer_ptr(idx) idx = ib.buffer_ptr(idx)
max_threads = int(math.sqrt( max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads)) tvm.target.Target.current(allow_none=False).max_num_threads))
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1 nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
...@@ -654,7 +654,7 @@ def invalid_to_bottom_ir(data, flag, idx, out): ...@@ -654,7 +654,7 @@ def invalid_to_bottom_ir(data, flag, idx, out):
out = ib.buffer_ptr(out) out = ib.buffer_ptr(out)
max_threads = int(math.sqrt( max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads)) tvm.target.Target.current(allow_none=False).max_num_threads))
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1 nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
......
...@@ -37,6 +37,6 @@ def schedule_lrn(outs): ...@@ -37,6 +37,6 @@ def schedule_lrn(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_lrn(cpp_target, outs) return cpp.cuda.schedule_lrn(cpp_target, outs)
...@@ -112,7 +112,7 @@ def schedule_pool(outs, layout): ...@@ -112,7 +112,7 @@ def schedule_pool(outs, layout):
def _schedule(PaddedInput, Pool): def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp): if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].compute_inline() s[PaddedInput].compute_inline()
num_thread = tvm.target.current_target(allow_none=False).max_num_threads num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
if Pool.op in s.outputs: if Pool.op in s.outputs:
Out = Pool Out = Pool
OL = s.cache_write(Pool, "local") OL = s.cache_write(Pool, "local")
...@@ -177,7 +177,7 @@ def schedule_pool_grad_cuda(outs): ...@@ -177,7 +177,7 @@ def schedule_pool_grad_cuda(outs):
else: else:
out = outs[0].op.output(0) out = outs[0].op.output(0)
fused = s[out].fuse(*s[out].op.axis) fused = s[out].fuse(*s[out].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
bx, tx = s[out].split(fused, factor=num_thread) bx, tx = s[out].split(fused, factor=num_thread)
s[out].bind(bx, tvm.thread_axis("blockIdx.x")) s[out].bind(bx, tvm.thread_axis("blockIdx.x"))
s[out].bind(tx, tvm.thread_axis("threadIdx.x")) s[out].bind(tx, tvm.thread_axis("threadIdx.x"))
......
...@@ -64,7 +64,7 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r ...@@ -64,7 +64,7 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
""" """
batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape) batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
num_anchors //= 2 num_anchors //= 2
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = (batch * height * width) // max_threads + 1 nthread_bx = (batch * height * width) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
...@@ -152,7 +152,7 @@ def argsort_ir(data_buf, out_index_buf): ...@@ -152,7 +152,7 @@ def argsort_ir(data_buf, out_index_buf):
The result IR statement. The result IR statement.
""" """
batch, num_bbox = get_const_tuple(data_buf.shape) batch, num_bbox = get_const_tuple(data_buf.shape)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data_buf) p_data = ib.buffer_ptr(data_buf)
index_out = ib.buffer_ptr(out_index_buf) index_out = ib.buffer_ptr(out_index_buf)
...@@ -225,7 +225,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold): ...@@ -225,7 +225,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
return i / u return i / u
batch, num_bbox = get_const_tuple(out_buf.shape) batch, num_bbox = get_const_tuple(out_buf.shape)
max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) max_threads = int(math.sqrt(tvm.target.Target.current(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
......
...@@ -35,7 +35,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): ...@@ -35,7 +35,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
if len(sch[data_out].op.axis) > 0: if len(sch[data_out].op.axis) > 0:
all_reduce = False all_reduce = False
num_thread = 32 num_thread = 32
target = tvm.target.current_target() target = tvm.target.Target.current()
if target and target.target_name == "opencl": if target and target.target_name == "opencl":
# without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py
# don't know why # don't know why
...@@ -45,7 +45,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): ...@@ -45,7 +45,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
else: else:
all_reduce = True all_reduce = True
num_thread = tvm.target.current_target(allow_none=False).max_num_threads num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
# Fuse and refactor the reduce axis # Fuse and refactor the reduce axis
......
...@@ -87,7 +87,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): ...@@ -87,7 +87,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
axis_mul_before *= value axis_mul_before *= value
elif i > axis: elif i > axis:
axis_mul_after *= value axis_mul_after *= value
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
data = ib.buffer_ptr(data) data = ib.buffer_ptr(data)
values_out = ib.buffer_ptr(values_out) values_out = ib.buffer_ptr(values_out)
...@@ -186,7 +186,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): ...@@ -186,7 +186,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
axis_mul_before *= value axis_mul_before *= value
elif i > axis: elif i > axis:
axis_mul_after *= value axis_mul_after *= value
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
data = ib.buffer_ptr(data) data = ib.buffer_ptr(data)
valid_count = ib.buffer_ptr(valid_count) valid_count = ib.buffer_ptr(valid_count)
......
...@@ -60,7 +60,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ...@@ -60,7 +60,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
The result IR statement. The result IR statement.
""" """
max_threads = int(math.sqrt( max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads)) tvm.target.Target.current(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y") ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
...@@ -196,7 +196,7 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp ...@@ -196,7 +196,7 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold) threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = (batch_size * num_anchors) // max_threads + 1 nthread_bx = (batch_size * num_anchors) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
...@@ -307,7 +307,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score ...@@ -307,7 +307,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
score = ib.buffer_ptr(temp_score) score = ib.buffer_ptr(temp_score)
out_loc = ib.buffer_ptr(out) out_loc = ib.buffer_ptr(out)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = (batch_size * num_anchors) // max_threads + 1 nthread_bx = (batch_size * num_anchors) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
......
...@@ -53,7 +53,7 @@ def schedule_reorg(outs): ...@@ -53,7 +53,7 @@ def schedule_reorg(outs):
s: Schedule s: Schedule
The computation schedule for reorg. The computation schedule for reorg.
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_injective(cpp_target, outs) return cpp.cuda.schedule_injective(cpp_target, outs)
......
...@@ -36,5 +36,5 @@ def schedule_extern(outs): ...@@ -36,5 +36,5 @@ def schedule_extern(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
return cpp.generic.schedule_extern(target, outs) return cpp.generic.schedule_extern(target, outs)
...@@ -54,7 +54,7 @@ def schedule_injective(outs): ...@@ -54,7 +54,7 @@ def schedule_injective(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
if target.target_name != "llvm": if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target) raise RuntimeError("schedule_injective not registered for '%s'" % target)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
......
...@@ -22,7 +22,7 @@ from .. import cpp ...@@ -22,7 +22,7 @@ from .. import cpp
def _default_schedule(outs, auto_inline): def _default_schedule(outs, auto_inline):
"""Default schedule for llvm.""" """Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name not in ("llvm", "c"): if target.target_name not in ("llvm", "c"):
raise RuntimeError("schedule not registered for '%s'" % target) raise RuntimeError("schedule not registered for '%s'" % target)
...@@ -645,7 +645,7 @@ def schedule_lrn(outs): ...@@ -645,7 +645,7 @@ def schedule_lrn(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False) return cpp.generic.default_schedule(cpp_target, outs, False)
...@@ -686,6 +686,6 @@ def schedule_sparse_transpose(outs): ...@@ -686,6 +686,6 @@ def schedule_sparse_transpose(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_batch_matmul(outs): def schedule_batch_matmul(outs):
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False) return cpp.generic.default_schedule(cpp_target, outs, False)
...@@ -22,7 +22,7 @@ from .. import cpp ...@@ -22,7 +22,7 @@ from .. import cpp
def _default_schedule(outs, auto_inline): def _default_schedule(outs, auto_inline):
"""Default schedule for llvm.""" """Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name != "llvm": if target.target_name != "llvm":
raise RuntimeError("schedule not registered for '%s'" % target) raise RuntimeError("schedule not registered for '%s'" % target)
...@@ -48,7 +48,7 @@ def schedule_reorg(outs): ...@@ -48,7 +48,7 @@ def schedule_reorg(outs):
s: Schedule s: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False) return cpp.generic.default_schedule(cpp_target, outs, False)
......
...@@ -221,7 +221,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -221,7 +221,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
return None return None
dispatch_ctx = autotvm.task.DispatchContext.current dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target() target = tvm.target.Target.current()
# query schedule and fallback if necessary # query schedule and fallback if necessary
workload = autotvm.task.args_to_workload( workload = autotvm.task.args_to_workload(
......
...@@ -59,7 +59,7 @@ def schedule_depthwise_conv2d_nchw_intel_graphics(cfg, outs): ...@@ -59,7 +59,7 @@ def schedule_depthwise_conv2d_nchw_intel_graphics(cfg, outs):
cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_knob("auto_unroll_max_step", [0, 256, 1500]) cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']: if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1]) cfg.define_knob("unroll_explicit", [1])
else: else:
...@@ -167,7 +167,7 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -167,7 +167,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
# num_thread here could be 728, it is larger than cuda.max_num_threads # num_thread here could be 728, it is larger than cuda.max_num_threads
num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value
target = tvm.target.current_target() target = tvm.target.Target.current()
if target and (target.target_name not in ["cuda", "nvptx"]): if target and (target.target_name not in ["cuda", "nvptx"]):
num_thread = target.max_num_threads num_thread = target.max_num_threads
xoc, xic = s[Output].split(c, factor=num_thread) xoc, xic = s[Output].split(c, factor=num_thread)
......
...@@ -153,7 +153,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): ...@@ -153,7 +153,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
# this part to make tuning records correct # this part to make tuning records correct
s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region') s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region')
else: else:
max_threads = tvm.target.current_target(allow_none=False).max_num_threads max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
co, ci, kh, kw, vc = s[kernel_vec].op.axis co, ci, kh, kw, vc = s[kernel_vec].op.axis
fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) fused = s[kernel_vec].fuse(co, ci, kh, kw, vc)
fused, vec = s[kernel_vec].split(fused, VC) fused, vec = s[kernel_vec].split(fused, VC)
......
...@@ -465,7 +465,7 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l ...@@ -465,7 +465,7 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn in_channel = ic_chunk * ic_bn
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
get_const_tuple(kernel.shape) get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn num_filter = oc_chunk * oc_bn
......
...@@ -57,7 +57,7 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -57,7 +57,7 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if "miopen" in target.libs: if "miopen" in target.libs:
assert layout == 'NCHW', "Only NCHW layout is supported." assert layout == 'NCHW', "Only NCHW layout is supported."
CO, CI, KH, KW = get_const_tuple(kernel.shape) CO, CI, KH, KW = get_const_tuple(kernel.shape)
...@@ -106,7 +106,7 @@ def schedule_conv2d_nchw_rocm(cfg, outs): ...@@ -106,7 +106,7 @@ def schedule_conv2d_nchw_rocm(cfg, outs):
s: Schedule s: Schedule
The computation schedule for conv2d. The computation schedule for conv2d.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if target and "miopen" in target.libs: if target and "miopen" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
......
...@@ -56,7 +56,7 @@ def dense_rocm(cfg, data, weight, bias=None, out_dtype=None): ...@@ -56,7 +56,7 @@ def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
out_dtype = data.dtype out_dtype = data.dtype
batch, in_dim = data.shape batch, in_dim = data.shape
out_dim, _ = weight.shape out_dim, _ = weight.shape
target = tvm.target.current_target() target = tvm.target.Target.current()
if "rocblas" in target.libs: if "rocblas" in target.libs:
assert out_dtype == data.dtype, "Mixed precision not supported." assert out_dtype == data.dtype, "Mixed precision not supported."
matmul = rocblas.matmul(data, weight, False, True) matmul = rocblas.matmul(data, weight, False, True)
...@@ -83,7 +83,7 @@ def schedule_dense(cfg, outs): ...@@ -83,7 +83,7 @@ def schedule_dense(cfg, outs):
s: Schedule s: Schedule
The computation schedule for dense. The computation schedule for dense.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if target.target_name == "rocm" and "rocblas" in target.libs: if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(cfg, outs) return topi.cuda.schedule_dense(cfg, outs)
...@@ -23,6 +23,6 @@ from .. import cpp ...@@ -23,6 +23,6 @@ from .. import cpp
@generic.schedule_lrn.register(["rocm", "gpu"]) @generic.schedule_lrn.register(["rocm", "gpu"])
def schedule_lrn(outs): def schedule_lrn(outs):
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_lrn(cpp_target, outs) return cpp.rocm.schedule_lrn(cpp_target, outs)
...@@ -43,7 +43,7 @@ def _declaration_batch_matmul_nopack(cfg, x, y): ...@@ -43,7 +43,7 @@ def _declaration_batch_matmul_nopack(cfg, x, y):
output : tvm.Tensor output : tvm.Tensor
3-D with shape [batch, M, N] 3-D with shape [batch, M, N]
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cblas" in target.libs: if "cblas" in target.libs:
return cblas.batch_matmul(x, y, False, True) return cblas.batch_matmul(x, y, False, True)
...@@ -83,7 +83,7 @@ def schedule_batch_matmul(cfg, outs): ...@@ -83,7 +83,7 @@ def schedule_batch_matmul(cfg, outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cblas" in target.libs: if "cblas" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
......
...@@ -74,7 +74,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): ...@@ -74,7 +74,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
kh, kw, oc, _ = kshape kh, kw, oc, _ = kshape
elif pat.match(layout) is not None: elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape n, ic_chunk, h, w, ic_bn = dshape
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk assert ic_chunk == k_ic_chunk
assert ic_bn == k_ic_bn assert ic_bn == k_ic_bn
...@@ -423,7 +423,7 @@ def _schedule_conv2d_NCHWc(cfg, outs): ...@@ -423,7 +423,7 @@ def _schedule_conv2d_NCHWc(cfg, outs):
data = data_pad.op.input_tensors[0] data = data_pad.op.input_tensors[0]
args = [s, cfg, data_vec, conv_out, outs[0]] args = [s, cfg, data_vec, conv_out, outs[0]]
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
_, _, kh, kw, _, _, = get_const_tuple(kernel.shape) _, _, kh, kw, _, _, = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1: if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_NCHWc(*args) conv2d_avx_1x1._schedule_conv_NCHWc(*args)
......
...@@ -75,7 +75,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -75,7 +75,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
# Set workload. Config update. # Set workload. Config update.
dispatch_ctx = autotvm.task.DispatchContext.current dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target() target = tvm.target.Target.current()
if is_depthwise: if is_depthwise:
workload = autotvm.task.args_to_workload( workload = autotvm.task.args_to_workload(
......
...@@ -64,11 +64,11 @@ def _is_int8_hw_support(data_dtype, kernel_dtype): ...@@ -64,11 +64,11 @@ def _is_int8_hw_support(data_dtype, kernel_dtype):
is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8' is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
# 2) Check LLVM support # 2) Check LLVM support
llvm_version = tvm.codegen.llvm_version_major() llvm_version = tvm.target.codegen.llvm_version_major()
is_llvm_support = llvm_version >= 8 is_llvm_support = llvm_version >= 8
# 3) Check target # 3) Check target
mcpu = tvm.target.current_target().mcpu mcpu = tvm.target.Target.current().mcpu
is_target_support = False is_target_support = False
if mcpu in ('skylake-avx512', 'cascadelake'): if mcpu in ('skylake-avx512', 'cascadelake'):
is_target_support = True is_target_support = True
...@@ -89,7 +89,7 @@ def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, lay ...@@ -89,7 +89,7 @@ def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, lay
kh, kw, oc, _ = kshape kh, kw, oc, _ = kshape
elif pat.match(layout) is not None: elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape n, ic_chunk, h, w, ic_bn = dshape
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk * ic_bn ic = ic_chunk * ic_bn
assert ic == k_ic * k_ic_f * k_ic_s assert ic == k_ic * k_ic_f * k_ic_s
...@@ -205,7 +205,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, outs): ...@@ -205,7 +205,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, outs):
data = data_pad.op.input_tensors[0] data = data_pad.op.input_tensors[0]
args = [s, cfg, data_vec, conv_out, outs[0]] args = [s, cfg, data_vec, conv_out, outs[0]]
target = tvm.target.current_target(allow_none=False) target = tvm.target.Target.current(allow_none=False)
# int8 conv kernel is 7-dim # int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1: if kh == 1 and kw == 1:
......
...@@ -28,7 +28,7 @@ from ..util import traverse_inline, get_const_tuple ...@@ -28,7 +28,7 @@ from ..util import traverse_inline, get_const_tuple
@autotvm.register_topi_compute(nn.dense, "cpu", "direct") @autotvm.register_topi_compute(nn.dense, "cpu", "direct")
def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cblas" in target.libs: if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True) C = cblas.matmul(data, weight, False, True)
if bias is not None: if bias is not None:
...@@ -119,7 +119,7 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): ...@@ -119,7 +119,7 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct") @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
def _schedule_dense(cfg, outs): def _schedule_dense(cfg, outs):
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cblas" in target.libs: if "cblas" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
...@@ -136,7 +136,7 @@ def _schedule_dense(cfg, outs): ...@@ -136,7 +136,7 @@ def _schedule_dense(cfg, outs):
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack") @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
def _schedule_dense_pack(cfg, outs): def _schedule_dense_pack(cfg, outs):
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cblas" in target.libs: if "cblas" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
...@@ -151,7 +151,7 @@ def _schedule_dense_pack(cfg, outs): ...@@ -151,7 +151,7 @@ def _schedule_dense_pack(cfg, outs):
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack") @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
def _schedule_dense_nopack(cfg, outs): def _schedule_dense_nopack(cfg, outs):
target = tvm.target.current_target() target = tvm.target.Target.current()
if "cblas" in target.libs: if "cblas" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
......
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
"""Core kernel of dot product of 4 Int8 operations""" """Core kernel of dot product of 4 Int8 operations"""
#pylint: disable=invalid-name #pylint: disable=invalid-name
import tvm import tvm
import tvm.target.codegen
def dot_16x1x16_uint8_int8_int32(): def dot_16x1x16_uint8_int8_int32():
"""Dispatch the most optimized intrin depending on the target""" """Dispatch the most optimized intrin depending on the target"""
mcpu = tvm.target.current_target().mcpu mcpu = tvm.target.Target.current().mcpu
assert mcpu in ("skylake-avx512", "cascadelake"), \ assert mcpu in ("skylake-avx512", "cascadelake"), \
"An old Intel machine that does not have fast Int8 support." "An old Intel machine that does not have fast Int8 support."
...@@ -254,7 +255,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake(): ...@@ -254,7 +255,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake():
vec_b = ins[1].vload([0, 0], "int8x64") vec_b = ins[1].vload([0, 0], "int8x64")
vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512' vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512'
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name)
if llvm_id != 0: # VNNI is available for current LLVM version if llvm_id != 0: # VNNI is available for current LLVM version
vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret', vec_b) vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret', vec_b)
......
...@@ -19,7 +19,7 @@ from __future__ import absolute_import as _abs ...@@ -19,7 +19,7 @@ from __future__ import absolute_import as _abs
import tvm import tvm
def get_fp32_len(): def get_fp32_len():
mcpu = tvm.target.current_target().mcpu mcpu = tvm.target.Target.current().mcpu
fp32_vec_len = 8 fp32_vec_len = 8
if mcpu in ('skylake-avx512', 'cascadelake'): if mcpu in ('skylake-avx512', 'cascadelake'):
fp32_vec_len = 16 fp32_vec_len = 16
......
...@@ -84,7 +84,7 @@ def compute_conv2d(attrs, inputs, output_type, target): ...@@ -84,7 +84,7 @@ def compute_conv2d(attrs, inputs, output_type, target):
groups, groups,
out_dtype)] out_dtype)]
# If it's not packed, run on ARM CPU # If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model): with tvm.target.arm_cpu(tvm.target.Target.current().model):
return _nn.compute_conv2d(attrs, inputs, output_type, target) return _nn.compute_conv2d(attrs, inputs, output_type, target)
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
...@@ -105,8 +105,8 @@ def schedule_conv2d(attrs, outs, target): ...@@ -105,8 +105,8 @@ def schedule_conv2d(attrs, outs, target):
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
return topi.generic.schedule_group_conv2d_nchw(outs) return topi.generic.schedule_group_conv2d_nchw(outs)
# If it's not packed, run on ARM CPU # If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model): with tvm.target.arm_cpu(tvm.target.Target.current().model):
return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target()) return _nn.schedule_conv2d(attrs, outs, tvm.target.Target.current())
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
return _nn.schedule_conv2d(attrs, outs, target) return _nn.schedule_conv2d(attrs, outs, target)
...@@ -128,7 +128,7 @@ def compute_conv2d_transpose(attrs, inputs, output_type, target): ...@@ -128,7 +128,7 @@ def compute_conv2d_transpose(attrs, inputs, output_type, target):
return [topi.nn.conv2d_transpose_nchw( return [topi.nn.conv2d_transpose_nchw(
inputs[0], inputs[1], strides, padding, out_dtype)] inputs[0], inputs[1], strides, padding, out_dtype)]
# If it's not packed, run on ARM CPU # If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model): with tvm.target.arm_cpu(tvm.target.Target.current().model):
return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target) return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
...@@ -145,11 +145,11 @@ def schedule_conv2d_transpose(attrs, outputs, target): ...@@ -145,11 +145,11 @@ def schedule_conv2d_transpose(attrs, outputs, target):
if is_packed_layout(layout): if is_packed_layout(layout):
return topi.nn.schedule_conv2d_transpose_nchw(outputs) return topi.nn.schedule_conv2d_transpose_nchw(outputs)
# If it's not packed, run on ARM CPU # If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model): with tvm.target.arm_cpu(tvm.target.Target.current().model):
return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target()) return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target()) return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
@reg.register_compute("nn.dense", level=15) @reg.register_compute("nn.dense", level=15)
...@@ -163,7 +163,7 @@ def compute_dense(attrs, inputs, out_type, target): ...@@ -163,7 +163,7 @@ def compute_dense(attrs, inputs, out_type, target):
target = tvm.target.create(target) target = tvm.target.create(target)
return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)] return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
# If it's not packed, run on ARM CPU # If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model): with tvm.target.arm_cpu(tvm.target.Target.current().model):
return _nn.compute_dense(attrs, inputs, out_type, target) return _nn.compute_dense(attrs, inputs, out_type, target)
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
...@@ -179,8 +179,8 @@ def schedule_dense(attrs, outs, target): ...@@ -179,8 +179,8 @@ def schedule_dense(attrs, outs, target):
assert target.device_name == "vta" assert target.device_name == "vta"
return topi.generic.schedule_dense(outs) return topi.generic.schedule_dense(outs)
# If it's not packed, run on ARM CPU # If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model): with tvm.target.arm_cpu(tvm.target.Target.current().model):
return _nn.schedule_dense(attrs, outs, tvm.target.current_target()) return _nn.schedule_dense(attrs, outs, tvm.target.Target.current())
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
return _nn.schedule_dense(attrs, outs, target) return _nn.schedule_dense(attrs, outs, target)
...@@ -80,7 +80,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation): ...@@ -80,7 +80,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation):
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype) res = topi.cast(res, env.out_dtype)
if tvm.target.current_target().device_name == 'vta': if tvm.target.Target.current().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res]) s = topi.generic.schedule_conv2d_nchw([res])
else: else:
s = tvm.create_schedule([res.op]) s = tvm.create_schedule([res.op])
......
...@@ -68,7 +68,7 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding): ...@@ -68,7 +68,7 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding):
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype) res = topi.cast(res, env.out_dtype)
if tvm.target.current_target().device_name == 'vta': if tvm.target.Target.current().device_name == 'vta':
s = topi.generic.schedule_conv2d_transpose_nchw([res]) s = topi.generic.schedule_conv2d_transpose_nchw([res])
else: else:
s = tvm.create_schedule([res.op]) s = tvm.create_schedule([res.op])
......
...@@ -59,7 +59,7 @@ def dense(N, CI, CO): ...@@ -59,7 +59,7 @@ def dense(N, CI, CO):
res = my_clip(res, 0, 127) res = my_clip(res, 0, 127)
res = topi.cast(res, "int8") res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta': if tvm.target.Target.current().device_name == 'vta':
s = topi.generic.schedule_dense([res]) s = topi.generic.schedule_dense([res])
else: else:
s = tvm.create_schedule([res.op]) s = tvm.create_schedule([res.op])
......
...@@ -80,7 +80,7 @@ def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group): ...@@ -80,7 +80,7 @@ def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group):
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype) res = topi.cast(res, env.out_dtype)
if tvm.target.current_target().device_name == 'vta': if tvm.target.Target.current().device_name == 'vta':
s = topi.generic.schedule_group_conv2d_nchw([res]) s = topi.generic.schedule_group_conv2d_nchw([res])
else: else:
s = tvm.create_schedule([res.op]) s = tvm.create_schedule([res.op])
......
...@@ -84,7 +84,7 @@ def register_vta_tuning_tasks(): ...@@ -84,7 +84,7 @@ def register_vta_tuning_tasks():
res = my_clip(res, 0, 127) res = my_clip(res, 0, 127)
res = topi.cast(res, "int8") res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta': if tvm.target.Target.current().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res]) s = topi.generic.schedule_conv2d_nchw([res])
else: else:
s = tvm.create_schedule([res.op]) s = tvm.create_schedule([res.op])
...@@ -102,7 +102,7 @@ def register_vta_tuning_tasks(): ...@@ -102,7 +102,7 @@ def register_vta_tuning_tasks():
res = my_clip(res, 0, 127) res = my_clip(res, 0, 127)
res = topi.cast(res, "int8") res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta': if tvm.target.Target.current().device_name == 'vta':
s = topi.generic.schedule_dense([res]) s = topi.generic.schedule_dense([res])
else: else:
s = tvm.create_schedule([res.op]) s = tvm.create_schedule([res.op])
......
...@@ -321,7 +321,7 @@ def register_vta_tuning_tasks(): ...@@ -321,7 +321,7 @@ def register_vta_tuning_tasks():
res = my_clip(res, 0, 127) res = my_clip(res, 0, 127)
res = topi.cast(res, "int8") res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta': if tvm.target.Target.current().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res]) s = topi.generic.schedule_conv2d_nchw([res])
else: else:
s = tvm.create_schedule([res.op]) s = tvm.create_schedule([res.op])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment