Unverified Commit 9816efc2 by Tianqi Chen Committed by GitHub

[REFACTOR][PY][API-CHANGE] Remove legacy python files. (#4943)

* [REFACTOR][PY][API-CHANGE] Remove legacy python files.

Remove legacy python files.
Use the te namespace for most of the tensor expression primitives.

- tvm.create_schedule -> tvm.te.create_schedule
- tvm.placeholder -> tvm.te.placeholder
- tvm.compute -> tvm.te.compute

* Remove top-level exposures.
parent c9be16bd
......@@ -22,6 +22,7 @@ Use "android" as the key if you wish to avoid modifying this script.
"""
import tvm
from tvm import te
import os
from tvm import rpc
from tvm.contrib import util, ndk
......@@ -44,9 +45,9 @@ test_vulkan = False
def test_rpc_module():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
n = tvm.runtime.convert(1024)
A = te.placeholder((n,), name='A')
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
a_np = np.random.uniform(size=1024).astype(A.dtype)
temp = util.tempdir()
......@@ -56,7 +57,7 @@ def test_rpc_module():
session_timeout=60)
# Compile the Graph for CPU target
s = tvm.create_schedule(B.op)
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].parallel(xi)
s[B].pragma(xo, "parallel_launch_point")
......@@ -79,10 +80,10 @@ def test_rpc_module():
# Compile the Graph for OpenCL target
if test_opencl:
s = tvm.create_schedule(B.op)
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, te.thread_axis("threadIdx.x"))
s[B].bind(xo, te.thread_axis("blockIdx.x"))
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "opencl", target_host=target, name="myadd")
......@@ -102,10 +103,10 @@ def test_rpc_module():
# Compile the Graph for Vulkan target
if test_vulkan:
s = tvm.create_schedule(B.op)
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, te.thread_axis("threadIdx.x"))
s[B].bind(xo, te.thread_axis("blockIdx.x"))
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "vulkan", target_host=target, name="myadd")
......
......@@ -22,6 +22,7 @@ import argparse
import numpy as np
import tvm
from tvm import te
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
from tvm import relay
......
......@@ -23,6 +23,7 @@ import threading
import numpy as np
import tvm
from tvm import te
import tvm.contrib.graph_runtime as runtime
from tvm import relay
......
......@@ -22,6 +22,7 @@ import argparse
import numpy as np
import tvm
from tvm import te
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
from tvm import relay
......
......@@ -20,6 +20,7 @@ import argparse
import os
from tvm import relay
import tvm
from tvm import te
import logging
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import os
def test_plugin_module():
......
......@@ -21,6 +21,7 @@ import os
import ctypes
# Import TVM first to get library symbols
import tvm
from tvm import te
def load_lib():
"""Load library, the functions will be registered into TVM"""
......
......@@ -16,6 +16,8 @@
# under the License.
import tvm_ext
import tvm
import tvm._ffi.registry
from tvm import te
import numpy as np
def test_bind_add():
......@@ -26,9 +28,9 @@ def test_bind_add():
def test_ext_dev():
n = 10
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
A = te.placeholder((n,), name='A')
B = te.compute((n,), lambda *i: A(*i) + 1.0, name='B')
s = te.create_schedule(B.op)
def check_llvm():
if not tvm.runtime.enabled("llvm"):
return
......@@ -43,8 +45,8 @@ def test_ext_dev():
def test_sym_add():
a = tvm.var('a')
b = tvm.var('b')
a = te.var('a')
b = te.var('b')
c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b
......@@ -59,19 +61,20 @@ def test_ext_vec():
assert(isinstance(v2, tvm_ext.IntVec))
assert v2[2] == 3
tvm.convert(ivec_cb)(ivec)
tvm.runtime.convert(ivec_cb)(ivec)
def test_extract_ext():
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
fdict = tvm._ffi.registry.extract_ext_funcs(
tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12
def test_extern_call():
n = 10
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: tvm.call_extern("float32", "TVMTestAddOne", A(*i)), name='B')
s = tvm.create_schedule(B.op)
A = te.placeholder((n,), name='A')
B = te.compute((n,), lambda *i: tvm.tir.call_extern("float32", "TVMTestAddOne", A(*i)), name='B')
s = te.create_schedule(B.op)
def check_llvm():
if not tvm.runtime.enabled("llvm"):
......
......@@ -16,13 +16,14 @@
# under the License.
"""Script to prepare test_addone.so"""
import tvm
from tvm import te
import os
def prepare_test_libs(base_path):
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
n = te.var("n")
A = te.placeholder((n,), name='A')
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = te.create_schedule(B.op)
# Compile library as dynamic library
fadd_dylib = tvm.build(s, [A, B], "llvm", name="addone")
dylib_path = os.path.join(base_path, "test_addone_dll.so")
......
......@@ -19,6 +19,7 @@
# file python_deploy.py
import tvm
from tvm import te
import numpy as np
def verify(mod, fname):
......
......@@ -21,6 +21,7 @@ And configure the proxy host field as commented.
"""
import tvm
from tvm import te
import os
import re
import sys
......@@ -54,14 +55,14 @@ def compile_metal(src):
def test_rpc_module():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
n = tvm.runtime.convert(1024)
A = te.placeholder((n,), name='A')
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
temp = util.tempdir()
s = tvm.create_schedule(B.op)
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, te.thread_axis("threadIdx.x"))
s[B].bind(xo, te.thread_axis("blockIdx.x"))
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "metal", target_host=target, name="myadd")
......@@ -70,7 +71,7 @@ def test_rpc_module():
arch=arch, sdk=sdk)
xcode.codesign(path_dso1)
s = tvm.create_schedule(B.op)
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].parallel(xi)
s[B].pragma(xo, "parallel_launch_point")
......
......@@ -23,6 +23,7 @@ from os import path as osp
from tvm import relay
from tvm.relay import testing
import tvm
from tvm import te
def main():
......
......@@ -17,6 +17,7 @@
import os.path as osp
import numpy as np
import tvm
from tvm import te
CWD = osp.abspath(osp.dirname(__file__))
......
......@@ -23,6 +23,7 @@ tvm.te
:members:
:imported-members:
:exclude-members:
any, all, min_value, max_value, trace,
exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil,
trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else,
div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod,
......
......@@ -20,5 +20,5 @@ tvm.tir
.. automodule:: tvm.tir
:members:
:imported-members:
:exclude-members: PrimExpr
:exclude-members: PrimExpr, const
:autosummary:
......@@ -61,6 +61,7 @@ source_parsers = {
os.environ['TVM_BUILD_DOC'] = '1'
# Version information.
import tvm
from tvm import te
version = tvm.__version__
release = tvm.__version__
......
......@@ -21,6 +21,7 @@ Get Started with TVM Go
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
# Global declarations of environment.
......@@ -31,15 +32,15 @@ tgt="llvm"
######################################################################
# Describe the Computation
# ------------------------
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
n = te.var("n")
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
######################################################################
# Schedule the Computation
# ------------------------
s = tvm.create_schedule(C.op)
s = te.create_schedule(C.op)
######################################################################
# Compilation
......
......@@ -17,14 +17,15 @@
import os
import tvm
from tvm import te
from tvm.contrib import cc, util
def test_add(target_dir):
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
n = te.var("n")
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = te.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], "llvm", target_host="llvm", name="myadd")
fadd.save(os.path.join(target_dir, "add_cpu.o"))
......
......@@ -17,22 +17,23 @@
import os
import tvm
from tvm import te
from tvm.contrib import cc, util
def test_add(target_dir):
if not tvm.runtime.enabled("cuda"):
print("skip %s because cuda is not enabled..." % __file__)
return
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
n = te.var("n")
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
......
......@@ -17,14 +17,15 @@
import os
import tvm
from tvm import te
import json
from tvm.contrib import graph_runtime
def dump_graph_lib(target_dir):
dim = 4
A = tvm.placeholder((dim,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
sched = tvm.create_schedule(B.op)
A = te.placeholder((dim,), name='A')
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
sched = te.create_schedule(B.op)
node0 = {"op": "null", "name": "x", "inputs": []}
node1 = {"op": "tvm_op", "name": "add",
......
......@@ -24,7 +24,7 @@ import traceback
# tvm._ffi
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi import register_object, register_func, register_extension, get_global_func
# top-level alias
# tvm.runtime
......@@ -47,10 +47,9 @@ from . import tir
# tvm.target
from . import target
from .target import build_config
# tvm.te
from .te import decl_tensor_intrin, create_schedule, tag_scope
from . import te
# tvm.testing
from . import testing
......@@ -64,14 +63,6 @@ from . import hybrid
# others
from . import arith
# backward compact for topi, to be removed later
from .api import *
from .tir import expr, stmt, ir_builder, ir_pass, generic
from .te import tensor, schedule
from .tir.op import *
from . import intrin
from . import make
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin
import tvm._ffi
import tvm.ir
import tvm.tir
from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container, Range
from tvm.tir import decl_buffer, layout, bijective_layout
from tvm.tir import min_value, max_value, indexdiv, indexmod, all, any
from tvm.te import placeholder, compute, scan, extern, var, size_var, thread_axis, reduce_axis
from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from . import make as _make
int8 = "int8"
int32 = "int32"
float32 = "float32"
handle = "handle"
......@@ -212,7 +212,7 @@ class Analyzer:
--------
.. code-block:: python
x = tvm.var("x")
x = te.var("x")
analyzer = tvm.arith.Analyzer()
with analzyer.constraint_scope(x % 3 == 0):
# constraint in effect
......
......@@ -28,8 +28,11 @@ There are two types of feature
import struct
import numpy as np
import tvm._ffi
from tvm import schedule, ir_pass, get_global_func, target as _target
from tvm import target as _target
from tvm.tir import ir_pass
from tvm.te import schedule
from tvm.driver import build_module
def ana_lower(sch, args,
......@@ -49,10 +52,12 @@ def ana_lower(sch, args,
return stmt
try:
_get_buffer_curve_sample_flatten = get_global_func(
_get_buffer_curve_sample_flatten = tvm._ffi.get_global_func(
"autotvm.feature.GetCurveSampleFeatureFlatten")
_get_itervar_feature = get_global_func("autotvm.feature.GetItervarFeature")
_get_itervar_feature_flatten = get_global_func("autotvm.feature.GetItervarFeatureFlatten")
_get_itervar_feature = tvm._ffi.get_global_func(
"autotvm.feature.GetItervarFeature")
_get_itervar_feature_flatten = tvm._ffi.get_global_func(
"autotvm.feature.GetItervarFeatureFlatten")
except ValueError as e:
def raise_error(*args, **kwargs): # pylint: disable=unused-argument
raise RuntimeError("Cannot load autotvm c++ API")
......@@ -64,8 +69,8 @@ def get_itervar_feature(sch, args, take_log=False):
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
sch: tvm.te.schedule.Schedule
args: Array of te.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
......@@ -112,8 +117,8 @@ def get_itervar_feature_flatten(sch, args, take_log=True):
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
sch: tvm.te.schedule.Schedule
args: Array of te.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
......@@ -185,8 +190,8 @@ def get_buffer_curve_sample_flatten(sch, args, sample_n=30):
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
sch: tvm.te.schedule.Schedule
args: Array of te.tensor.Tensor
the buffer args for lower
sample_n: int
number of sample points along one dimension
......
......@@ -23,6 +23,7 @@ import numpy as np
import topi
import tvm
from tvm import te
from tvm import autotvm, relay
from tvm.autotvm.task import get_config
from tvm.autotvm.record import encode, load_from_file
......@@ -301,8 +302,8 @@ class BaseGraphTuner(object):
_, out_layout = o_input_info[0]
else:
_, out_layout = o_output_info[0]
data_placeholder = tvm.placeholder(in_shape, name="data",
dtype=self._dtype)
data_placeholder = te.placeholder(in_shape, name="data",
dtype=self._dtype)
args = [data_placeholder, in_layout, out_layout]
callback(i_idx, o_idx, m, n, args)
......
......@@ -33,9 +33,13 @@ import tempfile
import numpy as np
from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
rpc as _rpc, target as _target
from ...contrib import nvcc, ndk, tar
import tvm._ffi
from tvm import nd, rpc as _rpc, target as _target
from tvm.tir import ir_pass
from tvm.error import TVMError
from tvm.target import build_config
from tvm.driver import build
from tvm.contrib import nvcc, ndk, tar
from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
......@@ -581,7 +585,7 @@ def check_remote(target, device_key, host=None, port=None, priority=100, timeout
return not t.is_alive()
@register_func
@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate ptx code for better optimization"""
curr_cuda_target_arch = AutotvmGlobalScope.current.cuda_target_arch
......
......@@ -22,7 +22,7 @@ code hashing is used to check the consistence of schedule code and the parameter
import inspect
import zlib
from tvm import schedule
from tvm.te import schedule
def attach_code_hash(s):
"""Decorator for attaching a code hash to a schedule
......@@ -30,7 +30,7 @@ def attach_code_hash(s):
Parameters
----------
s: Schedule
tvm.schedule.Schedule to attach the hash to
tvm.te.schedule.Schedule to attach the hash to
"""
def decorator(func):
def wrapper(*args, **kwargs):
......
......@@ -32,7 +32,7 @@ import math
from collections import namedtuple, OrderedDict
import numpy as np
from tvm import schedule, thread_axis
from tvm.te import schedule, thread_axis
from tvm.autotvm.util import get_const_int
Axis = namedtuple('Axis', ['space', 'index'])
......@@ -57,7 +57,7 @@ class TransformSpace(object):
.. note::
We can regard our schedule code as a transformation graph of axes.
Starting from raw axes in the definition of tvm.compute, we can transform these axes
Starting from raw axes in the definition of te.compute, we can transform these axes
by some operators. The operator includes 'split', 'reorder' and 'annotate'.
Each operator has some tunable parameters (e.g. the split factor).
Then the tuning process is just to find good parameters of these op.
......@@ -106,7 +106,7 @@ class VirtualAxis(TransformSpace):
Parameters
----------
var: int or tvm.schedule.IterVar
var: int or tvm.te.schedule.IterVar
If is int, return a virtual axis whose length is the provided argument.
If is IterVar, return a virtual axis whose length is extracted from
the IterVar's extent domain.
......@@ -266,11 +266,11 @@ class SplitEntity(object):
Parameters
----------
sch: tvm.schedule.Schedule
sch: tvm.te.schedule.Schedule
The tvm schedule
op: tvm.tensor.Operation
op: tvm.te.Operation
The stage to be applied
axis: tvm.schedule.IterVar
axis: tvm.te.schedule.IterVar
axis to split
Returns
......@@ -390,11 +390,11 @@ class ReorderEntity(object):
Parameters
----------
sch: tvm.schedule.Schedule
sch: tvm.te.schedule.Schedule
The tvm schedule
op: tvm.tensor.Operation
op: tvm.te.Operation
The stage to be applied
axis: tvm.schedule.IterVar
axis: tvm.te.schedule.IterVar
axis to split
Returns
......@@ -513,11 +513,11 @@ class AnnotateEntity(object):
Parameters
----------
sch: tvm.schedule.Schedule
sch: tvm.te.schedule.Schedule
The tvm schedule
op: tvm.tensor.Operation
op: tvm.te.Operation
The stage to be applied
axes: Array of tvm.schedule.IterVar
axes: Array of tvm.te.schedule.IterVar
axis to split
axis_lens: Array of int, optional
the length of axes
......@@ -532,7 +532,7 @@ class AnnotateEntity(object):
Returns
-------
axes : list of tvm.schedule.IterVar
axes : list of tvm.te.schedule.IterVar
The transformed axes
"""
if source is not None: # special case : attach cache_read/cache_write
......@@ -624,7 +624,7 @@ class ConfigSpace(object):
Parameters
----------
var: int or tvm.schedule.IterVar
var: int or tvm.te.schedule.IterVar
If is int, return an axis whose length is the provided argument.
If is IterVar, return an axis whose length is extracted from the
IterVar's extent domain.
......@@ -640,7 +640,7 @@ class ConfigSpace(object):
----------
name: str
name to index the entity of this space
axis: tvm.schedule.IterVar
axis: tvm.te.schedule.IterVar
axis to split
policy: str
name of policy.
......@@ -681,7 +681,7 @@ class ConfigSpace(object):
----------
name: str
name to index the entity of this space
axes: Array of tvm.schedule.IterVar
axes: Array of tvm.te.schedule.IterVar
axes to reorder
policy: str
name of policy
......@@ -702,7 +702,7 @@ class ConfigSpace(object):
----------
name: str
name to index the entity of this space
axes: Array of tvm.schedule.IterVar
axes: Array of tvm.te.schedule.IterVar
axes to annotate
policy: str
name of policy
......
......@@ -21,10 +21,13 @@ Task can be constructed from tuple of func, args, and kwargs.
func is a state-less function, or a string that
registers the standard task.
"""
import numpy as np
from ... import tensor, expr, container, placeholder, target as _target
from tvm import target as _target
from tvm.ir import container
from tvm.tir import expr
from tvm.te import tensor, placeholder
from ..util import get_const_int, get_const_tuple
from .dispatcher import DispatchContext, ApplyConfig
......@@ -81,7 +84,7 @@ def deserialize_args(args):
def args_to_workload(args, task_name=None):
"""Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple
flatten te.tensor.Tensor to a tuple
Parameters
----------
......@@ -138,9 +141,9 @@ class Task(object):
Returns
-------
sch: tvm.schedule.Schedule
sch: tvm.te.schedule.Schedule
The tvm schedule
arg_bufs: Array of tvm.tensor.Tensor
arg_bufs: Array of te.tensor.Tensor
The input/output buffers
"""
config.flop = 0
......@@ -303,12 +306,12 @@ def register_customized_task(name, func=None):
@autotvm.register_customized_task("matmul")
def matmul(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
k = te.reduce_axis((0, L), name='k')
C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name='C')
s = te.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
......@@ -400,7 +403,7 @@ def compute_flop(sch):
Parameters
----------
sch: tvm.schedule.Schedule
sch: tvm.te.schedule.Schedule
schedule
Returns
......@@ -475,8 +478,8 @@ def compute_flop(sch):
elif isinstance(op, tensor.PlaceholderOp):
pass
else:
raise FlopCalculationError("Only support tvm.compute currently. "
"Other ops like tvm.scan/tvm.extern is not supported")
raise FlopCalculationError("Only support te.compute currently. "
"Other ops like tvm.te.scan/te.extern is not supported")
return ret
try:
......
......@@ -21,15 +21,15 @@ Decorators for registering tunable templates to TOPI.
These decorators can make your simple implementation be able to use different configurations
for different workloads.
Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments
(except tvm.Tensor) in you calls are hashable. For tvm.Tensor, we will serialize it to a hashable
tuple.
(except tvm.te.Tensor) in you calls are hashable. For tvm.te.Tensor,
we will serialize it to a hashable tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
import tvm.te._ffi_api
from tvm import target as _target
from tvm.te import tensor
from ... import tensor
from .task import args_to_workload, DispatchContext, \
register_task_compute, register_task_schedule, serialize_args
......
......@@ -24,7 +24,7 @@ from random import randrange
import numpy as np
from .. import expr, ir_pass
from tvm.tir import expr, ir_pass
logger = logging.getLogger('autotvm')
......
......@@ -18,8 +18,9 @@
"""Utilities for binary file manipulation"""
import os
import subprocess
import tvm._ffi
from . import util
from ..api import register_func
RELOCATION_LD_SCRIPT_TEMPLATE = """
/* linker symbol for use in UTVMInit */
......@@ -95,7 +96,7 @@ def run_cmd(cmd):
return output
@register_func("tvm_callback_get_section_size")
@tvm._ffi.register_func("tvm_callback_get_section_size")
def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix):
"""Finds size of the section in the binary.
Assumes `size` shell command exists (typically works only on Linux machines)
......@@ -162,7 +163,7 @@ def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix):
return section_size
@register_func("tvm_callback_relocate_binary")
@tvm._ffi.register_func("tvm_callback_relocate_binary")
def tvm_callback_relocate_binary(
binary_path,
word_size,
......@@ -233,7 +234,7 @@ def tvm_callback_relocate_binary(
return rel_bin
@register_func("tvm_callback_read_binary_section")
@tvm._ffi.register_func("tvm_callback_read_binary_section")
def tvm_callback_read_binary_section(binary, section, toolchain_prefix):
"""Returns the contents of the specified section in the binary byte array
......@@ -273,7 +274,7 @@ def tvm_callback_read_binary_section(binary, section, toolchain_prefix):
return section_bin
@register_func("tvm_callback_get_symbol_map")
@tvm._ffi.register_func("tvm_callback_get_symbol_map")
def tvm_callback_get_symbol_map(binary, toolchain_prefix):
"""Obtains a map of symbols to addresses in the passed binary
......
......@@ -16,7 +16,7 @@
# under the License.
"""External function interface to BLAS libraries."""
import tvm
from .. import api as _api
from tvm import te
def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
......@@ -41,7 +41,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
return te.extern(
(n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
......@@ -75,7 +75,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return _api.extern(
return te.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
......
......@@ -16,7 +16,8 @@
# under the License.
"""External function interface to cuBLAS libraries."""
import tvm
from .. import api as _api
from tvm import te
def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
......@@ -40,7 +41,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
return te.extern(
(n, m), [lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublas.matmul",
......@@ -69,7 +70,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
return te.extern(
(b, n, m), [lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublas.batch_matmul",
......
......@@ -16,7 +16,7 @@
# under the License.
"""External function interface to cuBLASlt libraries."""
import tvm
from .. import api as _api
from tvm import te
def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
......@@ -43,7 +43,7 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
if m == 0:
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
return te.extern(
(n, m), [lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublaslt.matmul",
......
......@@ -19,8 +19,9 @@
import ctypes
import numpy as np
import tvm
from .. import api as _api
from .. import get_global_func as _get_global_func
import tvm._ffi
from tvm import te
# algos can be read from cudnn.h
_FWD_ALGOS = [
......@@ -217,7 +218,7 @@ def conv_output_shape(tensor_format,
_prepare_global_func_params(dims - 2, pad, stride, dilation, x_shape, w_shape)
oshape = np.zeros((dims), dtype=np.int32)
func = _get_global_func("tvm.contrib.cudnn.conv.output_shape")
func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape")
func(tensor_format,
dims - 2,
_get_np_int32_array_handle(pad),
......@@ -276,7 +277,7 @@ def conv_find_algo(tensor_format,
pad, stride, dilation, xshape, wshape = \
_prepare_global_func_params(dims - 2, pad, stride, dilation, x_shape, w_shape)
yshape = np.array(y_shape, dtype=np.int32)
func = _get_global_func("tvm.contrib.cudnn.conv.find_algo")
func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.find_algo")
return func(tensor_format,
dims - 2,
_get_np_int32_array_handle(pad),
......@@ -363,7 +364,7 @@ def conv_forward(x,
conv_dtype)
if dims == 4:
return _api.extern(
return te.extern(
oshape, [x, w],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d.forward",
......@@ -381,7 +382,7 @@ def conv_forward(x,
outs[0],
conv_dtype), name="y")
return _api.extern(
return te.extern(
oshape, [x, w],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv3d.forward",
......
......@@ -21,6 +21,7 @@ import os
import numpy as np
import tvm
GRAPH_DUMP_FILE_NAME = '_tvmdbg_graph_dump.json'
CHROME_TRACE_FILE_NAME = "_tvmdbg_execution_trace.json"
......
......@@ -19,8 +19,9 @@
import ctypes
import numpy as np
import tvm
from .. import api as _api
from .. import get_global_func as _get_global_func
import tvm._ffi
from tvm import te
def _get_np_int32_array_handle(arr):
......@@ -91,7 +92,7 @@ def conv2d_forward(x,
oshape = np.zeros((len(x.shape)), dtype=np.int32)
xshape = x.shape
wshape = w.shape
setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup")
setup_func = tvm._ffi.get_global_func("tvm.contrib.miopen.conv2d.setup")
algo = setup_func(conv_mode,
data_type,
pad_h,
......@@ -111,7 +112,7 @@ def conv2d_forward(x,
group_count,
_get_np_int32_array_handle(oshape))
return _api.extern(
return te.extern(
list(oshape), [x, w],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.conv2d.forward",
......
......@@ -16,7 +16,8 @@
# under the License.
"""External function interface to MPS libraries."""
import tvm
from .. import api as _api
from tvm import te
# pylint: disable=C0103,W0612
......@@ -47,7 +48,7 @@ def matmul(lhs, rhs, transa=False, transb=False):
m = b
if transb:
n = c
return _api.extern(
return te.extern(
(m, n), [lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
......@@ -79,7 +80,7 @@ def conv2d(data, weight, pad='SAME', stride=1):
ho = hi // stride
wo = wi // stride
return _api.extern(
return te.extern(
(n, ho, wo, co), [data, weight],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
......
......@@ -16,8 +16,8 @@
# under the License.
"""External function interface to NNPACK libraries."""
import tvm
from tvm import te
import tvm._ffi
from .. import api as _api
def is_available():
......@@ -43,7 +43,7 @@ def fully_connected_inference(lhs, rhs, nthreads=1):
lhs 1D array out[output_channels] of FP32 elements.
"""
m = rhs.shape[0]
return _api.extern(
return te.extern(
(m, ), [lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.fully_connected_inference",
......@@ -100,13 +100,13 @@ def convolution_inference(
assert isinstance(stride, list) and len(stride) == 2
batch, _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape
idxdiv = _api.indexdiv
idxdiv = te.indexdiv
output_height = idxdiv(
input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
output_width = idxdiv(
input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
return _api.extern(
return te.extern(
(batch, output_channels, output_height, output_width),
[data, kernel, bias] if bias is not None else [data, kernel],
lambda ins, outs: tvm.tir.call_packed(
......@@ -155,11 +155,11 @@ def convolution_inference_without_weight_transform(
batch, _, input_height, input_width = data.shape
output_channels, _, _, _ = transformed_kernel.shape
kernel_height, kernel_width = (3, 3)
idxdiv = _api.indexdiv
idxdiv = te.indexdiv
output_height = idxdiv(input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
output_width = idxdiv(input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
return _api.extern(
return te.extern(
(batch, output_channels, output_height, output_width),
[data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel],
lambda ins, outs: tvm.tir.call_packed(
......@@ -194,7 +194,7 @@ def convolution_inference_weight_transform(
transform_tile_size = 8
if not isinstance(dtype, str):
dtype = dtype.dtype
return _api.extern(
return te.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel],
lambda ins, outs: tvm.tir.call_packed(
......
......@@ -21,10 +21,11 @@ from __future__ import absolute_import as _abs
import subprocess
import os
import warnings
import tvm._ffi
from tvm.runtime import ndarray as nd
from . import util
from ..api import register_func
from .._ffi.base import py_str
def compile_cuda(code,
......@@ -152,7 +153,7 @@ def get_cuda_version(cuda_path):
raise RuntimeError("Cannot read cuda version file")
@register_func("tvm_callback_libdevice_path")
@tvm._ffi.register_func("tvm_callback_libdevice_path")
def find_libdevice_path(arch):
"""Utility function to find libdevice
......
......@@ -19,6 +19,7 @@
import logging
import tvm
from tvm import te
from . import util
from .. import rpc
......@@ -79,17 +80,17 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride,
base_type = str(base_type) + str(bits)
dtype = base_type if lanes == 1 else base_type + "x" + str(lanes)
k = tvm.reduce_axis((0, m), name="k")
k = te.reduce_axis((0, m), name="k")
x = tvm.placeholder((n,), dtype=dtype, name="x")
op = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t), name="sum")
y = tvm.compute((n // m,),
lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k))
s = tvm.create_schedule(y.op)
x = te.placeholder((n,), dtype=dtype, name="x")
op = te.comm_reducer(lambda x, y: x*y, lambda t: tvm.tir.const(1, dtype=t), name="sum")
y = te.compute((n // m,),
lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k))
s = te.create_schedule(y.op)
yo, yi = s[y].split(y.op.axis[0], target.max_num_threads)
s[y].bind(yo, tvm.thread_axis("blockIdx.x"))
s[y].bind(yi, tvm.thread_axis("threadIdx.x"))
s[y].bind(yo, te.thread_axis("blockIdx.x"))
s[y].bind(yi, te.thread_axis("threadIdx.x"))
s[y].unroll(k)
try:
......@@ -207,10 +208,10 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
def extern(ins, outs):
# pylint: disable=unused-argument
"""construct measurement function by building IR directly"""
ib = tvm.ir_builder.create()
ib = tvm.tir.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", n // max_threads)
ib.scope_attr(tx, "thread_extent", max_threads)
......@@ -235,8 +236,8 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
ib.emit(outs[0].vstore(idx, b[0]))
return ib.get()
y = tvm.extern((n,), [], extern, name="y", dtype=dtype)
s = tvm.create_schedule(y.op)
y = te.extern((n,), [], extern, name="y", dtype=dtype)
s = te.create_schedule(y.op)
try:
func = tvm.build(s, [y], target, target_host=target_host)
......
......@@ -16,8 +16,8 @@
# under the License.
"""External function interface to random library."""
import tvm
from tvm import te
import tvm._ffi
from .. import api as _api
def randint(low, high, size, dtype='int32'):
......@@ -38,7 +38,7 @@ def randint(low, high, size, dtype='int32'):
A tensor with specified size and dtype
"""
assert 'int' in dtype, "the type of randint output must be int or uint"
return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
return te.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.randint", int(low), int(high), outs[0]), dtype=dtype)
......@@ -66,7 +66,7 @@ def uniform(low, high, size):
out : Tensor
A tensor with specified size and dtype.
"""
return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
return te.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32')
......@@ -90,7 +90,7 @@ def normal(loc, scale, size):
out : Tensor
A tensor with specified size and dtype
"""
return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
return te.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
......
......@@ -16,7 +16,8 @@
# under the License.
"""External function interface to rocBLAS libraries."""
import tvm
from .. import api as _api
from tvm import te
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS
......@@ -39,7 +40,7 @@ def matmul(lhs, rhs, transa=False, transb=False):
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
return te.extern(
(n, m), [lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.rocblas.matmul",
......
......@@ -18,11 +18,13 @@
import subprocess
from os.path import join, exists
import tvm._ffi
from tvm._ffi.base import py_str
import tvm.runtime
import tvm.target
from . import util
from ..api import register_func, convert
def find_lld(required=True):
"""Find ld.lld in system.
......@@ -85,7 +87,7 @@ def rocm_link(in_file, out_file, lld=None):
raise RuntimeError(msg)
@register_func("tvm_callback_rocm_link")
@tvm._ffi.register_func("tvm_callback_rocm_link")
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
......@@ -108,7 +110,7 @@ def callback_rocm_link(obj_bin):
cobj_bin = bytearray(open(tmp_cobj, "rb").read())
return cobj_bin
@register_func("tvm_callback_rocm_bitcode_path")
@tvm._ffi.register_func("tvm_callback_rocm_bitcode_path")
def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"):
"""Utility function to find ROCm device library bitcodes
......@@ -138,4 +140,4 @@ def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"):
"oclc_wavefrontsize64_on.amdgcn.bc"
]
paths = [join(rocdl_dir, bitcode) for bitcode in bitcode_files]
return convert([path for path in paths if exists(path)])
return tvm.runtime.convert([path for path in paths if exists(path)])
......@@ -17,11 +17,12 @@
"""Utility for Interacting with SDAccel Tools"""
import subprocess
import os
import tvm._ffi
from . import util
from ..api import register_func
@register_func("tvm_callback_sdaccel_compile")
@tvm._ffi.register_func("tvm_callback_sdaccel_compile")
def compile_vhls(kernel_info, device_name):
"""Compile Vivado HLS code for SDAccel.
......
......@@ -18,10 +18,9 @@
# pylint: disable=invalid-name
import numpy as _np
from tvm.runtime import ndarray as _nd
from .. import expr as _expr
from .. import api as _api
from .. import tensor as _tensor
from tvm import te
from tvm.tir import expr as _expr
from tvm.te import tensor as _tensor
float32 = "float32"
......@@ -136,9 +135,9 @@ class CSRPlaceholderOp(SparsePlaceholderOp):
"""
SparsePlaceholderOp.__init__(self, shape, nonzeros, dtype, name)
self.stype = 'csr'
self.data = _api.placeholder((nonzeros,), dtype=dtype, name=self.name+'_data')
self.indices = _api.placeholder((nonzeros,), dtype=itype, name=self.name+'_indices')
self.indptr = _api.placeholder((self.shape[0]+1,), dtype=itype, name=self.name+'_indptr')
self.data = te.placeholder((nonzeros,), dtype=dtype, name=self.name+'_data')
self.indices = te.placeholder((nonzeros,), dtype=itype, name=self.name+'_indices')
self.indptr = te.placeholder((self.shape[0]+1,), dtype=itype, name=self.name+'_indptr')
assert isinstance(self.data, _tensor.Tensor)
assert isinstance(self.indices, _tensor.Tensor)
assert isinstance(self.indptr, _tensor.Tensor)
......
......@@ -282,7 +282,7 @@ def dump_json(sch, need_range):
def encode_itervar_relation(obj_manager, rel):
"""Extract and encode IterVar Relationship visualization data to a dictionary"""
rel_type = type(rel)
if rel_type is tvm.schedule.Split:
if rel_type is tvm.te.schedule.Split:
node_type = 'Split_Relation'
rel_dict = {
"type": node_type,
......@@ -290,7 +290,7 @@ def dump_json(sch, need_range):
"outer": obj_manager.get_dom_path(rel.outer),
"inner": obj_manager.get_dom_path(rel.inner),
}
elif rel_type is tvm.schedule.Fuse:
elif rel_type is tvm.te.schedule.Fuse:
node_type = 'Fuse_Relation'
rel_dict = {
"type": node_type,
......@@ -298,7 +298,7 @@ def dump_json(sch, need_range):
"outer": obj_manager.get_dom_path(rel.outer),
"inner": obj_manager.get_dom_path(rel.inner),
}
elif rel_type is tvm.schedule.Singleton:
elif rel_type is tvm.te.schedule.Singleton:
node_type = 'Singleton_Relation'
rel_dict = {
"type": node_type,
......@@ -377,12 +377,12 @@ def dump_json(sch, need_range):
dict : dictionary
A nested dictionary
"""
assert isinstance(sch, tvm.schedule.Schedule
), 'Input is not a tvm.schedule.Schedule object.'
assert isinstance(sch, tvm.te.schedule.Schedule
), 'Input is not a tvm.te.schedule.Schedule object.'
range_map = None
if need_range:
try:
range_map = tvm.schedule.InferBound(sch)
range_map = tvm.te.schedule.InferBound(sch)
except tvm._ffi.base.TVMError as expt:
warnings.warn(
'Ranges are not available, because InferBound fails with the following error:\n'
......
......@@ -89,7 +89,7 @@ def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body
Returns
......@@ -113,7 +113,7 @@ def lower(sch,
Parameters
----------
sch : tvm.schedule.Schedule
sch : tvm.te.schedule.Schedule
The schedule to be built
args : list of Buffer or Tensor or Var
......@@ -286,7 +286,7 @@ def build(inputs,
Parameters
----------
inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
inputs : tvm.te.Schedule, LoweredFunc, or dict of target to LoweredFunc list
The schedule to be built
args : list of Buffer or Tensor or Var, optional
......@@ -325,10 +325,10 @@ def build(inputs,
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm")
......@@ -337,10 +337,10 @@ def build(inputs,
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.create_schedule(C.op)
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1")
......
......@@ -16,6 +16,9 @@
# under the License.
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from tvm.runtime import const, convert
import tvm.te
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
......@@ -23,8 +26,6 @@ from tvm.tir import ir_pass
from tvm.tir import call_pure_intrin
from tvm.tir.stmt import For
from .. import api as _api
from .util import _internal_assert
# pylint: disable=redefined-builtin
......@@ -42,11 +43,11 @@ def _range(annotation, args):
"""Handling TVM loop types"""
n = args.__len__()
if n == 1:
low, ext = _api.const(0, dtype='int32'), args[0]
low, ext = const(0, dtype='int32'), args[0]
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
if not ir_pass.Equal(low, _api.const(0, dtype='int32')):
if not ir_pass.Equal(low, const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
......@@ -62,16 +63,16 @@ def bind(func_id, args):
_internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(args[0], str), \
"A loop bind's first argument should be a string!")
low, ext = _api.const(0, "int32"), args[1]
iter_var = _api.thread_axis((low, ext), args[0])
low, ext = const(0, "int32"), args[1]
iter_var = tvm.te.thread_axis((low, ext), args[0])
for_type = None
return iter_var, low, ext, for_type
def _math_intrin(func_id, args):
# pylint: disable=import-outside-toplevel
import tvm.tir.op
return getattr(tvm.tir.op, func_id)(*args)
from tvm.tir import op
return getattr(op, func_id)(*args)
sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
......@@ -88,7 +89,7 @@ def _allocate_tensor(func_id, args):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n = args.__len__()
_internal_assert(isinstance(_api.convert(args[0]), Array), \
_internal_assert(isinstance(convert(args[0]), Array), \
"allocate's first argument should be a tuple of shape!")
shape = args[0]
for i in shape:
......@@ -119,10 +120,10 @@ def len(func_id, args):
_internal_assert(args.__len__() == 1, "Only 1 argument is expected!")
_internal_assert(func_id == "len", "This function cannot be directly invoked!")
try:
return _api.convert(args[0].__len__())
return convert(args[0].__len__())
except: #pylint: disable=bare-except
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
return _api.convert(args[0].shape[0])
return convert(args[0].shape[0])
def _cast(func_id, args):
......@@ -159,4 +160,4 @@ def max_num_threads(func_id, args):
else:
_internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
res = _tgt.Target.current(args[0].value).max_num_threads
return _api.convert(res)
return convert(res)
......@@ -25,7 +25,9 @@ import numbers
from enum import Enum
from tvm.ir import Array, Range
import tvm.runtime
import tvm.tir
import tvm.te
import tvm.te._ffi_api
from tvm.tir import expr as _expr
......@@ -40,8 +42,6 @@ from . import calls
from . import util
from .preprocessor import determine_variable_usage
from .. import api as _api
def concat_list_to_block(lst):
"""Concatenate a list of Python IR nodes to HalideIR Block"""
......@@ -125,7 +125,7 @@ class HybridParser(ast.NodeVisitor):
"""
Parameters
----------
args: A list of tvm.placeholder or tvm.var
args: A list of tvm.te.placeholder or te.var
Provided by the user, the argument list of the function to be lowered.
usage: A dict of variables used in last in this function
......@@ -210,9 +210,9 @@ class HybridParser(ast.NodeVisitor):
_domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_true = _api.convert(True)
_true = tvm.runtime.convert(True)
body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', tvm.runtime.convert(_scope), body)
for elem in to_pop:
self.symbols.pop(elem)
......@@ -256,10 +256,10 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node):
name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']:
return _api.convert(ast.literal_eval(name))
return tvm.runtime.convert(ast.literal_eval(name))
if name in self.closure_vars:
return _api.convert(self.closure_vars[name])
return tvm.runtime.convert(self.closure_vars[name])
ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
......@@ -271,9 +271,9 @@ class HybridParser(ast.NodeVisitor):
return entry if isinstance(node.ctx, ast.Load) else None
if ty is Symbol.BufferVar:
if isinstance(node.ctx, ast.Load):
return tvm.tir.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
return tvm.tir.Call(entry.dtype, entry.name, [tvm.runtime.const(0, 'int32')], \
_expr.Call.Halide, entry.op, entry.value_index)
return entry, [_api.const(0, 'int32')]
return entry, [tvm.runtime.const(0, 'int32')]
# Do I need any assertion here?
return entry
......@@ -287,11 +287,11 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(isinstance(node.n, bool),
"The data type should be one of (int, float, bool)")
dtype = "bool"
return _api.const(node.n, dtype)
return tvm.runtime.const(node.n, dtype)
def visit_NameConstant(self, node):
return _api.convert(node.value)
return tvm.runtime.convert(node.value)
def visit_AugAssign(self, node):
......@@ -301,7 +301,7 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
buf, args = buf
else:
args = [_api.const(0, 'int32')]
args = [tvm.runtime.const(0, 'int32')]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
......@@ -341,7 +341,7 @@ class HybridParser(ast.NodeVisitor):
"This value should not be defined before this point!")
if isinstance(rhs, tuple):
shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs)
ph = tvm.te.placeholder(shape, dtype=dtype, name=lhs)
self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph)
if scope == 'output':
self.outputs.append(lhs)
......@@ -353,7 +353,7 @@ class HybridParser(ast.NodeVisitor):
"Single variable not supported in devices' side!\n" + \
"If you are using GPU, please allocate a 'local' spad " + \
"outside the bind body")
ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
ph = tvm.te.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.add_symbol(lhs, Symbol.BufferVar, ph)
lhs = self.visit(lhs_)
if lhs is not None:
......@@ -524,8 +524,8 @@ class HybridParser(ast.NodeVisitor):
if iter_var is None:
_internal_assert(for_type is not None, "The loop iterating function parse error!")
offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, 'int32')):
offset = iter_var = tvm.te.var(_name)
if not _ir_pass.Equal(low, tvm.runtime.const(0, 'int32')):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body)
......@@ -543,7 +543,7 @@ class HybridParser(ast.NodeVisitor):
else:
_internal_assert(not isinstance(for_type, tuple), \
"Micro expansion should be handled before!")
res = tvm.tir.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
res = tvm.tir.For(iter_var, tvm.runtime.const(0, 'int32'), ext, for_type, 0, _body)
self.symbols.pop(_name)
return res
......@@ -579,7 +579,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Assert(self, node):
test = self.visit(node.test)
mesg = _api.convert(self.visit(node.msg))
mesg = tvm.runtime.convert(self.visit(node.msg))
return tvm.tir.AssertStmt(test, mesg, util.make_nop())
......
......@@ -22,6 +22,7 @@ import logging
import sys
import numpy
import tvm.runtime
from tvm._ffi.base import numeric_types
from tvm.ir.container import Array
......@@ -29,8 +30,6 @@ from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.te.tensor import Tensor
from .. import api as _api
#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
......@@ -47,7 +46,7 @@ def _internal_assert(cond, err):
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
"""Returns a 'no operation' node in HalideIR."""
return _stmt.Evaluate(_api.const(0, dtype='int32'))
return _stmt.Evaluate(tvm.runtime.const(0, dtype='int32'))
def is_docstring(node):
......@@ -73,7 +72,7 @@ def _pruned_source(func):
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
# pylint: disable=import-outside-toplevel
from .. import ir_pass
from tvm.tir import ir_pass
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=unused-wildcard-import, wildcard-import, redefined-builtin
"""Backwared compatible layer for intrin."""
from .tir.op import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""namespace of IR node builder make function
This namespace is used for developers. While you do not see any declarations.
The functions are automatically exported from C++ side via PackedFunc.
Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
import tvm._ffi
import tvm.ir
from tvm.ir import make_node as node
from tvm.tir import Call
def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Parameters
----------
min_value : PrimExpr
The minimum value of the range.
extent : PrimExpr
The extent of the range.
Returns
-------
rng : Range
The constructed range.
"""
return tvm.ir.Range.make_by_min_extent(min_value, extent)
tvm._ffi._init_api("tvm.make")
......@@ -18,7 +18,7 @@
"""The Relay IR namespace containing the IR definition and compiler."""
import os
from sys import setrecursionlimit
from ..api import register_func
from . import call_graph
from . import base
from . import ty
......
......@@ -26,10 +26,10 @@ def lower(sch, inputs, func_name, source_func):
Parameters
----------
sch : tvm.Schedule
sch : tvm.te.Schedule
The schedule.
inputs : List[tvm.Tensor]
inputs : List[tvm.te.Tensor]
The inputs to the function.
func_name : str
......
......@@ -21,6 +21,7 @@ from __future__ import absolute_import
import logging
import numpy as np
import tvm
from tvm import te
from ..base import register_relay_node, Object
from ... import target as _target
from ... import autotvm
......@@ -79,12 +80,12 @@ def get_shape(shape):
"""Convert the shape to correct dtype and vars."""
ret = []
for dim in shape:
if isinstance(dim, tvm.expr.IntImm):
if isinstance(dim, tvm.tir.IntImm):
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.expr.IntImm("int32", val))
elif isinstance(dim, tvm.expr.Any):
ret.append(tvm.var("any_dim", "int32"))
ret.append(tvm.tir.IntImm("int32", val))
elif isinstance(dim, tvm.tir.Any):
ret.append(te.var("any_dim", "int32"))
else:
ret.append(dim)
return ret
......@@ -103,7 +104,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
attrs : object
The op attribute.
inputs : List[tvm.Tensor]
inputs : List[tvm.te.Tensor]
Input tensors to the op.
out_type : relay.Type
......@@ -129,7 +130,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
flag = True
for clause in spec.condition.clauses:
clause = analyzer.canonical_simplify(clause)
if isinstance(clause, tvm.expr.IntImm) and clause.value:
if isinstance(clause, tvm.tir.IntImm) and clause.value:
continue
flag = False
break
......@@ -162,7 +163,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
attrs : object
The op attribute.
inputs : List[tvm.Tensor]
inputs : List[tvm.te.Tensor]
Input tensors to the op.
out_type : relay.Type
......@@ -176,7 +177,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
Returns
-------
ret : tuple(relay.op.OpImplementation, List[tvm.Tensor])
ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor])
The best op implementation and the corresponding output tensors.
"""
all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
......
......@@ -36,7 +36,7 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
from tvm.runtime.ndarray import empty
from tvm.relay import _build_module
from tvm import target as _target
from tvm import expr as _expr
from tvm.tir import expr as _expr
class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system."""
......
......@@ -23,7 +23,7 @@ import numpy as np
from tvm.ir import IRModule
from tvm import expr as tvm_expr
from tvm.tir import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
......
......@@ -16,22 +16,20 @@
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from ..api import register_func
import tvm._ffi
# pylint: disable=unused-argument, import-outside-toplevel
def _debugger_init(expr, stack):
import pdb
pdb.set_trace()
@register_func("relay.debug")
@tvm._ffi.register_func("relay.debug")
def _debug(*args):
import pdb
pdb.set_trace()
# pylint: disable=unused-argument
@register_func("relay.debug_interp")
@tvm._ffi.register_func("relay.debug_interp")
def _debug_interp(*args):
_, _, _, ist = args
print("Relay Debugger")
......
......@@ -17,7 +17,6 @@
# pylint: disable=invalid-name, import-self, unused-argument, unused-variable
# pylint: disable=inconsistent-return-statements, import-outside-toplevel
"""CoreML frontend."""
from __future__ import absolute_import as _abs
import math
import numpy as np
import tvm
......
......@@ -19,7 +19,6 @@
DarkNet symbol frontend for Relay.
"""
from __future__ import absolute_import as _abs
from enum import Enum
import numpy as np
import tvm
......
......@@ -16,8 +16,6 @@
# under the License.
# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return, too-many-lines
"""MXNet symbol frontend."""
from __future__ import absolute_import as _abs
import json
import numpy as np
import tvm
......
......@@ -406,7 +406,7 @@ def _numtotensor():
val = inputs[0]
dtype = type(val)
if isinstance(val, tvm.expr.IntImm):
if isinstance(val, tvm.tir.IntImm):
val = val.__int__()
dtype = int
......
......@@ -18,9 +18,6 @@
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=import-outside-toplevel
"""TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs
from __future__ import print_function
import warnings
from collections import defaultdict
......@@ -1012,7 +1009,7 @@ def _gather():
'Attribute batch_dims is not supported')
new_input = inputs[0:2]
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
extras={'axis': tvm.tir.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices',
'Taxis', '_class', 'batch_dims'])(new_input, attr)
return _impl
......
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend."""
import math
import numpy as np
......
......@@ -17,9 +17,9 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from tvm.runtime import convert
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ...api import convert
from ...hybrid import script
_reg.register_reduce_schedule("argmax")
......
......@@ -16,14 +16,14 @@
# under the License.
#pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
from tvm.runtime import convert
from topi.util import get_const_tuple
from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern
from ...hybrid import script
from ...api import convert
register_broadcast_schedule("log")
......
......@@ -18,13 +18,14 @@
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
import tvm
from tvm import te
from tvm.runtime import convert
import topi
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from . import strategy
from .op import OpPattern
from ...hybrid import script
from ...api import convert
_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
......@@ -79,7 +80,7 @@ def compute_argwhere(attrs, inputs, output_type):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(tvm.var("any_dim", "int32"))
output_shape.append(te.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
......@@ -473,7 +474,7 @@ def squeeze_shape_func(attrs, inputs, _):
if keep_axes:
out = _squeeze_shape_func(inputs[0], convert(keep_axes))
else:
out = tvm.compute((), lambda *indices: 0)
out = te.compute((), lambda *indices: 0)
return [out]
@script
......
......@@ -28,7 +28,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
data : relay.Expr
The input data tensor.
valid_count : tvm.Tensor
valid_count : tvm.te.Tensor
The number of valid elements to be sorted.
axis : int, optional
......
......@@ -20,11 +20,12 @@ from __future__ import absolute_import
import topi
from topi.util import get_const_tuple
from tvm.runtime import convert
from .. import op as reg
from .. import strategy
from ..op import OpPattern
from .._tensor import elemwise_shape_func
from ....api import convert
from ....hybrid import script
# relu
......
......@@ -21,7 +21,6 @@ from tvm.driver import lower, build
from ..base import register_relay_node
from ..expr import RelayExpr
from ...api import register_func
from ...target import get_native_generic_func, GenericFunc
from ...runtime import Object
from . import _make
......@@ -155,7 +154,7 @@ class OpImplementation(Object):
attrs : Attrs
Op attributes.
inputs : list[tvm.tensor.Tensor]
inputs : list[te.tensor.Tensor]
The input tensors.
out_type : relay.Type
......@@ -163,7 +162,7 @@ class OpImplementation(Object):
Returns
-------
outs : list[tvm.tensor.Tensor]
outs : list[te.tensor.Tensor]
The output tensors.
"""
return _OpImplementationCompute(self, attrs, inputs, out_type)
......@@ -176,7 +175,7 @@ class OpImplementation(Object):
attrs : Attrs
Op attributes.
outs : list[tvm.tensor.Tensor]
outs : list[te.tensor.Tensor]
The output tensors.
target : tvm.target.Target
......@@ -184,7 +183,7 @@ class OpImplementation(Object):
Returns
-------
schedule : tvm.Schedule
schedule : tvm.te.Schedule
The schedule.
"""
return _OpImplementationSchedule(self, attrs, outs, target)
......@@ -454,11 +453,11 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level)
@register_func("relay.op.compiler._lower")
@tvm._ffi.register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
@register_func("relay.op.compiler._build")
@tvm._ffi.register_func("relay.op.compiler._build")
def _build(lowered_funcs):
return build(lowered_funcs, target="llvm")
......@@ -473,7 +472,7 @@ def debug(expr, debug_func=None):
if debug_func:
name = "debugger_func{}".format(__DEBUG_COUNTER__)
register_func(name, debug_func)
tvm._ffi.register_func(name, debug_func)
__DEBUG_COUNTER__ += 1
else:
name = ''
......
......@@ -17,9 +17,11 @@
# pylint: disable=invalid-name
"""Helper utility to save parameter dicts."""
import tvm
import tvm._ffi
_save_param_dict = tvm.get_global_func("tvm.relay._save_param_dict")
_load_param_dict = tvm.get_global_func("tvm.relay._load_param_dict")
_save_param_dict = tvm._ffi.get_global_func("tvm.relay._save_param_dict")
_load_param_dict = tvm._ffi.get_global_func("tvm.relay._load_param_dict")
def save_param_dict(params):
"""Save parameter dictionary to binary bytes.
......
......@@ -16,12 +16,12 @@
# under the License.
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
from __future__ import absolute_import
import tvm.ir
from . import _quantize
from ._calibrate import calibrate
from .. import expr as _expr
from .. import transform as _transform
from ... import make as _make
from ..base import Object, register_relay_node
......@@ -181,7 +181,7 @@ def qconfig(**kwargs):
"""
node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in QConfig._node_defaults.items()}
return _make.node("relay.quantize.QConfig", **node_args)
return tvm.ir.make_node("relay.quantize.QConfig", **node_args)
class QuantizeContext(object):
......
......@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
import numpy as np
import tvm
from tvm import te
import tvm.relay as relay
import tvm.relay.op as op
from tvm.relay import transform
......
......@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
import os
import tvm
def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("RELAY_TEST_TARGETS", "")
......
......@@ -23,6 +23,7 @@ import inspect
import functools
import tvm
from tvm import te
from tvm.runtime import ndarray as _nd
from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass
......
......@@ -106,6 +106,7 @@ class Executable(object):
import numpy as np
import tvm
from tvm import te
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
......
......@@ -35,7 +35,7 @@ class DumpIR(object):
-----------
.. code-block:: python
with tvm.build_config(dump_pass_ir=True)
with tvm.target.build_config(dump_pass_ir=True)
run()
"""
scope_level = 0
......
......@@ -116,6 +116,7 @@ def override_native_generic_func(func_name):
.. code-block:: python
import tvm
from tvm import te
# wrap function as target generic
@tvm.target.override_native_generic_func("my_func")
def my_func(a):
......@@ -210,6 +211,7 @@ def generic_func(fdefault):
.. code-block:: python
import tvm
from tvm import te
# wrap function as target generic
@tvm.target.generic_func
def my_func(a):
......
......@@ -18,6 +18,7 @@
"""Namespace for Tensor Expression Language
"""
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
......@@ -29,3 +30,5 @@ from .tensor_intrin import decl_tensor_intrin
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var
from .operation import thread_axis, reduce_axis
from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
......@@ -167,13 +167,13 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
.. code-block:: python
# The following code is equivalent to numpy.cumsum
m = tvm.var("m")
n = tvm.var("n")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.te.scan(s_init, s_update, s_state, X)
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
......@@ -264,10 +264,10 @@ def extern(shape,
.. code-block:: python
A = tvm.placeholder((n, l), name="A")
B = tvm.placeholder((l, m), name="B")
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
A = te.placeholder((n, l), name="A")
B = te.placeholder((l, m), name="B")
C = te.extern((n, m), [A, B],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
......
......@@ -73,19 +73,19 @@ def tag_scope(tag):
-------
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
n = te.var('n')
m = te.var('m')
l = te.var('l')
A = te.placeholder((n, l), name='A')
B = te.placeholder((m, l), name='B')
k = te.reduce_axis((0, l), name='k')
with tvm.tag_scope(tag='matmul'):
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
with tvm.te.tag_scope(tag='matmul'):
C = te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k))
# or use tag_scope as decorator
@tvm.tag_scope(tag="conv")
@tvm.te.tag_scope(tag="conv")
def compute_relu(data):
return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i)))
return te.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i)))
"""
return TagScope(tag)
......@@ -17,20 +17,22 @@
# pylint: disable=unused-import, redefined-builtin
"""Namespace for Tensor-level IR"""
from tvm.ir import PrimExpr
from tvm.runtime import const
from .buffer import Buffer, decl_buffer
from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar
from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
......
......@@ -201,15 +201,15 @@ def decl_buffer(shape,
.. code-block:: python
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B')
C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2")
n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2")
o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2")
A = te.placeholder((m0, m1, m2), name='A')
B = te.placeholder((n0, n1, n2), name='B')
C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = tvm.create_schedule(C.op)
s = te.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
......
......@@ -25,7 +25,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
.. code-block:: python
x = tvm.var("n")
x = te.var("n")
y = x + 2
assert(isinstance(y, tvm.tir.Add))
assert(y.a == x)
......@@ -169,7 +169,7 @@ class ExprOp(object):
def __nonzero__(self):
raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
"use tvm.all / tvm.any instead")
"use tvm.tir.all / tvm.tir.any instead")
def __bool__(self):
return self.__nonzero__()
......@@ -346,8 +346,8 @@ class IterVar(Object, ExprOp):
See Also
--------
tvm.thread_axis: Create thread axis IterVar.
tvm.reduce_axis: Create reduce axis IterVar.
te.thread_axis: Create thread axis IterVar.
te.reduce_axis: Create reduce axis IterVar.
"""
DataPar = 0
ThreadIndex = 1
......@@ -812,7 +812,7 @@ class Select(PrimExprWithOp):
Note
----
Select may compute both true_value and false_value.
Use :py:class:`tvm.if_then_else` instead if you want to
Use :py:class:`tvm.tir.if_then_else` instead if you want to
get a conditional expression that only evaluates
the correct branch.
......
......@@ -16,7 +16,7 @@
# under the License.
"""Generic opertors in TVM.
We follow the numpy naming convention for this interface
(e.g., tvm.generic.multitply ~ numpy.multiply).
(e.g., tvm.tir.generic.multitply ~ numpy.multiply).
The default implementation is used by tvm.ExprOp.
"""
# pylint: disable=unused-argument
......
......@@ -98,8 +98,8 @@ class IRBuilder(object):
--------
.. code-block:: python
ib = tvm.ir_builder.create()
n = tvm.var("n")
ib = tvm.tir.ir_builder.create()
n = te.var("n")
A = ib.allocate("float32", n, name="A")
with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2) == 0):
......@@ -158,8 +158,8 @@ class IRBuilder(object):
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
ib = tvm.tir.ir_builder.create()
i = te.var("i")
x = ib.pointer("float32")
ib.scope_attr(x, "storage_scope", "global")
x[i] = x[i - 1] + 1
......@@ -200,7 +200,7 @@ class IRBuilder(object):
--------
.. code-block:: python
ib = tvm.ir_builder.create()
ib = tvm.tir.ir_builder.create()
x = ib.pointer("float32")
with ib.for_range(1, 10, name="i") as i:
x[i] = x[i - 1] + 1
......@@ -243,8 +243,8 @@ class IRBuilder(object):
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
ib = tvm.tir.ir_builder.create()
i = te.var("i")
x = ib.pointer("float32")
with ib.if_scope((i % 2) == 0):
x[i] = x[i - 1] + 1
......@@ -268,8 +268,8 @@ class IRBuilder(object):
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
ib = tvm.tir.ir_builder.create()
i = te.var("i")
x = ib.pointer("float32")
with ib.if_scope((i % 2) == 0):
x[i] = x[i - 1] + 1
......
......@@ -64,7 +64,7 @@ def call_packed(*args):
See Also
--------
tvm.extern : Create tensor with extern function call.
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
return Call(
......@@ -194,7 +194,7 @@ def call_llvm_intrin(dtype, name, *args):
from tvm.target import codegen
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
return call_pure_intrin(dtype, 'llvm_intrin', tvm.tir.const(llvm_id, 'uint32'), *args)
def any(*args):
......@@ -274,7 +274,7 @@ def trace(args, trace_action="tvm.default_trace_action"):
tvm.tir.call_packed : Creates packed function.
"""
if not isinstance(args, list):
raise Exception("tvm.trace consumes the args as list type")
raise Exception("tvm.tir.trace consumes the args as list type")
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
call_args.insert(0, trace_action)
return tvm.tir.Call(
......@@ -556,9 +556,9 @@ def round(x):
def nearbyint(x):
"""Round elements of the array to the nearest integer.
This intrinsic uses llvm.nearbyint instead of llvm.round
which is faster but will results different from tvm.round.
which is faster but will results different from te.round.
Notably nearbyint rounds according to the rounding mode,
whereas tvm.round (llvm.round) ignores that.
whereas te.round (llvm.round) ignores that.
For differences between the two see:
https://en.cppreference.com/w/cpp/numeric/math/round
https://en.cppreference.com/w/cpp/numeric/math/nearbyint
......@@ -855,13 +855,13 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
-------
.. code-block:: python
n = tvm.var("n")
m = tvm.var("m")
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name="A")
k = tvm.reduce_axis((0, m), name="k")
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
"""
def _reduce_directly(*args):
num = len(args)
......@@ -943,14 +943,14 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
B = te.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = tvm.{0}(m, n)
......
......@@ -23,8 +23,8 @@ Each statement node have subfields that can be visited from python side.
.. code-block:: python
x = tvm.var("n")
a = tvm.var("array", tvm.handle)
x = te.var("n")
a = te.var("array", "handle")
st = tvm.tir.stmt.Store(a, x + 1, 1)
assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a)
......
......@@ -25,6 +25,7 @@ import sys
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime, cc
......
......@@ -20,20 +20,21 @@ import os.path as osp
import sys
import tvm
from tvm import te
from tvm.contrib import cc
def main(target, out_dir):
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule(C.op)
n = te.var('n')
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = te.create_schedule(C.op)
if target == 'cuda':
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis('blockIdx.x'))
s[C].bind(tx, tvm.thread_axis('threadIdx.x'))
s[C].bind(bx, te.thread_axis('blockIdx.x'))
s[C].bind(tx, te.thread_axis('threadIdx.x'))
fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd')
......
......@@ -22,6 +22,7 @@ from os import path as osp
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay import testing
......
......@@ -23,6 +23,7 @@ import sys
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay import testing
......
......@@ -22,13 +22,14 @@ from os import path as osp
import sys
import tvm
from tvm import te
def main():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
n = te.var('n')
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
s[C].parallel(s[C].op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))
tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
......
......@@ -22,14 +22,15 @@ from os import path as osp
import sys
import tvm
from tvm import te
from tvm.contrib import cc
def main():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
n = te.var('n')
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
s[C].parallel(s[C].op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))
obj_file = osp.join(sys.argv[1], 'test.o')
......
......@@ -24,6 +24,7 @@ Specifically, we test the following capabilities:
"""
import tvm
from tvm import te
import subprocess
from tvm.contrib import util
from tvm.contrib import cc
......
......@@ -15,19 +15,20 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
import topi.testing
from tvm.contrib import cblas
def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32):
bias = tvm.var('bias', dtype=dtype)
def verify_matmul_add(m, l, n, transa=False, transb=False, dtype="float32"):
bias = te.var('bias', dtype=dtype)
ashape = (l, n) if transa else (n, l)
bshape = (m, l) if transb else (l, m)
A = tvm.placeholder(ashape, name='A', dtype=dtype)
B = tvm.placeholder(bshape, name='B', dtype=dtype)
A = te.placeholder(ashape, name='A', dtype=dtype)
B = te.placeholder(bshape, name='B', dtype=dtype)
C = cblas.matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = te.create_schedule(D.op)
def get_numpy(a, b, bb, transa, transb):
if transa:
......@@ -64,14 +65,14 @@ def test_matmul_add():
verify_matmul_add(1, 16, 3, False, False)
verify_matmul_add(1, 16, 3, True, True)
def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32):
def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype="float32"):
ashape = (batch, l, n) if transa else (batch, n, l)
bshape = (batch, m, l) if transb else (batch, l, m)
A = tvm.placeholder(ashape, name='A', dtype=dtype)
B = tvm.placeholder(bshape, name='B', dtype=dtype)
A = te.placeholder(ashape, name='A', dtype=dtype)
B = te.placeholder(bshape, name='B', dtype=dtype)
C = cblas.batch_matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
s = tvm.create_schedule(D.op)
D = te.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
s = te.create_schedule(D.op)
def get_numpy(a, b, transa, transb):
if transa:
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
from tvm.contrib import cublas
from tvm.contrib import cublaslt
......@@ -23,10 +24,10 @@ def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
n = 1024
l = 128
m = 236
A = tvm.placeholder((n, l), name='A', dtype=in_dtype)
B = tvm.placeholder((l, m), name='B', dtype=in_dtype)
A = te.placeholder((n, l), name='A', dtype=in_dtype)
B = te.placeholder((l, m), name='B', dtype=in_dtype)
C = cublas.matmul(A, B, dtype=out_dtype)
s = tvm.create_schedule(C.op)
s = te.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.runtime.enabled(target):
......@@ -56,11 +57,11 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
N = roundoff(n, 8)
N_out = roundoff(n, 32)
A = tvm.placeholder((N, L), name='A', dtype=in_dtype)
B = tvm.placeholder((m, L), name='B', dtype=in_dtype)
A = te.placeholder((N, L), name='A', dtype=in_dtype)
B = te.placeholder((m, L), name='B', dtype=in_dtype)
# C has CUBLASLT_ORDER_COL32 layout, thus a different shape
C = cublaslt.matmul(A, B, False, True, m, N_out, dtype=out_dtype)
s = tvm.create_schedule(C.op)
s = te.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.runtime.enabled(target):
......@@ -108,10 +109,10 @@ def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
n = 1024
l = 128
m = 236
A = tvm.placeholder((j, n, l), name='A', dtype=in_dtype)
B = tvm.placeholder((j, l, m), name='B', dtype=in_dtype)
A = te.placeholder((j, n, l), name='A', dtype=in_dtype)
B = te.placeholder((j, l, m), name='B', dtype=in_dtype)
C = cublas.batch_matmul(A, B, dtype=out_dtype)
s = tvm.create_schedule(C.op)
s = te.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.runtime.enabled(target):
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm.contrib import cudnn
import numpy as np
import topi.testing
......@@ -48,8 +49,8 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
xshape = [batch, height, weight, in_channel]
wshape = [out_channel, filter_h, filter_w, in_channel]
X = tvm.placeholder(xshape, name='X', dtype=data_dtype)
W = tvm.placeholder(wshape, name='W', dtype=data_dtype)
X = te.placeholder(xshape, name='X', dtype=data_dtype)
W = te.placeholder(wshape, name='W', dtype=data_dtype)
Y = cudnn.conv_forward(X,
W,
[pad_h, pad_w],
......@@ -60,7 +61,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
conv_dtype=conv_dtype,
algo=-1)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
s = te.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
......@@ -120,8 +121,8 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
xshape = [batch, in_channel, depth, height, weight]
wshape = [out_channel, in_channel, filter_d, filter_h, filter_w]
X = tvm.placeholder(xshape, name='X', dtype=data_dtype)
W = tvm.placeholder(wshape, name='W', dtype=data_dtype)
X = te.placeholder(xshape, name='X', dtype=data_dtype)
W = te.placeholder(wshape, name='W', dtype=data_dtype)
Y = cudnn.conv_forward(X,
W,
[pad_d, pad_h, pad_w],
......@@ -132,7 +133,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
algo=-1,
conv_dtype=conv_dtype)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
s = te.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment