Unverified Commit a5661611 by Tianqi Chen Committed by GitHub

[REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding files (#4862)

* [REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding relay files.

This PR establishes tvm.ir and migrates the corresponding relay
files into the new folder.

API Change:
- relay.Module -> tvm.IRModule

* Update with ADT

* Migrate transform

* address comments

* Migrate module

* Migrate json_compact

* Migrate attrs

* Move LoweredFunc to stmt temporarily

* temp migrate container

* Finish migrate container
parent 15df204f
......@@ -34,7 +34,7 @@ def get_network(name, batch_size, dtype='float32'):
Returns
-------
net: relay.Module
net: tvm.IRModule
The relay function of network definition
params: dict
The random parameters for benchmark
......@@ -70,7 +70,7 @@ def get_network(name, batch_size, dtype='float32'):
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = net["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
net = relay.Module.from_expr(net)
net = tvm.IRModule.from_expr(net)
else:
raise ValueError("Unsupported network: " + name)
......
......@@ -21,8 +21,6 @@ The user facing API for computation declaration.
.. autosummary::
tvm.load_json
tvm.save_json
tvm.var
tvm.size_var
tvm.const
......@@ -47,8 +45,7 @@ The user facing API for computation declaration.
tvm.max
tvm.tag_scope
.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json
.. autofunction:: tvm.var
.. autofunction:: tvm.size_var
.. autofunction:: tvm.const
......
......@@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr {
class GlobalVar;
/*!
* \brief Global variable that leaves in the top-level module.
* \brief Global variable that lives in the top-level module.
*
* A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function.
......
......@@ -141,11 +141,12 @@ enum TypeKind : int {
};
/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
* \brief Type parameter in functions.
*
* A type variable can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* the TypeVar of f is TypeVar("n", kind=kShapeVar).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
......
......@@ -165,7 +165,7 @@ using TypeRelationFn =
const TypeReporter& reporter)>;
/*!
* \brief User defined type relation, is an input-output relation on types.
* \brief User defined type relation, it is an input-output relation on types.
*
* TypeRelation is more generalized than type call as it allows inference
* of both inputs and outputs.
......
......@@ -33,6 +33,12 @@ from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import ndarray as nd
# tvm.ir
from .ir import IRModule
from .ir import transform
from .ir import container
from . import ir
# others
from . import tensor
from . import arith
......@@ -41,10 +47,8 @@ from . import stmt
from . import make
from . import ir_pass
from . import codegen
from . import container
from . import schedule
from . import attrs
from . import ir_builder
from . import target
from . import generic
......
......@@ -87,6 +87,7 @@ class ObjectBase(object):
instead of creating a new Node.
"""
# assign handle first to avoid error raising
# pylint: disable=not-callable
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
......
......@@ -19,9 +19,11 @@
from numbers import Integral as _Integral
import tvm._ffi
import tvm.runtime._ffi_node_api
import tvm.ir
from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container
from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
......@@ -30,9 +32,7 @@ from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from . import container as _container
from . import tag as _tag
from . import json_compact
int8 = "int8"
int32 = "int32"
......@@ -71,66 +71,6 @@ def max_value(dtype):
"""
return _api_internal._max_value(dtype)
def get_env_func(name):
"""Get an EnvFunc by a global name.
Parameters
----------
name: str
The name of the global function.
Returns
-------
env_func : EnvFunc
The result env function.
Note
----
EnvFunc is a Object wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
return _api_internal._EnvFuncGet(name)
def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Object
The loaded tvm node.
"""
try:
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
def save_json(node):
"""Save tvm object as json string.
Parameters
----------
node : Object
A TVM object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
def var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
......@@ -688,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])
if not isinstance(dom, _container.Range):
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
......
......@@ -141,7 +141,7 @@ class BaseGraphTuner(object):
self._logger.propagate = False
# Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module):
if isinstance(graph, tvm.IRModule):
graph = graph["main"]
if isinstance(graph, relay.expr.Function):
......
......@@ -20,6 +20,7 @@ import threading
import topi
import tvm
from tvm import relay, autotvm
from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
......@@ -83,7 +84,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
def _infer_type(node):
"""A method to infer the type of a relay expression."""
mod = relay.Module.from_expr(node)
mod = tvm.IRModule.from_expr(node)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, relay.Function) else entry.body
......@@ -136,7 +137,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
mod = relay.Module.from_expr(relay.Function(params, call))
mod = tvm.IRModule.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build,
args=(mod,
......
......@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions"""
import tvm
from tvm import relay
from tvm.relay import transform
......@@ -136,7 +137,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)
mod = relay.Module.from_expr(updated_expr)
mod = tvm.IRModule.from_expr(updated_expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(updated_expr, relay.Function) else entry.body
......@@ -63,7 +63,7 @@ def extract_from_program(mod, params, ops, target, target_host=None,
Parameters
----------
mod: relay.module.Module or relay.expr.Function
mod: tvm.IRModule or relay.expr.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
......@@ -95,7 +95,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
Parameters
----------
mods: List[relay.module.Module] or List[relay.expr.Function]
mods: List[tvm.IRModule] or List[relay.expr.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
......@@ -151,8 +151,8 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod)
assert isinstance(mod, relay.module.Module), \
mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
......
......@@ -24,6 +24,7 @@ import tvm._ffi
import tvm.runtime
from tvm.runtime import Object, ndarray
from tvm.ir import container
from . import api
from . import _api_internal
from . import tensor
......@@ -31,10 +32,11 @@ from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import container
from . import codegen
from . import target as _target
from . import make
from .stmt import LoweredFunc
class DumpIR(object):
"""
......@@ -58,16 +60,16 @@ class DumpIR(object):
def dump(*args, **kwargs):
"""dump function"""
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)):
return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv
out = retv.body if isinstance(retv, LoweredFunc) else retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, container.LoweredFunc) else x
out = x.body if isinstance(x, LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
......@@ -459,7 +461,7 @@ def _build_for_device(flist, target, target_host):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc:
if func.func_type == LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
......@@ -471,9 +473,9 @@ def _build_for_device(flist, target, target_host):
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == container.LoweredFunc.HostFunc:
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == container.LoweredFunc.DeviceFunc:
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
......@@ -586,9 +588,9 @@ def build(inputs,
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, container.LoweredFunc):
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, container.LoweredFunc):
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
......@@ -612,7 +614,7 @@ def build(inputs,
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, container.LoweredFunc):
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
......
......@@ -38,7 +38,7 @@ class CSRNDArray(object):
The corresponding a dense numpy array,
or a tuple for constructing a sparse matrix directly.
ctx: tvm.TVMContext
ctx: tvmContext
The corresponding context.
shape : tuple of int
......
......@@ -16,12 +16,11 @@
# under the License.
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from tvm.ir.container import Array
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
......
......@@ -24,6 +24,7 @@ import types
import numbers
from enum import Enum
from tvm.ir.container import Array
from .util import _internal_assert
from . import calls
......@@ -32,7 +33,6 @@ from .preprocessor import determine_variable_usage
from ..api import all as _all
from ..api import any as _any
from ..container import Array
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
......
......@@ -21,13 +21,14 @@ import inspect
import logging
import sys
import numpy
from tvm.ir.container import Array
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from .. import stmt as _stmt
from .._ffi.base import numeric_types
from ..tensor import Tensor
from ..container import Array
#pylint: disable=invalid-name
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .type_relation import TypeCall, TypeRelation
from .tensor_type import TensorType
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs
from .container import Array, Map
from . import transform
......@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Module exposed from C++."""
"""FFI APIs for tvm.ir"""
import tvm._ffi
tvm._ffi._init_api("relay._module", __name__)
tvm._ffi._init_api("ir", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
......@@ -14,9 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.transform"""
import tvm._ffi
from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
class Module(Object): ...
tvm._ffi._init_api("transform", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Algebraic data type definitions."""
import tvm._ffi
from .type import Type
from .expr import RelayExpr
from . import _ffi_api
@tvm._ffi.register_object("relay.Constructor")
class Constructor(RelayExpr):
"""Relay ADT constructor.
Parameters
----------
name_hint : str
Name of constructor (only a hint).
inputs : List[Type]
Input types.
belong_to : GlobalTypeVar
Denotes which ADT the constructor belongs to.
"""
def __init__(self, name_hint, inputs, belong_to):
self.__init_handle_by_constructor__(
_ffi_api.Constructor, name_hint, inputs, belong_to)
def __call__(self, *args):
"""Call the constructor.
Parameters
----------
args: List[RelayExpr]
The arguments to the constructor.
Returns
-------
call: RelayExpr
A call to the constructor.
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
return relay.Call(self, args)
@tvm._ffi.register_object("relay.TypeData")
class TypeData(Type):
"""Stores the definition for an Algebraic Data Type (ADT) in Relay.
Note that ADT definitions are treated as type-level functions because
the type parameters need to be given for an instance of the ADT. Thus,
any global type var that is an ADT header needs to be wrapped in a
type call that passes in the type params.
Parameters
----------
header: GlobalTypeVar
The name of the ADT.
ADTs with the same constructors but different names are
treated as different types.
type_vars: List[TypeVar]
Type variables that appear in constructors.
constructors: List[Constructor]
The constructors for the ADT.
"""
def __init__(self, header, type_vars, constructors):
self.__init_handle_by_constructor__(
_ffi_api.TypeData, header, type_vars, constructors)
......@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
""" TVM Attribute module, which is mainly used for defining attributes of operators."""
import tvm._ffi
from tvm.runtime import Object
from . import _api_internal
from . import _ffi_api
@tvm._ffi.register_object
......@@ -36,7 +36,7 @@ class Attrs(Object):
infos: list of AttrFieldInfo
List of field information
"""
return _api_internal._AttrsListFieldInfo(self)
return _ffi_api.AttrsListFieldInfo(self)
def keys(self):
"""Get list of names in the attribute.
......@@ -91,6 +91,3 @@ class Attrs(Object):
def __getitem__(self, item):
return self.__getattr__(item)
tvm._ffi._init_api("tvm.attrs")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common base structures."""
import tvm._ffi
import tvm.error
import tvm.runtime._ffi_node_api
from tvm.runtime import Object
from . import _ffi_api
from . import json_compact
class Node(Object):
"""Base class of all IR Nodes, implements astext function."""
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Parameters
----------
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[Object->str]
Optionally annotate function to provide additional
information in the comment block.
Note
----
The meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big (e.g constant weights),
so it can be helpful to skip printing the meta data section.
Returns
-------
text : str
The text format of the expression.
"""
return _ffi_api.AsText(self, show_meta_data, annotate)
def __str__(self):
return self.astext(show_meta_data=False)
@tvm._ffi.register_object("relay.SourceName")
class SourceName(Object):
"""A identifier for a source location.
Parameters
----------
name : str
The name of the source.
"""
def __init__(self, name):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
@tvm._ffi.register_object("relay.Span")
class Span(Object):
"""Specifies a location in a source program.
Parameters
----------
source : SourceName
The source name.
lineno : int
The line number.
col_offset : int
The column offset of the location.
"""
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(
_ffi_api.Span, source, lineno, col_offset)
@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _ffi_api.EnvFuncCall(self, *args)
@property
def func(self):
return _ffi_api.EnvFuncGetPackedFunc(self)
@staticmethod
def get(name):
"""Get a static env function
Parameters
----------
name : str
The name of the function.
"""
return _ffi_api.EnvFuncGet(name)
def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Object
The loaded tvm node.
"""
try:
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except tvm.error.TVMError:
json_str = json_compact.upgrade_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
def save_json(node):
"""Save tvm object as json string.
Parameters
----------
node : Object
A TVM object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
......@@ -14,13 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Container data structures used in TVM DSL."""
"""Additional container data structures used across IR variants."""
import tvm._ffi
from tvm.runtime import Object
from tvm.runtime.container import getitem_helper
from tvm.runtime import _ffi_node_api
from . import _api_internal
@tvm._ffi.register_object
......@@ -41,20 +40,6 @@ class Array(Object):
@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _api_internal._EnvFuncCall(self, *args)
@property
def func(self):
return _api_internal._EnvFuncGetPackedFunc(self)
@tvm._ffi.register_object
class Map(Object):
"""Map container of TVM.
......@@ -87,20 +72,3 @@ class StrMap(Map):
"""Get the items from the map"""
akvs = _ffi_node_api.MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@tvm._ffi.register_object
class Range(Object):
"""Represent a range in TVM.
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
"""
@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common expressions data structures in the IR."""
import tvm._ffi
from .base import Node
from . import _ffi_api
class BaseExpr(Node):
"""Base class of all the expressions."""
class PrimExpr(BaseExpr):
"""Base class of all primitive expressions.
PrimExpr is used in the low-level code
optimizations and integer analysis.
"""
class RelayExpr(BaseExpr):
"""Base class of all non-primitive expressions."""
@property
def checked_type(self):
"""Get the checked type of tvm.relay.Expr.
Returns
-------
checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
" the checked_type for this node")
return ret
class BaseFunc(RelayExpr):
"""Base class of all functions."""
@tvm._ffi.register_object("relay.GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
GlobalVar is used to refer to the global functions
stored in the IRModule.
Parameters
----------
name_hint: str
The name of the variable.
"""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)
def __call__(self, *args):
"""Call the global variable.
Parameters
----------
args: List[RelayExpr]
The arguments to the call.
Returns
-------
call: BaseExpr
A call taking the variable as a function.
"""
# pylint: disable=import-outside-toplevel
if all(isinstance(x, RelayExpr) for x in args):
from tvm import relay
return relay.Call(self, args)
arg_types = [type(x) for x in args]
raise RuntimeError(
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types))
@tvm._ffi.register_object
class Range(Node):
"""Represent a range in TVM.
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
"""
......@@ -14,36 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global module storing everything needed to interpret or compile a Relay program."""
import os
from .base import register_relay_node, RelayNode
from .. import register_func
from .._ffi import base as _base
from . import _make
from . import _module
from . import expr as _expr
from . import ty as _ty
"""IRModule that holds the functions and type definitions."""
from tvm._ffi.base import string_types
import tvm._ffi
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
from .base import Node
from . import expr as _expr
from . import type as _ty
from . import _ffi_api
@register_func("tvm.relay.std_path")
def _std_path():
global __STD_PATH__
return __STD_PATH__
@register_relay_node
class Module(RelayNode):
"""The global Relay module containing collection of functions.
@tvm._ffi.register_object("relay.Module")
class IRModule(Node):
"""IRModule that holds functions and type definitions.
Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and Module is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
IRModule is the basic unit for all IR transformations across the stack.
Parameters
----------
functions: Optional[dict].
Map of global var to Function
Map of global var to BaseFunc
"""
def __init__(self, functions=None, type_definitions=None):
if functions is None:
......@@ -51,7 +41,7 @@ class Module(RelayNode):
elif isinstance(functions, dict):
mapped_funcs = {}
for k, v in functions.items():
if isinstance(k, _base.string_types):
if isinstance(k, string_types):
k = _expr.GlobalVar(k)
if not isinstance(k, _expr.GlobalVar):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
......@@ -62,13 +52,13 @@ class Module(RelayNode):
elif isinstance(type_definitions, dict):
mapped_type_defs = {}
for k, v in type_definitions.items():
if isinstance(k, _base.string_types):
if isinstance(k, string_types):
k = _ty.GlobalTypeVar(k)
if not isinstance(k, _ty.GlobalTypeVar):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_make.Module, functions, type_definitions)
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
def __setitem__(self, var, val):
......@@ -85,18 +75,18 @@ class Module(RelayNode):
return self._add(var, val)
def _add(self, var, val, update=False):
if isinstance(val, _expr.Expr):
if isinstance(var, _base.string_types):
if _module.Module_ContainGlobalVar(self, var):
var = _module.Module_GetGlobalVar(self, var)
if isinstance(val, _expr.RelayExpr):
if isinstance(var, string_types):
if _ffi_api.Module_ContainGlobalVar(self, var):
var = _ffi_api.Module_GetGlobalVar(self, var)
else:
var = _expr.GlobalVar(var)
_module.Module_Add(self, var, val, update)
_ffi_api.Module_Add(self, var, val, update)
else:
assert isinstance(val, _ty.Type)
if isinstance(var, _base.string_types):
if isinstance(var, string_types):
var = _ty.GlobalTypeVar(var)
_module.Module_AddDef(self, var, val, update)
_ffi_api.Module_AddDef(self, var, val, update)
def __getitem__(self, var):
"""Lookup a global definition by name or by variable.
......@@ -111,12 +101,11 @@ class Module(RelayNode):
val: Union[Function, Type]
The definition referenced by :code:`var` (either a function or type).
"""
if isinstance(var, _base.string_types):
return _module.Module_Lookup_str(self, var)
elif isinstance(var, _expr.GlobalVar):
return _module.Module_Lookup(self, var)
else:
return _module.Module_LookupDef(self, var)
if isinstance(var, string_types):
return _ffi_api.Module_Lookup_str(self, var)
if isinstance(var, _expr.GlobalVar):
return _ffi_api.Module_Lookup(self, var)
return _ffi_api.Module_LookupDef(self, var)
def update(self, other):
"""Insert functions in another Module to current one.
......@@ -128,7 +117,7 @@ class Module(RelayNode):
"""
if isinstance(other, dict):
other = Module(other)
return _module.Module_Update(self, other)
return _ffi_api.Module_Update(self, other)
def get_global_var(self, name):
"""Get a global variable in the function by name.
......@@ -145,9 +134,9 @@ class Module(RelayNode):
Raises
------
tvm.TVMError if we cannot find corresponding global var.
tvm.error.TVMError if we cannot find corresponding global var.
"""
return _module.Module_GetGlobalVar(self, name)
return _ffi_api.Module_GetGlobalVar(self, name)
def get_global_vars(self):
"""Collect all global vars defined in this module.
......@@ -157,7 +146,7 @@ class Module(RelayNode):
global_vars: tvm.Array[GlobalVar]
An array of global vars.
"""
return _module.Module_GetGlobalVars(self)
return _ffi_api.Module_GetGlobalVars(self)
def get_global_type_vars(self):
"""Collect all global type vars defined in this module.
......@@ -167,7 +156,7 @@ class Module(RelayNode):
global_type_vars: tvm.Array[GlobalTypeVar]
An array of global type vars.
"""
return _module.Module_GetGlobalTypeVars(self)
return _ffi_api.Module_GetGlobalTypeVars(self)
def get_global_type_var(self, name):
"""Get a global type variable in the function by name.
......@@ -184,9 +173,9 @@ class Module(RelayNode):
Raises
------
tvm.TVMError if we cannot find corresponding global type var.
tvm.error.TVMError if we cannot find corresponding global type var.
"""
return _module.Module_GetGlobalTypeVar(self, name)
return _ffi_api.Module_GetGlobalTypeVar(self, name)
def get_constructor(self, tag):
"""Look up an ADT constructor by tag.
......@@ -203,9 +192,9 @@ class Module(RelayNode):
Raises
------
tvm.TVMError if the corresponding constructor cannot be found.
tvm.error.TVMError if the corresponding constructor cannot be found.
"""
return _module.Module_LookupTag(self, tag)
return _ffi_api.Module_LookupTag(self, tag)
@staticmethod
def from_expr(expr, functions=None, type_defs=None):
......@@ -213,14 +202,15 @@ class Module(RelayNode):
Parameters
----------
expr: Expr
expr: RelayExpr
The starting expression
global_funcs: Optional[dict]
Map of global vars to function definitions
type_defs: Optional[dict]
Map of global type vars to type definitions
Returns
-------
mod: Module
......@@ -230,10 +220,10 @@ class Module(RelayNode):
"""
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs)
return _ffi_api.Module_FromExpr(expr, funcs, defs)
def _import(self, file_to_import):
return _module.Module_Import(self, file_to_import)
return _ffi_api.Module_Import(self, file_to_import)
def import_from_std(self, file_to_import):
return _module.Module_ImportFromStd(self, file_to_import)
return _ffi_api.Module_ImportFromStd(self, file_to_import)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Type relation and function for type checking."""
import tvm._ffi
from .type import Type
from . import _ffi_api
@tvm._ffi.register_object("relay.TensorType")
class TensorType(Type):
"""A concrete TensorType in Relay.
This is the type assigned to tensors with a known dtype and shape.
For example, a tensor of `float32` and `(5, 5)`.
Parameters
----------
shape : List[tvm.ir.PrimExpr]
The shape of the Tensor
dtype : Optional[str]
The content data type.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_ffi_api.TensorType, shape, dtype)
@property
def concrete_shape(self):
"""Get shape of the type as concrete tuple of int.
Returns
-------
shape : List[int]
The concrete shape of the Type.
Raises
------
TypeError : If the shape is symbolic
"""
return tuple(int(x) for x in self.shape)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unified type system in the project."""
from enum import IntEnum
import tvm._ffi
from .base import Node
from . import _ffi_api
class Type(Node):
"""The base class of all types."""
def __eq__(self, other):
"""Compare two types for structural equivalence."""
return bool(_ffi_api.type_alpha_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
class TypeKind(IntEnum):
"""Possible kinds of TypeVars."""
Type = 0
ShapeVar = 1
BaseType = 2
Constraint = 4
AdtHandle = 5
TypeData = 6
@tvm._ffi.register_object("relay.TypeVar")
class TypeVar(Type):
"""Type parameter in functions.
A type variable represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
Parameters
----------
name_hint: str
The name of the type variable. This name only acts as a hint, and
is not used for equality.
kind : Optional[TypeKind]
The kind of the type parameter.
"""
def __init__(self, name_hint, kind=TypeKind.Type):
self.__init_handle_by_constructor__(
_ffi_api.TypeVar, name_hint, kind)
def __call__(self, *args):
"""Create a type call from this type.
Parameters
----------
args: List[Type]
The arguments to the type call.
Returns
-------
call: Type
The result type call.
"""
# pylint: disable=import-outside-toplevel
from .type_relation import TypeCall
return TypeCall(self, args)
@tvm._ffi.register_object("relay.GlobalTypeVar")
class GlobalTypeVar(Type):
"""A global type variable that is used for defining new types or type aliases.
Parameters
----------
name_hint: str
The name of the type variable. This name only acts as a hint, and
is not used for equality.
kind : Optional[TypeKind]
The kind of the type parameter.
"""
def __init__(self, name_hint, kind=TypeKind.AdtHandle):
self.__init_handle_by_constructor__(
_ffi_api.GlobalTypeVar, name_hint, kind)
def __call__(self, *args):
"""Create a type call from this type.
Parameters
----------
args: List[Type]
The arguments to the type call.
Returns
-------
call: Type
The result type call.
"""
# pylint: disable=import-outside-toplevel
from .type_relation import TypeCall
return TypeCall(self, args)
@tvm._ffi.register_object("relay.TupleType")
class TupleType(Type):
"""The type of tuple values.
Parameters
----------
fields : List[Type]
The fields in the tuple
"""
def __init__(self, fields):
self.__init_handle_by_constructor__(
_ffi_api.TupleType, fields)
@tvm._ffi.register_object("relay.TypeConstraint")
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
@tvm._ffi.register_object("relay.FuncType")
class FuncType(Type):
"""Function type.
A function type consists of a list of type parameters to enable
the definition of generic functions,
a set of type constraints which we omit for the time being,
a sequence of argument types, and a return type.
We can informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
Parameters
----------
arg_types : List[tvm.relay.Type]
The argument types
ret_type : tvm.relay.Type
The return type.
type_params : Optional[List[tvm.relay.TypeVar]]
The type parameters
type_constraints : Optional[List[tvm.relay.TypeConstraint]]
The type constraints.
"""
def __init__(self,
arg_types,
ret_type,
type_params=None,
type_constraints=None):
if type_params is None:
type_params = []
if type_constraints is None:
type_constraints = []
self.__init_handle_by_constructor__(
_ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)
@tvm._ffi.register_object("relay.IncompleteType")
class IncompleteType(Type):
"""Incomplete type during type inference.
kind : Optional[TypeKind]
The kind of the incomplete type.
"""
def __init__(self, kind=TypeKind.Type):
self.__init_handle_by_constructor__(
_ffi_api.IncompleteType, kind)
@tvm._ffi.register_object("relay.RefType")
class RelayRefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_ffi_api.RelayRefType, value)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Type relation and function for type checking."""
import tvm._ffi
from .type import Type, TypeConstraint
from . import _ffi_api
class TypeCall(Type):
"""Type function application.
Parameters
----------
func: tvm.ir.Type
The function.
args: List[tvm.ir.Type]
The arguments.
Returns
-------
type_call: TypeCall
The type function application.
"""
def __init__(self, func, args):
self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)
@tvm._ffi.register_object("relay.TypeRelation")
class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types.
TypeRelation is more generalized than TypeCall as it allows inference
of both inputs and outputs.
Parameters
----------
func : EnvFunc
User defined relation function.
args : [tvm.ir.Type]
List of types to the func.
num_inputs : int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
Returns
-------
type_relation : tvm.ir.TypeRelation
The type relation.
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(
_ffi_api.TypeRelation, func, args, num_inputs, attrs)
......@@ -15,16 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Developer API of IR node builder make function."""
from tvm._ffi.base import string_types
from tvm.runtime import ObjectGeneric, DataType
from ._ffi.base import string_types
from tvm.ir import container as _container
from . import api as _api
from . import stmt as _stmt
from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from . import container as _container
from .expr import Call as _Call
class WithScope(object):
......
......@@ -16,7 +16,6 @@
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
import os
from sys import setrecursionlimit
from ..api import register_func
......@@ -25,7 +24,6 @@ from . import ty
from . import expr
from . import type_functor
from . import expr_functor
from . import module
from . import adt
from . import analysis
from . import transform
......@@ -66,14 +64,11 @@ setrecursionlimit(10000)
# Span
Span = base.Span
# Env
Module = module.Module
# Type
Type = ty.Type
TupleType = ty.TupleType
TensorType = ty.TensorType
Kind = ty.Kind
TypeKind = ty.TypeKind
TypeVar = ty.TypeVar
ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint
......@@ -87,7 +82,7 @@ TypeCall = ty.TypeCall
Any = ty.Any
# Expr
Expr = expr.Expr
Expr = expr.RelayExpr
Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
......
......@@ -37,8 +37,9 @@ except ImportError:
return deque.__new__(cls, *args, **kwds)
import tvm
import tvm.ir._ffi_api
from tvm.ir import IRModule
from . import module
from .base import Span, SourceName
from . import adt
from . import expr
......@@ -190,7 +191,7 @@ def spanify(f):
sp = Span(sn, line, col)
if isinstance(ast, tvm.relay.expr.TupleWrapper):
ast = ast.astuple()
ast.set_span(sp)
tvm.ir._ffi_api.NodeSetSpan(ast, sp)
return ast
return _wrapper
......@@ -201,7 +202,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def __init__(self, source_name: str) -> None:
self.source_name = source_name
self.module = module.Module({}) # type: module.Module
self.module = IRModule({}) # type: IRModule
# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
......@@ -243,7 +244,7 @@ class ParseTreeToRelayIR(RelayVisitor):
"""Pop off the current TypeVar scope and return it."""
return self.type_var_scopes.popleft()
def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar:
def mk_typ(self, name: str, kind: ty.TypeKind) -> ty.TypeVar:
"""Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.TypeVar(name, kind)
self.type_var_scopes[0].append((name, typ))
......@@ -274,7 +275,7 @@ class ParseTreeToRelayIR(RelayVisitor):
if isinstance(e, adt.Constructor):
return "`{0}` ADT constructor".format(e.belong_to.name_hint)
if isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle:
if e.kind == ty.TypeKind.AdtHandle:
return "ADT definition"
return "function definition"
......@@ -352,12 +353,12 @@ class ParseTreeToRelayIR(RelayVisitor):
return self.visit(ctx)
def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, module.Module]:
def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]:
self.meta = None
if ctx.METADATA():
header, data = str(ctx.METADATA()).split("\n", 1)
assert header == "METADATA:"
self.meta = tvm.load_json(data)
self.meta = tvm.ir.load_json(data)
if ctx.defn():
self.visit_list(ctx.defn())
return self.module
......@@ -492,7 +493,7 @@ class ParseTreeToRelayIR(RelayVisitor):
assert type_params
for ty_param in type_params:
name = ty_param.getText()
self.mk_typ(name, ty.Kind.Type)
self.mk_typ(name, ty.TypeKind.Type)
var_list, attr_list = self.visit(ctx.argList())
if var_list is None:
......@@ -528,13 +529,13 @@ class ParseTreeToRelayIR(RelayVisitor):
ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]):
"""Handles parsing of the name and type params of an ADT definition."""
adt_name = ctx.generalIdent().getText()
adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle)
adt_var = self.mk_global_typ_var(adt_name, ty.TypeKind.AdtHandle)
# parse type params
type_params = ctx.typeParamList()
if type_params is None:
type_params = []
else:
type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type)
type_params = [self.mk_typ(type_ident.getText(), ty.TypeKind.Type)
for type_ident in type_params.typeExpr()]
return adt_var, type_params
......@@ -746,7 +747,7 @@ class StrictErrorListener(ErrorListener):
def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
raise Exception("Context Sensitivity in:\n" + self.text)
def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]:
def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, IRModule]:
"""Parse a Relay program."""
if data == "":
raise ParseError("cannot parse the empty string.")
......
......@@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Algebraic data types in Relay."""
from tvm.ir import Constructor, TypeData
from .base import RelayNode, register_relay_node, Object
from . import _make
from .ty import Type
from .expr import Expr, Call
from .expr import ExprWithOp, RelayExpr, Call
class Pattern(RelayNode):
......@@ -113,77 +115,6 @@ class PatternTuple(Pattern):
@register_relay_node
class Constructor(Expr):
"""Relay ADT constructor."""
def __init__(self, name_hint, inputs, belong_to):
"""Defines an ADT constructor.
Parameters
----------
name_hint : str
Name of constructor (only a hint).
inputs : List[Type]
Input types.
belong_to : tvm.relay.GlobalTypeVar
Denotes which ADT the constructor belongs to.
Returns
-------
con: Constructor
A constructor.
"""
self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to)
def __call__(self, *args):
"""Call the constructor.
Parameters
----------
args: List[relay.Expr]
The arguments to the constructor.
Returns
-------
call: relay.Call
A call to the constructor.
"""
return Call(self, args)
@register_relay_node
class TypeData(Type):
"""Stores the definition for an Algebraic Data Type (ADT) in Relay.
Note that ADT definitions are treated as type-level functions because
the type parameters need to be given for an instance of the ADT. Thus,
any global type var that is an ADT header needs to be wrapped in a
type call that passes in the type params.
"""
def __init__(self, header, type_vars, constructors):
"""Defines a TypeData object.
Parameters
----------
header: tvm.relay.GlobalTypeVar
The name of the ADT.
ADTs with the same constructors but different names are
treated as different types.
type_vars: List[TypeVar]
Type variables that appear in constructors.
constructors: List[tvm.relay.Constructor]
The constructors for the ADT.
Returns
-------
type_data: TypeData
The adt declaration.
"""
self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors)
@register_relay_node
class Clause(Object):
"""Clause for pattern matching in Relay."""
......@@ -206,7 +137,7 @@ class Clause(Object):
@register_relay_node
class Match(Expr):
class Match(ExprWithOp):
"""Pattern matching expression in Relay."""
def __init__(self, data, clauses, complete=True):
......
......@@ -20,11 +20,11 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
from tvm.ir import RelayExpr, IRModule
from . import _analysis
from . import _make
from .expr import Expr
from .ty import Type
from .module import Module
from .feature import Feature
......@@ -70,7 +70,7 @@ def check_kind(t, mod=None):
t : tvm.relay.Type
The type to check
mod : Optional[tvm.relay.Module]
mod : Optional[tvm.IRModule]
The global module.
Returns
......@@ -169,7 +169,7 @@ def free_type_vars(expr, mod=None):
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod : Optional[tvm.relay.Module]
mod : Optional[tvm.IRModule]
The global module
Returns
......@@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None):
free : List[tvm.relay.TypeVar]
The list of free type variables in post-DFS order
"""
use_mod = mod if mod is not None else Module()
use_mod = mod if mod is not None else IRModule()
return _analysis.free_type_vars(expr, use_mod)
......@@ -189,7 +189,7 @@ def bound_type_vars(expr, mod=None):
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod : Optional[tvm.relay.Module]
mod : Optional[tvm.IRModule]
The global module
Returns
......@@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None):
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
use_mod = mod if mod is not None else Module()
use_mod = mod if mod is not None else IRModule()
return _analysis.bound_type_vars(expr, use_mod)
......@@ -209,7 +209,7 @@ def all_type_vars(expr, mod=None):
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod : Optional[tvm.relay.Module]
mod : Optional[tvm.IRModule]
The global module
Returns
......@@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None):
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
use_mod = mod if mod is not None else Module()
use_mod = mod if mod is not None else IRModule()
return _analysis.all_type_vars(expr, use_mod)
......@@ -353,7 +353,7 @@ def unmatched_cases(match, mod=None):
match : tvm.relay.Match
The match expression
mod : Optional[tvm.relay.Module]
mod : Optional[tvm.IRModule]
The module (defaults to an empty module)
Returns
......@@ -370,10 +370,10 @@ def detect_feature(a, b=None):
Parameters
----------
a : Union[tvm.relay.Expr, tvm.relay.Module]
a : Union[tvm.relay.Expr, tvm.IRModule]
The input expression or module.
b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]]
b : Optional[Union[tvm.relay.Expr, tvm.IRModule]]
The input expression or module.
The two arguments cannot both be expression or module.
......@@ -382,7 +382,7 @@ def detect_feature(a, b=None):
features : Set[Feature]
Features used in the program.
"""
if isinstance(a, Module):
if isinstance(a, IRModule):
a, b = b, a
return {Feature(int(x)) for x in _analysis.detect_feature(a, b)}
......@@ -400,7 +400,7 @@ def structural_hash(value):
result : int
The hash value
"""
if isinstance(value, Expr):
if isinstance(value, RelayExpr):
return int(_analysis._expr_hash(value))
elif isinstance(value, Type):
return int(_analysis._type_hash(value))
......
......@@ -16,9 +16,9 @@
# under the License.
"""The interface of expr function exposed from C++."""
import tvm._ffi
from tvm.ir import container as _container
from ... import build_module as _build
from ... import container as _container
@tvm._ffi.register_func("relay.backend.lower")
......
......@@ -21,9 +21,10 @@ from __future__ import absolute_import
import numpy as np
from tvm.runtime import container
from tvm.ir import IRModule
from . import _backend
from .. import _make, analysis, transform
from .. import module
from ... import nd
from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
......@@ -186,10 +187,10 @@ class Interpreter(Executor):
Parameters
----------
mod : tvm.relay.Module
mod : tvm.IRModule
The module to support the execution.
ctx : tvm.TVMContext
ctx : tvmContext
The runtime context to run the code on.
target : tvm.Target
......@@ -205,7 +206,7 @@ class Interpreter(Executor):
Returns
-------
opt_mod : tvm.relay.Module
opt_mod : tvm.IRModule
The optimized module.
"""
seq = transform.Sequential([transform.SimplifyInference(),
......@@ -239,7 +240,7 @@ class Interpreter(Executor):
if self.mod:
self.mod["main"] = func
else:
self.mod = module.Module.from_expr(func)
self.mod = IRModule.from_expr(func)
mod = self.optimize()
opt_expr = Call(mod["main"], relay_args)
......
......@@ -36,7 +36,7 @@ def compile(mod, target=None, target_host=None, params=None):
Parameters
----------
mod : relay.Module
mod : tvm.IRModule
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
......@@ -110,7 +110,7 @@ class VMCompiler(object):
Parameters
----------
mod : relay.Module
mod : tvm.IRModule
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
......@@ -142,7 +142,7 @@ class VMCompiler(object):
Parameters
----------
mod : relay.Module
mod : tvm.IRModule
target : str, :any:`tvm.target.Target`, or dict of str (i.e.
device/context name) to str/tvm.target.Target, optional
......@@ -153,7 +153,7 @@ class VMCompiler(object):
Returns
-------
mod : relay.Module
mod : tvm.IRModule
The optimized relay module.
params : dict
......@@ -229,10 +229,10 @@ class VMExecutor(Executor):
Parameters
----------
mod : :py:class:`~tvm.relay.module.Module`
mod : :py:class:`~tvm.IRModule`
The module to support the execution.
ctx : :py:class:`~tvm.TVMContext`
ctx : :py:class:`~tvmContext`
The runtime context to run the code on.
target : :py:class:`Target`
......
......@@ -14,16 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck
# pylint: disable=no-else-return, unidiomatic-typecheck, unused-import
"""The base node types for the Relay language."""
import os
import tvm._ffi
from tvm.runtime import Object
from tvm.ir import SourceName, Span, Node as RelayNode
from . import _make
from . import _expr
from . import _base
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
@tvm._ffi.register_func("tvm.relay.std_path")
def _std_path():
return __STD_PATH__
def register_relay_node(type_key=None):
"""Register a Relay node type.
......@@ -52,55 +61,6 @@ def register_relay_attr_node(type_key=None):
return tvm._ffi.register_object(type_key)
class RelayNode(Object):
"""Base class of all Relay nodes."""
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Parameters
----------
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
Note
----
The meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big (e.g constant weights),
so it can be helpful to skip printing the meta data section.
Returns
-------
text : str
The text format of the expression.
"""
return _expr.AsText(self, show_meta_data, annotate)
def set_span(self, span):
_base.set_span(self, span)
def __str__(self):
return self.astext(show_meta_data=False)
@register_relay_node
class Span(RelayNode):
"""Specifies a location in a source program."""
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
@register_relay_node
class SourceName(RelayNode):
"""A identifier for a source location"""
def __init__(self, name):
self.__init_handle_by_constructor__(_make.SourceName, name)
@register_relay_node
class Id(Object):
"""Unique identifier(name) used in Var.
......
......@@ -21,13 +21,14 @@ from a Relay expression.
import warnings
import numpy as np
from tvm.ir import IRModule
from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ty as _ty
from . import expr as _expr
from .module import Module as _Module
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
......@@ -141,7 +142,7 @@ class BuildModule(object):
Returns
-------
mod : relay.Module
mod : tvm.IRModule
The optimized relay module.
params : dict
......@@ -185,7 +186,7 @@ def build(mod, target=None, target_host=None, params=None):
Parameters
----------
mod : relay.Module
mod : tvm.IRModule
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
......@@ -217,16 +218,16 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
if isinstance(mod, _Module):
if isinstance(mod, IRModule):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
warnings.warn(
"Please use input parameter mod (tvm.relay.module.Module) "
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")
raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target)
......@@ -254,7 +255,7 @@ def optimize(mod, target=None, params=None):
Parameters
----------
mod : relay.Module
mod : tvm.IRModule
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
......@@ -268,7 +269,7 @@ def optimize(mod, target=None, params=None):
Returns
-------
mod : relay.Module
mod : tvm.IRModule
The optimized relay module.
params : dict
......@@ -279,11 +280,11 @@ def optimize(mod, target=None, params=None):
elif isinstance(mod, _expr.Function):
func = mod
warnings.warn(
"Please use input parameter mod (tvm.relay.module.Module) "
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")
raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target)
......@@ -330,7 +331,7 @@ class GraphExecutor(_interpreter.Executor):
Parameters
----------
mod : :py:class:`~tvm.relay.module.Module`
mod : :py:class:`~tvm.IRModule`
The module to support the execution.
ctx : :py:class:`TVMContext`
......@@ -385,17 +386,17 @@ def create_executor(kind="debug",
kind : str
The type of executor
mod : :py:class:`~tvm.relay.module.Module`
mod : :py:class:`~tvm.IRModule`
The Relay module containing collection of functions
ctx : :py:class:`tvm.TVMContext`
ctx : :py:class:`tvmContext`
The context to execute the code.
target : :py:class:`tvm.Target`
The corresponding context
"""
if mod is None:
mod = _Module()
mod = IRModule()
if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type
else:
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
# pylint: disable=no-else-return, invalid-name, unused-import
"""The expression nodes of Relay."""
from __future__ import absolute_import
from numbers import Number as _Number
......@@ -22,33 +22,21 @@ from numbers import Number as _Number
import numpy as _np
from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar, BaseFunc
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
# alias relay expr as Expr.
Expr = RelayExpr
# will be registered afterwards
_op_make = None
class Expr(RelayNode):
"""The base type for all Relay expressions."""
@property
def checked_type(self):
"""Get the checked type of tvm.relay.Expr.
Returns
-------
checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
" the checked_type for this node")
return ret
class ExprWithOp(RelayExpr):
"""Basetype of all relay expressions that defines op overloading."""
def astype(self, dtype):
"""Cast the content type of the current data to dtype.
......@@ -173,7 +161,7 @@ class Expr(RelayNode):
return Call(self, args)
@register_relay_node
class Constant(Expr):
class Constant(ExprWithOp):
"""A constant expression in Relay.
Parameters
......@@ -186,7 +174,7 @@ class Constant(Expr):
@register_relay_node
class Tuple(Expr):
class Tuple(ExprWithOp):
"""Tuple expression that groups several fields together.
Parameters
......@@ -210,7 +198,7 @@ class Tuple(Expr):
@register_relay_node
class Var(Expr):
class Var(ExprWithOp):
"""A local variable in Relay.
Local variable can be used to declare input
......@@ -238,33 +226,7 @@ class Var(Expr):
@register_relay_node
class GlobalVar(Expr):
"""A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the module.
Parameters
----------
name_hint: str
The name of the variable.
"""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
def __call__(self, *args):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
@register_relay_node
class Function(Expr):
class Function(BaseFunc):
"""A function declaration expression.
Parameters
......@@ -320,7 +282,7 @@ class Function(Expr):
@register_relay_node
class Call(Expr):
class Call(ExprWithOp):
"""Function call node in Relay.
Call node corresponds the operator application node
......@@ -349,7 +311,7 @@ class Call(Expr):
@register_relay_node
class Let(Expr):
class Let(ExprWithOp):
"""Let variable binding expression.
Parameters
......@@ -369,7 +331,7 @@ class Let(Expr):
@register_relay_node
class If(Expr):
class If(ExprWithOp):
"""A conditional expression in Relay.
Parameters
......@@ -389,7 +351,7 @@ class If(Expr):
@register_relay_node
class TupleGetItem(Expr):
class TupleGetItem(ExprWithOp):
"""Get index-th item from a tuple.
Parameters
......@@ -406,7 +368,7 @@ class TupleGetItem(Expr):
@register_relay_node
class RefCreate(Expr):
class RefCreate(ExprWithOp):
"""Create a new reference from initial value.
Parameters
----------
......@@ -418,7 +380,7 @@ class RefCreate(Expr):
@register_relay_node
class RefRead(Expr):
class RefRead(ExprWithOp):
"""Get the value inside the reference.
Parameters
----------
......@@ -430,7 +392,7 @@ class RefRead(Expr):
@register_relay_node
class RefWrite(Expr):
class RefWrite(ExprWithOp):
"""
Update the value inside the reference.
The whole expression will evaluate to an empty tuple.
......@@ -445,7 +407,7 @@ class RefWrite(Expr):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)
class TempExpr(Expr):
class TempExpr(ExprWithOp):
"""Baseclass of all TempExpr.
TempExprs are pass specific expression that can be
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List
import tvm
from .base import Span, Object
from .ty import Type, TypeParam
from ._analysis import _get_checked_type
class Expr(Object):
def checked_type(self):
...
def __call__(self, *args):
...
class Constant(Expr):
data = ... # type: tvm.nd.NDArray
def __init__(self, data):
# type: (tvm.nd.NDArray) -> None
...
class Tuple(Expr):
fields = ... # type: List[Expr]
def __init__(self, fields):
# type: (List[Expr]) -> None
...
class Var(Expr):
"""A local variable in Relay."""
name_hint = ... # type: str
def __init__(self, name_hint):
# type: (str) -> None
...
class GlobalVar(Expr):
name_hint = ... # type: str
def __init__(self, name_hint):
# type: (str) -> None
...
class Param(Expr):
var = ... # type: Var
type = ... # type: Type
def __init__(self, var, ty):
# type: (Var, Type) -> None
...
class Function(Expr):
"""A function in Relay, see tvm/relay/expr.h for more details."""
type_params = ... # type: List[TypeParam]
params = ... # type: List[Param]
ret_type = ... # type: Type
body = ... # type: Expr
def __init__(self,
params, # type: List[Param],
ret_type, # type: Type,
body, # type: Expr,
type_params=None, # type: List[TypeParam]
):
# type: (...) -> None
...
@register_relay_node
class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
op = ... # type: Expr
args = ... # type: List[Expr]
# todo(@jroesch): add attrs. revise attrs type in __init__
def __init__(self, op, args, attrs=None, ty_args=None):
# type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None
if not ty_args:
ty_args = []
self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, ty_args)
@register_relay_node
class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
var = ... # type: Var
value = ... # type: Expr
body = ... # type: Expr
value_type = ... # type: Type
def __init__(self, var, value, body, value_type):
# type: (Var, Expr, Expr, Type) -> None
...
@register_relay_node
class If(Expr):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
cond = ... # type: Expr
true_value = ... # type: Expr
false_value = ... # type: Expr
span = ... # type: Span
def __init__(self, cond, true_value, false_value):
# type: (Expr, Expr, Expr) -> None
...
......@@ -16,11 +16,11 @@
# under the License.
# pylint: disable=import-self, invalid-name, line-too-long, unused-argument
"""Caffe2 frontend"""
from __future__ import absolute_import as _abs
import tvm
from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from ... import nd as _nd
from .common import AttrCvt, Renamer
......@@ -383,7 +383,7 @@ class Caffe2NetDef(object):
self._ops = {}
self._shape = shape
self._dtype = dtype
self._mod = _module.Module({})
self._mod = IRModule({})
def from_caffe2(self, init_net, predict_net):
"""Construct Relay expression from caffe2 graph.
......@@ -395,7 +395,7 @@ class Caffe2NetDef(object):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The module that optimizations will be performed on.
params : dict
......@@ -565,7 +565,7 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The module that optimizations will be performed on.
params : dict of str to tvm.nd.NDArray
......
......@@ -20,9 +20,10 @@ import logging
import numpy as np
import tvm
from tvm.ir import IRModule
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
from .. import transform as _transform
from .. import op as _op
from .. import analysis
......@@ -453,7 +454,7 @@ def get_name(node):
def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
new_mod = _module.Module.from_expr(node)
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
......
......@@ -21,9 +21,10 @@ from __future__ import absolute_import as _abs
import math
import numpy as np
import tvm
from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from ... import nd as _nd
from ..._ffi import base as _base
......@@ -449,7 +450,7 @@ def from_coreml(model, shape=None):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module for compilation.
params : dict of str to tvm.nd.NDArray
......@@ -505,4 +506,4 @@ def from_coreml(model, shape=None):
outexpr = outexpr[0]
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return _module.Module.from_expr(func), params
return IRModule.from_expr(func), params
......@@ -23,9 +23,10 @@ from __future__ import absolute_import as _abs
from enum import Enum
import numpy as np
import tvm
from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .common import get_relay_op, new_var
__all__ = ['from_darknet']
......@@ -822,7 +823,7 @@ class GraphProto(object):
outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
sym = _expr.Function(analysis.free_vars(outputs), outputs)
return _module.Module.from_expr(sym), self._tvmparams
return IRModule.from_expr(sym), self._tvmparams
def from_darknet(net,
shape=None,
......@@ -840,7 +841,7 @@ def from_darknet(net,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module for compilation.
params : dict of str to tvm.nd.NDArray
......
......@@ -19,9 +19,10 @@
import sys
import numpy as np
import tvm
from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable, new_var
......@@ -752,7 +753,7 @@ def from_keras(model, shape=None):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module for compilation.
params : dict of str to tvm.nd.NDArray
......@@ -837,4 +838,4 @@ def from_keras(model, shape=None):
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return _module.Module.from_expr(func), params
return IRModule.from_expr(func), params
......@@ -21,12 +21,13 @@ from __future__ import absolute_import as _abs
import json
import numpy as np
import tvm
from tvm.ir import IRModule
from tvm import relay
from topi.util import get_const_tuple
from .. import analysis
from .. import expr as _expr
from .. import op as _op
from .. import module as _module
from .. import scope_builder as _scope_builder
from ... import nd as _nd
......@@ -1902,7 +1903,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
dtype_info : dict or str.
Known parameter dtypes
mod : tvm.relay.Module
mod : tvm.IRModule
The module that contains global information. It will be used for
converting ops that need global information, e.g. control-flow ops.
......@@ -2009,7 +2010,7 @@ def from_mxnet(symbol,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module for compilation
params : dict of str to tvm.nd.NDArray
......@@ -2020,7 +2021,7 @@ def from_mxnet(symbol,
except ImportError as e:
raise ImportError("{}. MXNet is required to parse symbols.".format(e))
mod = _module.Module()
mod = IRModule()
if isinstance(symbol, mx.sym.Symbol):
params = {}
arg_params = arg_params if arg_params else {}
......
......@@ -17,14 +17,13 @@
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from tvm.ir import IRModule
from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
......@@ -1615,7 +1614,7 @@ class GraphProto(object):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The returned relay module
params : dict
......@@ -1708,7 +1707,7 @@ class GraphProto(object):
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs)
return _module.Module.from_expr(func), self._params
return IRModule.from_expr(func), self._params
def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str."""
......@@ -1836,7 +1835,7 @@ def from_onnx(model,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module for compilation
params : dict of str to tvm.nd.NDArray
......
......@@ -29,13 +29,13 @@ import numpy as np
import tvm
from tvm.ir import IRModule
from tvm.relay.prelude import Prelude
from .. import analysis
from .. import expr as _expr
from .. import op as _op
from ..expr_functor import ExprMutator
from .. import module as _module
from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
......@@ -2136,7 +2136,7 @@ class GraphProto(object):
self._input_shapes = {}
self._loops = {}
self._branches = {}
self._mod = _module.Module({})
self._mod = IRModule({})
self._prelude = Prelude(self._mod)
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
......@@ -2171,7 +2171,7 @@ class GraphProto(object):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The module that optimizations will be performed on.
params : dict
......@@ -2653,7 +2653,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The module that optimizations will be performed on.
params : dict of str to tvm.nd.NDArray
......
......@@ -17,14 +17,14 @@
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend."""
from __future__ import absolute_import as _abs
import math
import numpy as np
import tvm
from tvm.ir import IRModule
from tvm import relay
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from .. import qnn as _qnn
from ..util import get_scalar_from_constant
......@@ -1901,7 +1901,7 @@ def from_tflite(model, shape_dict, dtype_dict):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module for compilation.
params : dict of str to tvm.nd.NDArray
......@@ -1940,5 +1940,5 @@ def from_tflite(model, shape_dict, dtype_dict):
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs)
mod = _module.Module.from_expr(func)
mod = IRModule.from_expr(func)
return mod, params
......@@ -176,7 +176,7 @@ class ManifestAllocPass(ExprMutator):
view = LinearizeRetType(ret_type)
out_types = view.unpack()
is_dynamic = ret_type.is_dynamic()
is_dynamic = ty.type_has_any(ret_type)
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
......
......@@ -41,7 +41,6 @@ from . import _tensor_grad
from . import _transform
from . import _reduce
from . import _algorithm
from ..expr import Expr
from ..base import register_relay_node
......
......@@ -275,7 +275,7 @@ def legalize_conv2d(attrs, inputs, types):
Parameters
----------
attrs : tvm.attrs.Attrs
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
......@@ -296,7 +296,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
Parameters
----------
attrs : tvm.attrs.Attrs
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
......@@ -413,7 +413,7 @@ def legalize_conv2d_transpose(attrs, inputs, types):
Parameters
----------
attrs : tvm.attrs.Attrs
attrs : tvm.ir.Attrs
Attributes of current Transposed convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
......@@ -947,7 +947,7 @@ def legalize_bitserial_conv2d(attrs, inputs, types):
Parameters
----------
attrs : tvm.attrs.Attrs
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
......
......@@ -16,8 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-variable
"""NN operator common utilities"""
from __future__ import absolute_import
from .... import container
from tvm.ir import container
def get_pad_tuple2d(padding):
"""Common code to get the pad option
......
......@@ -20,13 +20,13 @@ import topi
import tvm._ffi
from ..base import register_relay_node
from ..expr import Expr
from ..expr import RelayExpr
from ...api import register_func
from ...build_module import lower, build
from . import _make
@register_relay_node
class Op(Expr):
class Op(RelayExpr):
"""A Relay operator definition."""
def __init__(self):
......
......@@ -16,7 +16,7 @@
# under the License.
"""The attributes node used for Relay operators"""
from ...attrs import Attrs
from tvm.ir import Attrs
from ..base import register_relay_attr_node
......
......@@ -16,13 +16,15 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule
from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, Function, GlobalVar, If, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op
from .module import Module
class TensorArrayOps(object):
"""Contains tensor array related ops"""
......@@ -648,7 +650,7 @@ class Prelude:
def __init__(self, mod=None):
if mod is None:
mod = Module()
mod = IRModule()
self.mod = mod
self.load_prelude()
......
......@@ -63,7 +63,7 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
Parameters
----------
attrs : tvm.attrs.Attrs
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
......@@ -106,7 +106,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
Parameters
----------
attrs : tvm.attrs.Attrs
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
......@@ -169,7 +169,7 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
Parameters
----------
attrs : tvm.attrs.Attrs
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
......
......@@ -42,7 +42,7 @@ def CanonicalizeOps():
# We want to utilize all the existing Relay infrastructure. So, instead of supporting this
# QNN requantize op, we convert it into a sequence of existing Relay operators.
mod = relay.Module.from_expr(qnn_expr)
mod = tvm.IRModule.from_expr(qnn_expr)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
relay_expr = mod['main']
print(relay_expr)
......
......@@ -20,12 +20,12 @@ import logging
import multiprocessing as mp
import numpy as np
import tvm
from tvm.ir import IRModule
from . import _quantize
from . import quantize
from .. import op as _op
from .. import expr as _expr
from .. import module as _module
from .. import analysis as _analysis
from .. import transform as _transform
from .. import build_module as _build_module
......@@ -141,7 +141,7 @@ def _set_params(mod, input_scale_func, weight_scale_func):
func = mod['main']
_analysis.post_order_visit(func, visit_func)
func = _expr.bind(func, const_params)
return _module.Module.from_expr(func)
return IRModule.from_expr(func)
# weight scale functions
......
......@@ -47,7 +47,7 @@ from ..transform import gradient
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
......
......@@ -103,7 +103,7 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype=
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a DCGAN network.
params : dict of str to NDArray
The parameters.
......
......@@ -105,7 +105,7 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4,
Returns
-------
mod: tvm.relay.Module
mod: tvm.IRModule
The relay module that contains a DenseNet network.
params : dict of str to NDArray
......
......@@ -72,7 +72,7 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo
The data type
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a DQN network.
params : dict of str to NDArray
The parameters.
......
......@@ -290,7 +290,7 @@ def get_workload(batch_size=1, num_classes=1000,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains an Inception V3 network.
params : dict of str to NDArray
......
......@@ -144,13 +144,13 @@ def create_workload(net, initializer=None, seed=0):
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The created relay module.
params : dict of str to NDArray
The parameters.
"""
mod = relay.Module.from_expr(net)
mod = tvm.IRModule.from_expr(net)
mod = relay.transform.InferType()(mod)
shape_dict = {
v.name_hint : v.checked_type for v in mod["main"].params}
......
......@@ -173,7 +173,7 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"):
The data type
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a LSTM network.
params : dict of str to NDArray
The parameters.
......
......@@ -84,7 +84,7 @@ def get_workload(batch_size,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a mlp network.
params : dict of str to NDArray
......
......@@ -151,7 +151,7 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224),
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a MobileNet network.
params : dict of str to NDArray
......
......@@ -584,7 +584,7 @@ class PythonConverter(ExprFunctor):
def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
"""Converts the given Relay expression into a Python script (as a Python AST object).
For easiest debugging, import the astor package and use to_source()."""
mod = mod if mod is not None else relay.Module()
mod = mod if mod is not None else tvm.IRModule()
converter = PythonConverter(mod, target)
return converter.convert(expr)
......@@ -592,7 +592,7 @@ def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
"""Converts the given Relay expression into a Python script and
executes it."""
mod = mod if mod is not None else relay.Module()
mod = mod if mod is not None else tvm.IRModule()
py_ast = to_python(expr, mod, target)
code = compile(py_ast, '<string>', 'exec')
var_map = {
......
......@@ -262,7 +262,7 @@ def get_workload(batch_size=1,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a ResNet network.
params : dict of str to NDArray
......
......@@ -149,7 +149,7 @@ def get_workload(batch_size=1,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a SqueezeNet network.
params : dict of str to NDArray
......
......@@ -124,7 +124,7 @@ def get_workload(batch_size,
Returns
-------
mod : tvm.relay.Module
mod : tvm.IRModule
The relay module that contains a VGG network.
params : dict of str to NDArray
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from .base import Object
class PassContext(Object):
def __init__(self):
...
class PassInfo(Object):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list
def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None
class Pass(Object):
def __init__(self):
...
class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class Sequential(Pass):
name = ... # type: str
opt_level = ... # type: int
passes = ... # type: list
required = ... # type: list
disabled = ... # type: list
def __init__(self, name, opt_level, passes, required, disabled):
# type: (str, int, list, list, list) -> None
...
......@@ -14,133 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
# pylint: disable=invalid-name, unused-import
"""The type nodes of the Relay language."""
from enum import IntEnum
from tvm.ir import Type, TypeKind, TypeVar, GlobalTypeVar
from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType
from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType
from .base import RelayNode, register_relay_node
from . import _make
Any = _make.Any
class Type(RelayNode):
"""The base type for all Relay types."""
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._alpha_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
def __call__(self, *args):
"""Create a type call from this type.
def type_has_any(tensor_type):
"""Check whether type has any as a shape.
Parameters
----------
args: List[relay.Type]
The arguments to the type call.
Returns
-------
call: relay.TypeCall
"""
return TypeCall(self, args)
def is_dynamic(self):
return _make.IsDynamic(self)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay.
This is the type assigned to tensors with a known dtype and shape. For
example, a tensor of `float32` and `(5, 5)`.
Parameters
----------
shape : List[tvm.Expr]
The shape of the Tensor
dtype : Optional[str]
The content data type.
Default to "float32".
tensor_type : Type
The type to be inspected
Returns
-------
tensor_type : tvm.relay.TensorType
The tensor type.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype)
@property
def concrete_shape(self):
"""Get shape of the type as concrete tuple of int.
Returns
-------
shape : List[int]
The concrete shape of the Type.
Raises
------
TypeError : If the shape is symbolic
"""
return tuple(int(x) for x in self.shape)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
Type = 0
ShapeVar = 1
BaseType = 2
Shape = 3
Constraint = 4
AdtHandle = 5
TypeData = 6
@register_relay_node
class TypeVar(Type):
"""A type variable used for generic types in Relay,
see tvm/relay/type.h for more details.
A type variable represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
has_any : bool
The check result.
"""
return _make.IsDynamic(tensor_type)
def __init__(self, name_hint, kind=Kind.Type):
"""Construct a TypeVar.
Parameters
----------
name_hint: str
The name of the type variable. This name only acts as a hint, and
is not used for equality.
kind : Optional[Kind]
The kind of the type parameter.
Default to Kind.Type.
Returns
-------
type_var : tvm.relay.TypeVar
The type variable.
"""
self.__init_handle_by_constructor__(_make.TypeVar, name_hint, kind)
def ShapeVar(name):
"""A helper which constructs a type var of which the shape kind.
......@@ -154,172 +51,9 @@ def ShapeVar(name):
type_var : tvm.relay.TypeVar
The shape variable.
"""
return TypeVar(name, kind=Kind.ShapeVar)
@register_relay_node
class GlobalTypeVar(Type):
"""A global type variable in Relay.
GlobalTypeVar is used to refer to the global type-level definitions
stored in the environment.
"""
def __init__(self, name_hint, kind=Kind.AdtHandle):
"""Construct a GlobalTypeVar.
Parameters
----------
name_hint: str
The name of the global type variable. This name only acts as a
hint, and is not used for equality.
kind: Kind, optional
The kind of the type parameter, Kind.AdtHandle by default.
Returns
-------
type_var: GlobalTypeVar
The global type variable.
"""
self.__init_handle_by_constructor__(_make.GlobalTypeVar, name_hint, kind)
@register_relay_node
class TypeCall(Type):
"""Type-level function application in Relay.
A type call applies argument types to a constructor (type-level function).
"""
def __init__(self, func, args):
"""Construct a TypeCall.
Parameters
----------
func: tvm.relay.Type
The function.
args: List[tvm.expr.Type]
The arguments.
Returns
-------
type_call: TypeCall
The type function application.
"""
self.__init_handle_by_constructor__(_make.TypeCall, func, args)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields : List[tvm.relay.Type]
The fields in the tuple
Returns
-------
tuple_type : tvm.relay.TupleType
the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
functions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
Parameters
----------
arg_types : List[tvm.relay.Type]
The argument types
ret_type : tvm.relay.Type
The return type.
type_params : Optional[List[tvm.relay.TypeVar]]
The type parameters
type_constraints : Optional[List[tvm.relay.TypeConstraint]]
The type constraints.
"""
def __init__(self,
arg_types,
ret_type,
type_params=None,
type_constraints=None):
if type_params is None:
type_params = []
if type_constraints is None:
type_constraints = []
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
return TypeVar(name, kind=TypeKind.ShapeVar)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : [tvm.relay.Type]
List of types to the func.
num_inputs : int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
Returns
-------
type_relation : tvm.relay.TypeRelation
The type relation.
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
@register_relay_node
class RefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value)
def scalar_type(dtype):
"""Creates a scalar type.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import Object, register_relay_node
from . import _make
class Type(Object):
"""The base type for all Relay types."""
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_eq(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters
----------
shape: list of tvm.Expr
dtype: str
Returns
-------
tensor_type: The TensorType
"""
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
ShapeVar = 0
Shape = 1
BaseType = 2
Type = 3
@register_relay_node
class TypeParam(Type):
"""A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details.
A type parameter represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
"""
def __init__(self, var, kind):
"""Construct a TypeParam.
Parameters
----------
var: tvm.expr.Var
The tvm.Var which backs the type parameter.
kind: Kind
The kind of the type parameter.
Returns
-------
type_param: TypeParam
The type parameter.
"""
self.__init_handle_by_constructor__(_make.TypeParam, var, kind)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields: list of tvm.Type
Returns
-------
tuple_type: the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
functions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
"""
def __init__(self,
arg_types,
ret_type,
type_params,
type_constraints,
):
"""Construct a function type.
Parameters
----------
arg_types: list of Type
ret_type: Type
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns
-------
func_type: FuncType
The function type.
"""
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
......@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, unused-argument
"""FFI for tvm.runtime.extra"""
"""FFI for tvm.node"""
import tvm._ffi
# The implementations below are default ones when the corresponding
......
......@@ -19,11 +19,11 @@ import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import container as _container
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container
@tvm._ffi.register_object
......
......@@ -366,6 +366,14 @@ class Prefetch(Stmt):
_make.Prefetch, func, value_index, dtype, bounds)
@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
def stmt_seq(*args):
"""Make sequence of statements
......
......@@ -38,7 +38,7 @@ Constructor::Constructor(std::string name_hint,
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.Constructor")
TVM_REGISTER_GLOBAL("ir.Constructor")
.set_body_typed([](std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
......@@ -64,7 +64,7 @@ TypeData::TypeData(GlobalTypeVar header,
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("relay._make.TypeData")
TVM_REGISTER_GLOBAL("ir.TypeData")
.set_body_typed([](GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
......
......@@ -334,7 +334,7 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}
TVM_REGISTER_GLOBAL("_AttrsListFieldInfo")
TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Attrs()->ListFieldInfo();
});
......
......@@ -50,10 +50,10 @@ EnvFunc EnvFunc::Get(const std::string& name) {
return EnvFunc(CreateEnvNode(name));
}
TVM_REGISTER_GLOBAL("_EnvFuncGet")
TVM_REGISTER_GLOBAL("ir.EnvFuncGet")
.set_body_typed(EnvFunc::Get);
TVM_REGISTER_GLOBAL("_EnvFuncCall")
TVM_REGISTER_GLOBAL("ir.EnvFuncCall")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnvFunc env = args[0];
CHECK_GE(args.size(), 1);
......@@ -62,7 +62,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall")
args.size() - 1), rv);
});
TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc")
TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc")
.set_body_typed([](const EnvFunc&n) {
return n->func;
});
......
......@@ -154,7 +154,7 @@ GlobalVar::GlobalVar(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
TVM_REGISTER_GLOBAL("ir.GlobalVar")
.set_body_typed([](std::string name){
return GlobalVar(name);
});
......
......@@ -338,13 +338,13 @@ IRModule IRModule::FromText(const std::string& text, const std::string& source_p
TVM_REGISTER_NODE_TYPE(IRModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module")
TVM_REGISTER_GLOBAL("ir.IRModule")
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types) {
return IRModule(funcs, types, {});
});
TVM_REGISTER_GLOBAL("relay._module.Module_Add")
TVM_REGISTER_GLOBAL("ir.Module_Add")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IRModule mod = args[0];
GlobalVar var = args[1];
......@@ -369,67 +369,67 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
*ret = mod;
});
TVM_REGISTER_GLOBAL("relay._module.Module_AddDef")
TVM_REGISTER_GLOBAL("ir.Module_AddDef")
.set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar")
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVars")
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVars")
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
TVM_REGISTER_GLOBAL("relay._module.Module_ContainGlobalVar")
TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar")
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup")
TVM_REGISTER_GLOBAL("ir.Module_Lookup")
.set_body_typed([](IRModule mod, GlobalVar var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str")
TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
.set_body_typed([](IRModule mod, std::string var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef")
TVM_REGISTER_GLOBAL("ir.Module_LookupDef")
.set_body_typed([](IRModule mod, GlobalTypeVar var) {
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str")
TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
.set_body_typed([](IRModule mod, std::string var) {
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
TVM_REGISTER_GLOBAL("ir.Module_LookupTag")
.set_body_typed([](IRModule mod, int32_t tag) {
return mod->LookupTag(tag);
});
TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
.set_body_typed([](RelayExpr e,
tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return IRModule::FromExpr(e, funcs, type_defs);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Update")
TVM_REGISTER_GLOBAL("ir.Module_Update")
.set_body_typed([](IRModule mod, IRModule from) {
mod->Update(from);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Import")
TVM_REGISTER_GLOBAL("ir.Module_Import")
.set_body_typed([](IRModule mod, std::string path) {
mod->Import(path);
});
TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
.set_body_typed([](IRModule mod, std::string path) {
mod->ImportFromStd(path);
});;
......
......@@ -45,7 +45,7 @@ SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name));
}
TVM_REGISTER_GLOBAL("relay._make.SourceName")
TVM_REGISTER_GLOBAL("ir.SourceName")
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -70,7 +70,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
TVM_REGISTER_GLOBAL("ir.Span")
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
......@@ -55,7 +55,7 @@ PrimExpr TensorTypeNode::Size() const {
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TensorType")
TVM_REGISTER_GLOBAL("ir.TensorType")
.set_body_typed([](Array<PrimExpr> shape, DataType dtype) {
return TensorType(shape, dtype);
});
......
......@@ -300,10 +300,15 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const {
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
CHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
const runtime::PackedFunc* f = nullptr;
if (pass_name.find("transform.") != std::string::npos) {
f = Registry::Get(pass_name);
} else if ((f = Registry::Get("transform." + pass_name))) {
// pass
} else if ((f = Registry::Get("relay._transform." + pass_name))) {
}
CHECK(f != nullptr) << "Cannot use " << pass_name
<< "to create the pass";
return (*f)();
}
......@@ -311,7 +316,7 @@ Pass GetPass(const std::string& pass_name) {
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(const IRModule& module,
const PassContext& pass_ctx) const {
const PassContext& pass_ctx) const {
IRModule mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
......@@ -339,12 +344,12 @@ Pass CreateModulePass(
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("relay._transform.PassInfo")
TVM_REGISTER_GLOBAL("transform.PassInfo")
.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
return PassInfo(opt_level, name, required);
});
TVM_REGISTER_GLOBAL("relay._transform.Info")
TVM_REGISTER_GLOBAL("transform.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
*ret = pass->Info();
......@@ -366,14 +371,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass")
TVM_REGISTER_GLOBAL("transform.MakeModulePass")
.set_body_typed(
[](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) {
return ModulePass(pass_func, pass_info);
});
TVM_REGISTER_GLOBAL("relay._transform.RunPass")
TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
IRModule mod = args[1];
......@@ -390,7 +395,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(SequentialNode);
TVM_REGISTER_GLOBAL("relay._transform.Sequential")
TVM_REGISTER_GLOBAL("transform.Sequential")
.set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
......@@ -416,7 +421,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_GLOBAL("relay._transform.PassContext")
TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto pctx = PassContext::Create();
int opt_level = args[0];
......@@ -465,13 +470,13 @@ class PassContext::Internal {
}
};
TVM_REGISTER_GLOBAL("relay._transform.GetCurrentPassContext")
TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext")
.set_body_typed(PassContext::Current);
TVM_REGISTER_GLOBAL("relay._transform.EnterPassContext")
TVM_REGISTER_GLOBAL("transform.EnterPassContext")
.set_body_typed(PassContext::Internal::EnterScope);
TVM_REGISTER_GLOBAL("relay._transform.ExitPassContext")
TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope);
} // namespace transform
......
......@@ -33,7 +33,7 @@ PrimType::PrimType(runtime::DataType dtype) {
TVM_REGISTER_NODE_TYPE(PrimTypeNode);
TVM_REGISTER_GLOBAL("relay._make.PrimType")
TVM_REGISTER_GLOBAL("ir.PrimType")
.set_body_typed([](runtime::DataType dtype) {
return PrimType(dtype);
});
......@@ -54,7 +54,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.TypeVar")
TVM_REGISTER_GLOBAL("ir.TypeVar")
.set_body_typed([](std::string name, int kind) {
return TypeVar(name, static_cast<TypeKind>(kind));
});
......@@ -76,7 +76,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
TVM_REGISTER_GLOBAL("ir.GlobalTypeVar")
.set_body_typed([](std::string name, int kind) {
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});
......@@ -102,7 +102,7 @@ FuncType::FuncType(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType")
TVM_REGISTER_GLOBAL("ir.FuncType")
.set_body_typed([](tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
......@@ -131,7 +131,7 @@ TupleType TupleType::Empty() {
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TupleType")
TVM_REGISTER_GLOBAL("ir.TupleType")
.set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});
......@@ -151,7 +151,7 @@ IncompleteType::IncompleteType(TypeKind kind) {
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
TVM_REGISTER_GLOBAL("ir.IncompleteType")
.set_body_typed([](int kind) {
return IncompleteType(static_cast<TypeKind>(kind));
});
......@@ -169,7 +169,7 @@ RelayRefType::RelayRefType(Type value) {
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.RefType")
TVM_REGISTER_GLOBAL("ir.RelayRefType")
.set_body_typed([](Type value) {
return RelayRefType(value);
});
......
......@@ -35,7 +35,7 @@ TypeCall::TypeCall(Type func, tvm::Array<Type> args) {
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_GLOBAL("relay._make.TypeCall")
TVM_REGISTER_GLOBAL("ir.TypeCall")
.set_body_typed([](Type func, Array<Type> type) {
return TypeCall(func, type);
});
......@@ -61,7 +61,7 @@ TypeRelation::TypeRelation(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
TVM_REGISTER_GLOBAL("ir.TypeRelation")
.set_body_typed([](TypeRelationFn func,
Array<Type> args,
int num_inputs,
......
......@@ -131,6 +131,7 @@ class RelayTextPrinter :
} else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<IRModule>(node));
} else {
// default module.
std::ostringstream os;
os << node;
return Doc() << os.str();
......@@ -905,20 +906,18 @@ static const char* kSemVer = "v0.0.4";
// - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR)
std::string PrettyPrint(const ObjectRef& node) {
Doc doc;
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
return AsText(node, false, nullptr);
}
std::string AsText(const ObjectRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) {
Doc doc;
doc << kSemVer << Doc::NewLine()
<< relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
doc << kSemVer << Doc::NewLine();
doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
TVM_REGISTER_GLOBAL("relay._expr.AsText")
TVM_REGISTER_GLOBAL("ir.AsText")
.set_body_typed(AsText);
} // namespace tvm
......@@ -599,6 +599,11 @@ TVM_REGISTER_GLOBAL("relay._make._alpha_equal")
return AlphaEqualHandler(false, false).Equal(a, b);
});
TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
.set_body_typed([](Type a, Type b) {
return AlphaEqual(a, b);
});
TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
......
......@@ -33,7 +33,7 @@ using namespace tvm::runtime;
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_GLOBAL("relay._base.set_span")
TVM_REGISTER_GLOBAL("ir.NodeSetSpan")
.set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) {
rn->span = sp;
......
......@@ -84,7 +84,7 @@ def check_server_drop():
f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11
except tvm.TVMError as e:
except tvm.error.TVMError as e:
pass
remote3 = tclient.request("abc")
f1 = remote3.get_function("rpc.test2.addone")
......
......@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import tvm
from tvm import relay
from tvm.relay import transform
import model_zoo
......@@ -99,7 +101,7 @@ def test_multi_outputs():
z = F.split(x, **kwargs)
z = F.subtract(F.add(z[0], z[2]), y)
func = relay.Function(relay.analysis.free_vars(z), z)
return relay.Module.from_expr(func)
return tvm.IRModule.from_expr(func)
mx_sym = mx_compose(mx, num_outputs=3, axis=1)
mod, _ = relay.frontend.from_mxnet(
......
......@@ -34,7 +34,7 @@ def test_mkldnn_dequantize():
max_range=max_range,
in_dtype=in_dtype)
mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output)
mod = relay.Module.from_expr(mod)
mod = tvm.IRModule.from_expr(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
......@@ -90,7 +90,7 @@ def test_mkldnn_quantize():
max_range=max_range,
out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = tvm.IRModule.from_expr(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
......
......@@ -23,7 +23,7 @@ from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_val
import numpy as np
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
add_nat_definitions(p)
......@@ -730,7 +730,7 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32",
def test_tensor_expand_dims():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
expand_dims_func = p.get_var('tensor_expand_dims', dtype)
tensor1 = p.get_var('tensor1', dtype)
......@@ -745,7 +745,7 @@ def test_tensor_expand_dims():
def test_tensor_array_constructor():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([x], tensor_array(x))
......@@ -757,7 +757,7 @@ def test_tensor_array_constructor():
def test_tensor_array_read():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
l = relay.var('l')
i = relay.var('i')
......@@ -773,7 +773,7 @@ def test_tensor_array_read():
def test_tensor_array_write():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
v1 = relay.var('v1')
v2 = relay.var('v2')
......@@ -793,7 +793,7 @@ def test_tensor_array_write():
def test_tensor_array_stack():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
tensor1 = p.get_var('tensor1', dtype)
......@@ -815,7 +815,7 @@ def test_tensor_array_stack():
def test_tensor_array_unstack():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
v = relay.var('v')
......@@ -828,7 +828,7 @@ def test_tensor_array_unstack():
def test_tensor_take():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
take = p.get_var('tensor_take', dtype)
tensor2 = p.get_var('tensor2', dtype)
......@@ -847,7 +847,7 @@ def test_tensor_take():
def test_tensor_concatenate():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
concat = p.get_var('tensor_concatenate', dtype)
tensor1 = p.get_var('tensor1', dtype)
......@@ -865,7 +865,7 @@ def test_tensor_concatenate():
def test_tensor_array_concat():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
v1 = relay.var('v1')
v2 = relay.var('v2')
......@@ -888,9 +888,9 @@ def test_tensor_array_concat():
def test_tensor_array_scatter():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
# tensor array
v1 = relay.var('v1')
v2 = relay.var('v2')
......@@ -938,9 +938,9 @@ def test_tensor_array_scatter():
def test_tensor_array_split():
def run(dtype):
mod = relay.Module()
mod = tvm.IRModule()
p = Prelude(mod)
# tensor array
v1 = relay.var('v1')
v2 = relay.var('v2')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment