Unverified Commit a5661611 by Tianqi Chen Committed by GitHub

[REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding files (#4862)

* [REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding relay files.

This PR establishes tvm.ir and migrates the corresponding relay
files into the new folder.

API Change:
- relay.Module -> tvm.IRModule

* Update with ADT

* Migrate transform

* address comments

* Migrate module

* Migrate json_compact

* Migrate attrs

* Move LoweredFunc to stmt temporarily

* temp migrate container

* Finish migrate container
parent 15df204f
...@@ -34,7 +34,7 @@ def get_network(name, batch_size, dtype='float32'): ...@@ -34,7 +34,7 @@ def get_network(name, batch_size, dtype='float32'):
Returns Returns
------- -------
net: relay.Module net: tvm.IRModule
The relay function of network definition The relay function of network definition
params: dict params: dict
The random parameters for benchmark The random parameters for benchmark
...@@ -70,7 +70,7 @@ def get_network(name, batch_size, dtype='float32'): ...@@ -70,7 +70,7 @@ def get_network(name, batch_size, dtype='float32'):
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = net["main"] net = net["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
net = relay.Module.from_expr(net) net = tvm.IRModule.from_expr(net)
else: else:
raise ValueError("Unsupported network: " + name) raise ValueError("Unsupported network: " + name)
......
...@@ -21,8 +21,6 @@ The user facing API for computation declaration. ...@@ -21,8 +21,6 @@ The user facing API for computation declaration.
.. autosummary:: .. autosummary::
tvm.load_json
tvm.save_json
tvm.var tvm.var
tvm.size_var tvm.size_var
tvm.const tvm.const
...@@ -47,8 +45,7 @@ The user facing API for computation declaration. ...@@ -47,8 +45,7 @@ The user facing API for computation declaration.
tvm.max tvm.max
tvm.tag_scope tvm.tag_scope
.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json
.. autofunction:: tvm.var .. autofunction:: tvm.var
.. autofunction:: tvm.size_var .. autofunction:: tvm.size_var
.. autofunction:: tvm.const .. autofunction:: tvm.const
......
...@@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr { ...@@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr {
class GlobalVar; class GlobalVar;
/*! /*!
* \brief Global variable that leaves in the top-level module. * \brief Global variable that lives in the top-level module.
* *
* A GlobalVar only refers to function definitions. * A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function. * This is used to enable recursive calls between function.
......
...@@ -141,11 +141,12 @@ enum TypeKind : int { ...@@ -141,11 +141,12 @@ enum TypeKind : int {
}; };
/*! /*!
* \brief Type parameter in the function. * \brief Type parameter in functions.
* This can be viewed as template parameter in c++ template function. *
* A type variable can be viewed as template parameter in c++ template function.
* *
* For example, in the following pesudo code, * For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n). * the TypeVar of f is TypeVar("n", kind=kShapeVar).
* This function can take in a Tensor with shape=(3, 3) and * This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,) * returns a Tensor with shape=(9,)
* *
......
...@@ -165,7 +165,7 @@ using TypeRelationFn = ...@@ -165,7 +165,7 @@ using TypeRelationFn =
const TypeReporter& reporter)>; const TypeReporter& reporter)>;
/*! /*!
* \brief User defined type relation, is an input-output relation on types. * \brief User defined type relation, it is an input-output relation on types.
* *
* TypeRelation is more generalized than type call as it allows inference * TypeRelation is more generalized than type call as it allows inference
* of both inputs and outputs. * of both inputs and outputs.
......
...@@ -33,6 +33,12 @@ from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl ...@@ -33,6 +33,12 @@ from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import ndarray as nd from .runtime import ndarray as nd
# tvm.ir
from .ir import IRModule
from .ir import transform
from .ir import container
from . import ir
# others # others
from . import tensor from . import tensor
from . import arith from . import arith
...@@ -41,10 +47,8 @@ from . import stmt ...@@ -41,10 +47,8 @@ from . import stmt
from . import make from . import make
from . import ir_pass from . import ir_pass
from . import codegen from . import codegen
from . import container
from . import schedule from . import schedule
from . import attrs
from . import ir_builder from . import ir_builder
from . import target from . import target
from . import generic from . import generic
......
...@@ -87,6 +87,7 @@ class ObjectBase(object): ...@@ -87,6 +87,7 @@ class ObjectBase(object):
instead of creating a new Node. instead of creating a new Node.
""" """
# assign handle first to avoid error raising # assign handle first to avoid error raising
# pylint: disable=not-callable
self.handle = None self.handle = None
handle = __init_by_constructor__(fconstructor, args) handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle): if not isinstance(handle, ObjectHandle):
......
...@@ -19,9 +19,11 @@ ...@@ -19,9 +19,11 @@
from numbers import Integral as _Integral from numbers import Integral as _Integral
import tvm._ffi import tvm._ffi
import tvm.runtime._ffi_node_api import tvm.ir
from tvm.runtime import convert, const, DataType from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container
from ._ffi.base import string_types, TVMError from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
...@@ -30,9 +32,7 @@ from . import make as _make ...@@ -30,9 +32,7 @@ from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from . import container as _container
from . import tag as _tag from . import tag as _tag
from . import json_compact
int8 = "int8" int8 = "int8"
int32 = "int32" int32 = "int32"
...@@ -71,66 +71,6 @@ def max_value(dtype): ...@@ -71,66 +71,6 @@ def max_value(dtype):
""" """
return _api_internal._max_value(dtype) return _api_internal._max_value(dtype)
def get_env_func(name):
"""Get an EnvFunc by a global name.
Parameters
----------
name: str
The name of the global function.
Returns
-------
env_func : EnvFunc
The result env function.
Note
----
EnvFunc is a Object wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
return _api_internal._EnvFuncGet(name)
def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Object
The loaded tvm node.
"""
try:
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
def save_json(node):
"""Save tvm object as json string.
Parameters
----------
node : Object
A TVM object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
def var(name="tindex", dtype=int32): def var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype """Create a new variable with specified name and dtype
...@@ -688,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''): ...@@ -688,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
raise TypeError("need to be list of ranges") raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1]) dom = Range(dom[0], dom[1])
if not isinstance(dom, _container.Range): if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range") raise TypeError("dom need to be Range")
name = name if name else 'iter' name = name if name else 'iter'
v = var(name) v = var(name)
......
...@@ -141,7 +141,7 @@ class BaseGraphTuner(object): ...@@ -141,7 +141,7 @@ class BaseGraphTuner(object):
self._logger.propagate = False self._logger.propagate = False
# Generate workload and schedule dictionaries. # Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module): if isinstance(graph, tvm.IRModule):
graph = graph["main"] graph = graph["main"]
if isinstance(graph, relay.expr.Function): if isinstance(graph, relay.expr.Function):
......
...@@ -20,6 +20,7 @@ import threading ...@@ -20,6 +20,7 @@ import threading
import topi import topi
import tvm
from tvm import relay, autotvm from tvm import relay, autotvm
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
...@@ -83,7 +84,7 @@ def expr2graph(expr, target_ops, node_dict, node_list): ...@@ -83,7 +84,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
def _infer_type(node): def _infer_type(node):
"""A method to infer the type of a relay expression.""" """A method to infer the type of a relay expression."""
mod = relay.Module.from_expr(node) mod = tvm.IRModule.from_expr(node)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(node, relay.Function) else entry.body return entry if isinstance(node, relay.Function) else entry.body
...@@ -136,7 +137,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -136,7 +137,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
free_var = relay.Var("var_%d" % i, input_type) free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var) params.append(free_var)
call = relay.Call(node.op, params, node.attrs) call = relay.Call(node.op, params, node.attrs)
mod = relay.Module.from_expr(relay.Function(params, call)) mod = tvm.IRModule.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear() relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build, build_thread = threading.Thread(target=relay.build,
args=(mod, args=(mod,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
# pylint: disable=eval-used,invalid-name,too-many-arguments # pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions""" """Utility functions"""
import tvm
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
...@@ -136,7 +137,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): ...@@ -136,7 +137,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
rebind_dict[var] = updated_input_dict[var.name_hint] rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict) updated_expr = relay.expr.bind(expr, rebind_dict)
mod = relay.Module.from_expr(updated_expr) mod = tvm.IRModule.from_expr(updated_expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(updated_expr, relay.Function) else entry.body return entry if isinstance(updated_expr, relay.Function) else entry.body
...@@ -63,7 +63,7 @@ def extract_from_program(mod, params, ops, target, target_host=None, ...@@ -63,7 +63,7 @@ def extract_from_program(mod, params, ops, target, target_host=None,
Parameters Parameters
---------- ----------
mod: relay.module.Module or relay.expr.Function mod: tvm.IRModule or relay.expr.Function
The module or function to tune The module or function to tune
params: dict of str to numpy array params: dict of str to numpy array
The associated parameters of the program The associated parameters of the program
...@@ -95,7 +95,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -95,7 +95,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
Parameters Parameters
---------- ----------
mods: List[relay.module.Module] or List[relay.expr.Function] mods: List[tvm.IRModule] or List[relay.expr.Function]
The list of modules or functions to tune The list of modules or functions to tune
params: List of dict of str to numpy array params: List of dict of str to numpy array
The associated parameters of the programs The associated parameters of the programs
...@@ -151,8 +151,8 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -151,8 +151,8 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
for mod, param in zip(mods, params): for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function): if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod) mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, relay.module.Module), \ assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned" "only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear() relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems # wrap build call in thread to avoid multiprocessing problems
......
...@@ -24,6 +24,7 @@ import tvm._ffi ...@@ -24,6 +24,7 @@ import tvm._ffi
import tvm.runtime import tvm.runtime
from tvm.runtime import Object, ndarray from tvm.runtime import Object, ndarray
from tvm.ir import container
from . import api from . import api
from . import _api_internal from . import _api_internal
from . import tensor from . import tensor
...@@ -31,10 +32,11 @@ from . import schedule ...@@ -31,10 +32,11 @@ from . import schedule
from . import expr from . import expr
from . import ir_pass from . import ir_pass
from . import stmt as _stmt from . import stmt as _stmt
from . import container
from . import codegen from . import codegen
from . import target as _target from . import target as _target
from . import make from . import make
from .stmt import LoweredFunc
class DumpIR(object): class DumpIR(object):
""" """
...@@ -58,16 +60,16 @@ class DumpIR(object): ...@@ -58,16 +60,16 @@ class DumpIR(object):
def dump(*args, **kwargs): def dump(*args, **kwargs):
"""dump function""" """dump function"""
retv = func(*args, **kwargs) retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)): if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)):
return retv return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__ fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc" pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f: with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv out = retv.body if isinstance(retv, LoweredFunc) else retv
f.write(str(out)) f.write(str(out))
if isinstance(retv, container.Array): if isinstance(retv, container.Array):
for x in retv: for x in retv:
out = x.body if isinstance(x, container.LoweredFunc) else x out = x.body if isinstance(x, LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out))) f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1 self._pass_id += 1
return retv return retv
...@@ -459,7 +461,7 @@ def _build_for_device(flist, target, target_host): ...@@ -459,7 +461,7 @@ def _build_for_device(flist, target, target_host):
raise ValueError( raise ValueError(
"Direct host side access to device memory is detected in %s. " "Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name) "Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc: if func.func_type == LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier: if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "shared")
...@@ -471,9 +473,9 @@ def _build_for_device(flist, target, target_host): ...@@ -471,9 +473,9 @@ def _build_for_device(flist, target, target_host):
fhost.append(fsplits[0]) fhost.append(fsplits[0])
for x in fsplits[1:]: for x in fsplits[1:]:
fdevice.append(x) fdevice.append(x)
elif func.func_type == container.LoweredFunc.HostFunc: elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func) fhost.append(func)
elif func.func_type == container.LoweredFunc.DeviceFunc: elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func) fdevice.append(func)
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
...@@ -586,9 +588,9 @@ def build(inputs, ...@@ -586,9 +588,9 @@ def build(inputs,
flist = lower(inputs, args, flist = lower(inputs, args,
name=name, name=name,
binds=binds) binds=binds)
if isinstance(flist, container.LoweredFunc): if isinstance(flist, LoweredFunc):
flist = [flist] flist = [flist]
elif isinstance(inputs, container.LoweredFunc): elif isinstance(inputs, LoweredFunc):
if args: if args:
raise ValueError("args must be done when build from LoweredFunc.") raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs] flist = [inputs]
...@@ -612,7 +614,7 @@ def build(inputs, ...@@ -612,7 +614,7 @@ def build(inputs,
"_target.Target when inputs is dict.") "_target.Target when inputs is dict.")
fname_set = set() fname_set = set()
for x in flist: for x in flist:
if not isinstance(x, container.LoweredFunc): if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list " raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of " "of LoweredFunc, or dict of str to list of "
"LoweredFunc.") "LoweredFunc.")
......
...@@ -38,7 +38,7 @@ class CSRNDArray(object): ...@@ -38,7 +38,7 @@ class CSRNDArray(object):
The corresponding a dense numpy array, The corresponding a dense numpy array,
or a tuple for constructing a sparse matrix directly. or a tuple for constructing a sparse matrix directly.
ctx: tvm.TVMContext ctx: tvmContext
The corresponding context. The corresponding context.
shape : tuple of int shape : tuple of int
......
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
# under the License. # under the License.
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time """Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support.""" semantic support."""
from tvm.ir.container import Array
from .. import api as _api from .. import api as _api
from .. import expr as _expr from .. import expr as _expr
from .. import make as _make from .. import make as _make
from .. import target as _tgt from .. import target as _tgt
from ..container import Array
from .. import ir_pass from .. import ir_pass
from ..stmt import For from ..stmt import For
from .util import _internal_assert from .util import _internal_assert
......
...@@ -24,6 +24,7 @@ import types ...@@ -24,6 +24,7 @@ import types
import numbers import numbers
from enum import Enum from enum import Enum
from tvm.ir.container import Array
from .util import _internal_assert from .util import _internal_assert
from . import calls from . import calls
...@@ -32,7 +33,6 @@ from .preprocessor import determine_variable_usage ...@@ -32,7 +33,6 @@ from .preprocessor import determine_variable_usage
from ..api import all as _all from ..api import all as _all
from ..api import any as _any from ..api import any as _any
from ..container import Array
from ..tensor import Tensor, Operation from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal from .. import _api_internal as _tvm_internal
from .. import expr as _expr from .. import expr as _expr
......
...@@ -21,13 +21,14 @@ import inspect ...@@ -21,13 +21,14 @@ import inspect
import logging import logging
import sys import sys
import numpy import numpy
from tvm.ir.container import Array
from .. import api as _api from .. import api as _api
from .. import make as _make from .. import make as _make
from .. import expr as _expr from .. import expr as _expr
from .. import stmt as _stmt from .. import stmt as _stmt
from .._ffi.base import numeric_types from .._ffi.base import numeric_types
from ..tensor import Tensor from ..tensor import Tensor
from ..container import Array
#pylint: disable=invalid-name #pylint: disable=invalid-name
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .type_relation import TypeCall, TypeRelation
from .tensor_type import TensorType
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs
from .container import Array, Map
from . import transform
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable """FFI APIs for tvm.ir"""
"""The interface to the Module exposed from C++."""
import tvm._ffi import tvm._ffi
tvm._ffi._init_api("relay._module", __name__)
tvm._ffi._init_api("ir", __name__)
# Licensed to the Apache Software Foundation (ASF) under one # Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file # or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information # distributed with this work for additional information
...@@ -14,9 +15,8 @@ ...@@ -14,9 +15,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI APIs for tvm.transform"""
import tvm._ffi
from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
class Module(Object): ... tvm._ffi._init_api("transform", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Algebraic data type definitions."""
import tvm._ffi
from .type import Type
from .expr import RelayExpr
from . import _ffi_api
@tvm._ffi.register_object("relay.Constructor")
class Constructor(RelayExpr):
"""Relay ADT constructor.
Parameters
----------
name_hint : str
Name of constructor (only a hint).
inputs : List[Type]
Input types.
belong_to : GlobalTypeVar
Denotes which ADT the constructor belongs to.
"""
def __init__(self, name_hint, inputs, belong_to):
self.__init_handle_by_constructor__(
_ffi_api.Constructor, name_hint, inputs, belong_to)
def __call__(self, *args):
"""Call the constructor.
Parameters
----------
args: List[RelayExpr]
The arguments to the constructor.
Returns
-------
call: RelayExpr
A call to the constructor.
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
return relay.Call(self, args)
@tvm._ffi.register_object("relay.TypeData")
class TypeData(Type):
"""Stores the definition for an Algebraic Data Type (ADT) in Relay.
Note that ADT definitions are treated as type-level functions because
the type parameters need to be given for an instance of the ADT. Thus,
any global type var that is an ADT header needs to be wrapped in a
type call that passes in the type params.
Parameters
----------
header: GlobalTypeVar
The name of the ADT.
ADTs with the same constructors but different names are
treated as different types.
type_vars: List[TypeVar]
Type variables that appear in constructors.
constructors: List[Constructor]
The constructors for the ADT.
"""
def __init__(self, header, type_vars, constructors):
self.__init_handle_by_constructor__(
_ffi_api.TypeData, header, type_vars, constructors)
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" TVM Attribute module, which is mainly used for defining attributes of operators""" """ TVM Attribute module, which is mainly used for defining attributes of operators."""
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from . import _api_internal from . import _ffi_api
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -36,7 +36,7 @@ class Attrs(Object): ...@@ -36,7 +36,7 @@ class Attrs(Object):
infos: list of AttrFieldInfo infos: list of AttrFieldInfo
List of field information List of field information
""" """
return _api_internal._AttrsListFieldInfo(self) return _ffi_api.AttrsListFieldInfo(self)
def keys(self): def keys(self):
"""Get list of names in the attribute. """Get list of names in the attribute.
...@@ -91,6 +91,3 @@ class Attrs(Object): ...@@ -91,6 +91,3 @@ class Attrs(Object):
def __getitem__(self, item): def __getitem__(self, item):
return self.__getattr__(item) return self.__getattr__(item)
tvm._ffi._init_api("tvm.attrs")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common base structures."""
import tvm._ffi
import tvm.error
import tvm.runtime._ffi_node_api
from tvm.runtime import Object
from . import _ffi_api
from . import json_compact
class Node(Object):
"""Base class of all IR Nodes, implements astext function."""
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Parameters
----------
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[Object->str]
Optionally annotate function to provide additional
information in the comment block.
Note
----
The meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big (e.g constant weights),
so it can be helpful to skip printing the meta data section.
Returns
-------
text : str
The text format of the expression.
"""
return _ffi_api.AsText(self, show_meta_data, annotate)
def __str__(self):
return self.astext(show_meta_data=False)
@tvm._ffi.register_object("relay.SourceName")
class SourceName(Object):
"""A identifier for a source location.
Parameters
----------
name : str
The name of the source.
"""
def __init__(self, name):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
@tvm._ffi.register_object("relay.Span")
class Span(Object):
"""Specifies a location in a source program.
Parameters
----------
source : SourceName
The source name.
lineno : int
The line number.
col_offset : int
The column offset of the location.
"""
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(
_ffi_api.Span, source, lineno, col_offset)
@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _ffi_api.EnvFuncCall(self, *args)
@property
def func(self):
return _ffi_api.EnvFuncGetPackedFunc(self)
@staticmethod
def get(name):
"""Get a static env function
Parameters
----------
name : str
The name of the function.
"""
return _ffi_api.EnvFuncGet(name)
def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Object
The loaded tvm node.
"""
try:
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except tvm.error.TVMError:
json_str = json_compact.upgrade_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
def save_json(node):
"""Save tvm object as json string.
Parameters
----------
node : Object
A TVM object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Container data structures used in TVM DSL.""" """Additional container data structures used across IR variants."""
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from tvm.runtime.container import getitem_helper from tvm.runtime.container import getitem_helper
from tvm.runtime import _ffi_node_api from tvm.runtime import _ffi_node_api
from . import _api_internal
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -41,20 +40,6 @@ class Array(Object): ...@@ -41,20 +40,6 @@ class Array(Object):
@tvm._ffi.register_object @tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _api_internal._EnvFuncCall(self, *args)
@property
def func(self):
return _api_internal._EnvFuncGetPackedFunc(self)
@tvm._ffi.register_object
class Map(Object): class Map(Object):
"""Map container of TVM. """Map container of TVM.
...@@ -87,20 +72,3 @@ class StrMap(Map): ...@@ -87,20 +72,3 @@ class StrMap(Map):
"""Get the items from the map""" """Get the items from the map"""
akvs = _ffi_node_api.MapItems(self) akvs = _ffi_node_api.MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@tvm._ffi.register_object
class Range(Object):
"""Represent a range in TVM.
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
"""
@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common expressions data structures in the IR."""
import tvm._ffi
from .base import Node
from . import _ffi_api
class BaseExpr(Node):
"""Base class of all the expressions."""
class PrimExpr(BaseExpr):
"""Base class of all primitive expressions.
PrimExpr is used in the low-level code
optimizations and integer analysis.
"""
class RelayExpr(BaseExpr):
"""Base class of all non-primitive expressions."""
@property
def checked_type(self):
"""Get the checked type of tvm.relay.Expr.
Returns
-------
checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
" the checked_type for this node")
return ret
class BaseFunc(RelayExpr):
"""Base class of all functions."""
@tvm._ffi.register_object("relay.GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
GlobalVar is used to refer to the global functions
stored in the IRModule.
Parameters
----------
name_hint: str
The name of the variable.
"""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)
def __call__(self, *args):
"""Call the global variable.
Parameters
----------
args: List[RelayExpr]
The arguments to the call.
Returns
-------
call: BaseExpr
A call taking the variable as a function.
"""
# pylint: disable=import-outside-toplevel
if all(isinstance(x, RelayExpr) for x in args):
from tvm import relay
return relay.Call(self, args)
arg_types = [type(x) for x in args]
raise RuntimeError(
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types))
@tvm._ffi.register_object
class Range(Node):
"""Represent a range in TVM.
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
"""
...@@ -14,36 +14,26 @@ ...@@ -14,36 +14,26 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """IRModule that holds the functions and type definitions."""
"""A global module storing everything needed to interpret or compile a Relay program.""" from tvm._ffi.base import string_types
import os import tvm._ffi
from .base import register_relay_node, RelayNode
from .. import register_func
from .._ffi import base as _base
from . import _make
from . import _module
from . import expr as _expr
from . import ty as _ty
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") from .base import Node
from . import expr as _expr
from . import type as _ty
from . import _ffi_api
@register_func("tvm.relay.std_path")
def _std_path():
global __STD_PATH__
return __STD_PATH__
@register_relay_node @tvm._ffi.register_object("relay.Module")
class Module(RelayNode): class IRModule(Node):
"""The global Relay module containing collection of functions. """IRModule that holds functions and type definitions.
Each global function is identified by an unique tvm.relay.GlobalVar. IRModule is the basic unit for all IR transformations across the stack.
tvm.relay.GlobalVar and Module is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x
Parameters Parameters
---------- ----------
functions: Optional[dict]. functions: Optional[dict].
Map of global var to Function Map of global var to BaseFunc
""" """
def __init__(self, functions=None, type_definitions=None): def __init__(self, functions=None, type_definitions=None):
if functions is None: if functions is None:
...@@ -51,7 +41,7 @@ class Module(RelayNode): ...@@ -51,7 +41,7 @@ class Module(RelayNode):
elif isinstance(functions, dict): elif isinstance(functions, dict):
mapped_funcs = {} mapped_funcs = {}
for k, v in functions.items(): for k, v in functions.items():
if isinstance(k, _base.string_types): if isinstance(k, string_types):
k = _expr.GlobalVar(k) k = _expr.GlobalVar(k)
if not isinstance(k, _expr.GlobalVar): if not isinstance(k, _expr.GlobalVar):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]") raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
...@@ -62,13 +52,13 @@ class Module(RelayNode): ...@@ -62,13 +52,13 @@ class Module(RelayNode):
elif isinstance(type_definitions, dict): elif isinstance(type_definitions, dict):
mapped_type_defs = {} mapped_type_defs = {}
for k, v in type_definitions.items(): for k, v in type_definitions.items():
if isinstance(k, _base.string_types): if isinstance(k, string_types):
k = _ty.GlobalTypeVar(k) k = _ty.GlobalTypeVar(k)
if not isinstance(k, _ty.GlobalTypeVar): if not isinstance(k, _ty.GlobalTypeVar):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v mapped_type_defs[k] = v
type_definitions = mapped_type_defs type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
def __setitem__(self, var, val): def __setitem__(self, var, val):
...@@ -85,18 +75,18 @@ class Module(RelayNode): ...@@ -85,18 +75,18 @@ class Module(RelayNode):
return self._add(var, val) return self._add(var, val)
def _add(self, var, val, update=False): def _add(self, var, val, update=False):
if isinstance(val, _expr.Expr): if isinstance(val, _expr.RelayExpr):
if isinstance(var, _base.string_types): if isinstance(var, string_types):
if _module.Module_ContainGlobalVar(self, var): if _ffi_api.Module_ContainGlobalVar(self, var):
var = _module.Module_GetGlobalVar(self, var) var = _ffi_api.Module_GetGlobalVar(self, var)
else: else:
var = _expr.GlobalVar(var) var = _expr.GlobalVar(var)
_module.Module_Add(self, var, val, update) _ffi_api.Module_Add(self, var, val, update)
else: else:
assert isinstance(val, _ty.Type) assert isinstance(val, _ty.Type)
if isinstance(var, _base.string_types): if isinstance(var, string_types):
var = _ty.GlobalTypeVar(var) var = _ty.GlobalTypeVar(var)
_module.Module_AddDef(self, var, val, update) _ffi_api.Module_AddDef(self, var, val, update)
def __getitem__(self, var): def __getitem__(self, var):
"""Lookup a global definition by name or by variable. """Lookup a global definition by name or by variable.
...@@ -111,12 +101,11 @@ class Module(RelayNode): ...@@ -111,12 +101,11 @@ class Module(RelayNode):
val: Union[Function, Type] val: Union[Function, Type]
The definition referenced by :code:`var` (either a function or type). The definition referenced by :code:`var` (either a function or type).
""" """
if isinstance(var, _base.string_types): if isinstance(var, string_types):
return _module.Module_Lookup_str(self, var) return _ffi_api.Module_Lookup_str(self, var)
elif isinstance(var, _expr.GlobalVar): if isinstance(var, _expr.GlobalVar):
return _module.Module_Lookup(self, var) return _ffi_api.Module_Lookup(self, var)
else: return _ffi_api.Module_LookupDef(self, var)
return _module.Module_LookupDef(self, var)
def update(self, other): def update(self, other):
"""Insert functions in another Module to current one. """Insert functions in another Module to current one.
...@@ -128,7 +117,7 @@ class Module(RelayNode): ...@@ -128,7 +117,7 @@ class Module(RelayNode):
""" """
if isinstance(other, dict): if isinstance(other, dict):
other = Module(other) other = Module(other)
return _module.Module_Update(self, other) return _ffi_api.Module_Update(self, other)
def get_global_var(self, name): def get_global_var(self, name):
"""Get a global variable in the function by name. """Get a global variable in the function by name.
...@@ -145,9 +134,9 @@ class Module(RelayNode): ...@@ -145,9 +134,9 @@ class Module(RelayNode):
Raises Raises
------ ------
tvm.TVMError if we cannot find corresponding global var. tvm.error.TVMError if we cannot find corresponding global var.
""" """
return _module.Module_GetGlobalVar(self, name) return _ffi_api.Module_GetGlobalVar(self, name)
def get_global_vars(self): def get_global_vars(self):
"""Collect all global vars defined in this module. """Collect all global vars defined in this module.
...@@ -157,7 +146,7 @@ class Module(RelayNode): ...@@ -157,7 +146,7 @@ class Module(RelayNode):
global_vars: tvm.Array[GlobalVar] global_vars: tvm.Array[GlobalVar]
An array of global vars. An array of global vars.
""" """
return _module.Module_GetGlobalVars(self) return _ffi_api.Module_GetGlobalVars(self)
def get_global_type_vars(self): def get_global_type_vars(self):
"""Collect all global type vars defined in this module. """Collect all global type vars defined in this module.
...@@ -167,7 +156,7 @@ class Module(RelayNode): ...@@ -167,7 +156,7 @@ class Module(RelayNode):
global_type_vars: tvm.Array[GlobalTypeVar] global_type_vars: tvm.Array[GlobalTypeVar]
An array of global type vars. An array of global type vars.
""" """
return _module.Module_GetGlobalTypeVars(self) return _ffi_api.Module_GetGlobalTypeVars(self)
def get_global_type_var(self, name): def get_global_type_var(self, name):
"""Get a global type variable in the function by name. """Get a global type variable in the function by name.
...@@ -184,9 +173,9 @@ class Module(RelayNode): ...@@ -184,9 +173,9 @@ class Module(RelayNode):
Raises Raises
------ ------
tvm.TVMError if we cannot find corresponding global type var. tvm.error.TVMError if we cannot find corresponding global type var.
""" """
return _module.Module_GetGlobalTypeVar(self, name) return _ffi_api.Module_GetGlobalTypeVar(self, name)
def get_constructor(self, tag): def get_constructor(self, tag):
"""Look up an ADT constructor by tag. """Look up an ADT constructor by tag.
...@@ -203,9 +192,9 @@ class Module(RelayNode): ...@@ -203,9 +192,9 @@ class Module(RelayNode):
Raises Raises
------ ------
tvm.TVMError if the corresponding constructor cannot be found. tvm.error.TVMError if the corresponding constructor cannot be found.
""" """
return _module.Module_LookupTag(self, tag) return _ffi_api.Module_LookupTag(self, tag)
@staticmethod @staticmethod
def from_expr(expr, functions=None, type_defs=None): def from_expr(expr, functions=None, type_defs=None):
...@@ -213,14 +202,15 @@ class Module(RelayNode): ...@@ -213,14 +202,15 @@ class Module(RelayNode):
Parameters Parameters
---------- ----------
expr: Expr expr: RelayExpr
The starting expression The starting expression
global_funcs: Optional[dict] global_funcs: Optional[dict]
Map of global vars to function definitions Map of global vars to function definitions
type_defs: Optional[dict] type_defs: Optional[dict]
Map of global type vars to type definitions Map of global type vars to type definitions
Returns Returns
------- -------
mod: Module mod: Module
...@@ -230,10 +220,10 @@ class Module(RelayNode): ...@@ -230,10 +220,10 @@ class Module(RelayNode):
""" """
funcs = functions if functions is not None else {} funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {} defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs) return _ffi_api.Module_FromExpr(expr, funcs, defs)
def _import(self, file_to_import): def _import(self, file_to_import):
return _module.Module_Import(self, file_to_import) return _ffi_api.Module_Import(self, file_to_import)
def import_from_std(self, file_to_import): def import_from_std(self, file_to_import):
return _module.Module_ImportFromStd(self, file_to_import) return _ffi_api.Module_ImportFromStd(self, file_to_import)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Type relation and function for type checking."""
import tvm._ffi
from .type import Type
from . import _ffi_api
@tvm._ffi.register_object("relay.TensorType")
class TensorType(Type):
"""A concrete TensorType in Relay.
This is the type assigned to tensors with a known dtype and shape.
For example, a tensor of `float32` and `(5, 5)`.
Parameters
----------
shape : List[tvm.ir.PrimExpr]
The shape of the Tensor
dtype : Optional[str]
The content data type.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_ffi_api.TensorType, shape, dtype)
@property
def concrete_shape(self):
"""Get shape of the type as concrete tuple of int.
Returns
-------
shape : List[int]
The concrete shape of the Type.
Raises
------
TypeError : If the shape is symbolic
"""
return tuple(int(x) for x in self.shape)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unified type system in the project."""
from enum import IntEnum
import tvm._ffi
from .base import Node
from . import _ffi_api
class Type(Node):
"""The base class of all types."""
def __eq__(self, other):
"""Compare two types for structural equivalence."""
return bool(_ffi_api.type_alpha_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
class TypeKind(IntEnum):
"""Possible kinds of TypeVars."""
Type = 0
ShapeVar = 1
BaseType = 2
Constraint = 4
AdtHandle = 5
TypeData = 6
@tvm._ffi.register_object("relay.TypeVar")
class TypeVar(Type):
"""Type parameter in functions.
A type variable represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
Parameters
----------
name_hint: str
The name of the type variable. This name only acts as a hint, and
is not used for equality.
kind : Optional[TypeKind]
The kind of the type parameter.
"""
def __init__(self, name_hint, kind=TypeKind.Type):
self.__init_handle_by_constructor__(
_ffi_api.TypeVar, name_hint, kind)
def __call__(self, *args):
"""Create a type call from this type.
Parameters
----------
args: List[Type]
The arguments to the type call.
Returns
-------
call: Type
The result type call.
"""
# pylint: disable=import-outside-toplevel
from .type_relation import TypeCall
return TypeCall(self, args)
@tvm._ffi.register_object("relay.GlobalTypeVar")
class GlobalTypeVar(Type):
"""A global type variable that is used for defining new types or type aliases.
Parameters
----------
name_hint: str
The name of the type variable. This name only acts as a hint, and
is not used for equality.
kind : Optional[TypeKind]
The kind of the type parameter.
"""
def __init__(self, name_hint, kind=TypeKind.AdtHandle):
self.__init_handle_by_constructor__(
_ffi_api.GlobalTypeVar, name_hint, kind)
def __call__(self, *args):
"""Create a type call from this type.
Parameters
----------
args: List[Type]
The arguments to the type call.
Returns
-------
call: Type
The result type call.
"""
# pylint: disable=import-outside-toplevel
from .type_relation import TypeCall
return TypeCall(self, args)
@tvm._ffi.register_object("relay.TupleType")
class TupleType(Type):
"""The type of tuple values.
Parameters
----------
fields : List[Type]
The fields in the tuple
"""
def __init__(self, fields):
self.__init_handle_by_constructor__(
_ffi_api.TupleType, fields)
@tvm._ffi.register_object("relay.TypeConstraint")
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
@tvm._ffi.register_object("relay.FuncType")
class FuncType(Type):
"""Function type.
A function type consists of a list of type parameters to enable
the definition of generic functions,
a set of type constraints which we omit for the time being,
a sequence of argument types, and a return type.
We can informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
Parameters
----------
arg_types : List[tvm.relay.Type]
The argument types
ret_type : tvm.relay.Type
The return type.
type_params : Optional[List[tvm.relay.TypeVar]]
The type parameters
type_constraints : Optional[List[tvm.relay.TypeConstraint]]
The type constraints.
"""
def __init__(self,
arg_types,
ret_type,
type_params=None,
type_constraints=None):
if type_params is None:
type_params = []
if type_constraints is None:
type_constraints = []
self.__init_handle_by_constructor__(
_ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)
@tvm._ffi.register_object("relay.IncompleteType")
class IncompleteType(Type):
"""Incomplete type during type inference.
kind : Optional[TypeKind]
The kind of the incomplete type.
"""
def __init__(self, kind=TypeKind.Type):
self.__init_handle_by_constructor__(
_ffi_api.IncompleteType, kind)
@tvm._ffi.register_object("relay.RefType")
class RelayRefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_ffi_api.RelayRefType, value)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Type relation and function for type checking."""
import tvm._ffi
from .type import Type, TypeConstraint
from . import _ffi_api
class TypeCall(Type):
"""Type function application.
Parameters
----------
func: tvm.ir.Type
The function.
args: List[tvm.ir.Type]
The arguments.
Returns
-------
type_call: TypeCall
The type function application.
"""
def __init__(self, func, args):
self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)
@tvm._ffi.register_object("relay.TypeRelation")
class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types.
TypeRelation is more generalized than TypeCall as it allows inference
of both inputs and outputs.
Parameters
----------
func : EnvFunc
User defined relation function.
args : [tvm.ir.Type]
List of types to the func.
num_inputs : int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
Returns
-------
type_relation : tvm.ir.TypeRelation
The type relation.
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(
_ffi_api.TypeRelation, func, args, num_inputs, attrs)
...@@ -15,16 +15,15 @@ ...@@ -15,16 +15,15 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Developer API of IR node builder make function.""" """Developer API of IR node builder make function."""
from tvm._ffi.base import string_types
from tvm.runtime import ObjectGeneric, DataType from tvm.runtime import ObjectGeneric, DataType
from tvm.ir import container as _container
from ._ffi.base import string_types
from . import api as _api from . import api as _api
from . import stmt as _stmt from . import stmt as _stmt
from . import expr as _expr from . import expr as _expr
from . import make as _make from . import make as _make
from . import ir_pass as _pass from . import ir_pass as _pass
from . import container as _container
from .expr import Call as _Call from .expr import Call as _Call
class WithScope(object): class WithScope(object):
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# under the License. # under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name # pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler.""" """The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
import os import os
from sys import setrecursionlimit from sys import setrecursionlimit
from ..api import register_func from ..api import register_func
...@@ -25,7 +24,6 @@ from . import ty ...@@ -25,7 +24,6 @@ from . import ty
from . import expr from . import expr
from . import type_functor from . import type_functor
from . import expr_functor from . import expr_functor
from . import module
from . import adt from . import adt
from . import analysis from . import analysis
from . import transform from . import transform
...@@ -66,14 +64,11 @@ setrecursionlimit(10000) ...@@ -66,14 +64,11 @@ setrecursionlimit(10000)
# Span # Span
Span = base.Span Span = base.Span
# Env
Module = module.Module
# Type # Type
Type = ty.Type Type = ty.Type
TupleType = ty.TupleType TupleType = ty.TupleType
TensorType = ty.TensorType TensorType = ty.TensorType
Kind = ty.Kind TypeKind = ty.TypeKind
TypeVar = ty.TypeVar TypeVar = ty.TypeVar
ShapeVar = ty.ShapeVar ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint TypeConstraint = ty.TypeConstraint
...@@ -87,7 +82,7 @@ TypeCall = ty.TypeCall ...@@ -87,7 +82,7 @@ TypeCall = ty.TypeCall
Any = ty.Any Any = ty.Any
# Expr # Expr
Expr = expr.Expr Expr = expr.RelayExpr
Constant = expr.Constant Constant = expr.Constant
Tuple = expr.Tuple Tuple = expr.Tuple
Var = expr.Var Var = expr.Var
......
...@@ -37,8 +37,9 @@ except ImportError: ...@@ -37,8 +37,9 @@ except ImportError:
return deque.__new__(cls, *args, **kwds) return deque.__new__(cls, *args, **kwds)
import tvm import tvm
import tvm.ir._ffi_api
from tvm.ir import IRModule
from . import module
from .base import Span, SourceName from .base import Span, SourceName
from . import adt from . import adt
from . import expr from . import expr
...@@ -190,7 +191,7 @@ def spanify(f): ...@@ -190,7 +191,7 @@ def spanify(f):
sp = Span(sn, line, col) sp = Span(sn, line, col)
if isinstance(ast, tvm.relay.expr.TupleWrapper): if isinstance(ast, tvm.relay.expr.TupleWrapper):
ast = ast.astuple() ast = ast.astuple()
ast.set_span(sp) tvm.ir._ffi_api.NodeSetSpan(ast, sp)
return ast return ast
return _wrapper return _wrapper
...@@ -201,7 +202,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -201,7 +202,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def __init__(self, source_name: str) -> None: def __init__(self, source_name: str) -> None:
self.source_name = source_name self.source_name = source_name
self.module = module.Module({}) # type: module.Module self.module = IRModule({}) # type: IRModule
# Adding an empty scope allows naked lets without pain. # Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
...@@ -243,7 +244,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -243,7 +244,7 @@ class ParseTreeToRelayIR(RelayVisitor):
"""Pop off the current TypeVar scope and return it.""" """Pop off the current TypeVar scope and return it."""
return self.type_var_scopes.popleft() return self.type_var_scopes.popleft()
def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar: def mk_typ(self, name: str, kind: ty.TypeKind) -> ty.TypeVar:
"""Create a new TypeVar and add it to the TypeVar scope.""" """Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.TypeVar(name, kind) typ = ty.TypeVar(name, kind)
self.type_var_scopes[0].append((name, typ)) self.type_var_scopes[0].append((name, typ))
...@@ -274,7 +275,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -274,7 +275,7 @@ class ParseTreeToRelayIR(RelayVisitor):
if isinstance(e, adt.Constructor): if isinstance(e, adt.Constructor):
return "`{0}` ADT constructor".format(e.belong_to.name_hint) return "`{0}` ADT constructor".format(e.belong_to.name_hint)
if isinstance(e, ty.GlobalTypeVar): if isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle: if e.kind == ty.TypeKind.AdtHandle:
return "ADT definition" return "ADT definition"
return "function definition" return "function definition"
...@@ -352,12 +353,12 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -352,12 +353,12 @@ class ParseTreeToRelayIR(RelayVisitor):
return self.visit(ctx) return self.visit(ctx)
def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, module.Module]: def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]:
self.meta = None self.meta = None
if ctx.METADATA(): if ctx.METADATA():
header, data = str(ctx.METADATA()).split("\n", 1) header, data = str(ctx.METADATA()).split("\n", 1)
assert header == "METADATA:" assert header == "METADATA:"
self.meta = tvm.load_json(data) self.meta = tvm.ir.load_json(data)
if ctx.defn(): if ctx.defn():
self.visit_list(ctx.defn()) self.visit_list(ctx.defn())
return self.module return self.module
...@@ -492,7 +493,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -492,7 +493,7 @@ class ParseTreeToRelayIR(RelayVisitor):
assert type_params assert type_params
for ty_param in type_params: for ty_param in type_params:
name = ty_param.getText() name = ty_param.getText()
self.mk_typ(name, ty.Kind.Type) self.mk_typ(name, ty.TypeKind.Type)
var_list, attr_list = self.visit(ctx.argList()) var_list, attr_list = self.visit(ctx.argList())
if var_list is None: if var_list is None:
...@@ -528,13 +529,13 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -528,13 +529,13 @@ class ParseTreeToRelayIR(RelayVisitor):
ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]): ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]):
"""Handles parsing of the name and type params of an ADT definition.""" """Handles parsing of the name and type params of an ADT definition."""
adt_name = ctx.generalIdent().getText() adt_name = ctx.generalIdent().getText()
adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle) adt_var = self.mk_global_typ_var(adt_name, ty.TypeKind.AdtHandle)
# parse type params # parse type params
type_params = ctx.typeParamList() type_params = ctx.typeParamList()
if type_params is None: if type_params is None:
type_params = [] type_params = []
else: else:
type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type) type_params = [self.mk_typ(type_ident.getText(), ty.TypeKind.Type)
for type_ident in type_params.typeExpr()] for type_ident in type_params.typeExpr()]
return adt_var, type_params return adt_var, type_params
...@@ -746,7 +747,7 @@ class StrictErrorListener(ErrorListener): ...@@ -746,7 +747,7 @@ class StrictErrorListener(ErrorListener):
def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs): def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
raise Exception("Context Sensitivity in:\n" + self.text) raise Exception("Context Sensitivity in:\n" + self.text)
def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]: def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, IRModule]:
"""Parse a Relay program.""" """Parse a Relay program."""
if data == "": if data == "":
raise ParseError("cannot parse the empty string.") raise ParseError("cannot parse the empty string.")
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Algebraic data types in Relay.""" """Algebraic data types in Relay."""
from tvm.ir import Constructor, TypeData
from .base import RelayNode, register_relay_node, Object from .base import RelayNode, register_relay_node, Object
from . import _make from . import _make
from .ty import Type from .ty import Type
from .expr import Expr, Call from .expr import ExprWithOp, RelayExpr, Call
class Pattern(RelayNode): class Pattern(RelayNode):
...@@ -113,77 +115,6 @@ class PatternTuple(Pattern): ...@@ -113,77 +115,6 @@ class PatternTuple(Pattern):
@register_relay_node @register_relay_node
class Constructor(Expr):
"""Relay ADT constructor."""
def __init__(self, name_hint, inputs, belong_to):
"""Defines an ADT constructor.
Parameters
----------
name_hint : str
Name of constructor (only a hint).
inputs : List[Type]
Input types.
belong_to : tvm.relay.GlobalTypeVar
Denotes which ADT the constructor belongs to.
Returns
-------
con: Constructor
A constructor.
"""
self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to)
def __call__(self, *args):
"""Call the constructor.
Parameters
----------
args: List[relay.Expr]
The arguments to the constructor.
Returns
-------
call: relay.Call
A call to the constructor.
"""
return Call(self, args)
@register_relay_node
class TypeData(Type):
"""Stores the definition for an Algebraic Data Type (ADT) in Relay.
Note that ADT definitions are treated as type-level functions because
the type parameters need to be given for an instance of the ADT. Thus,
any global type var that is an ADT header needs to be wrapped in a
type call that passes in the type params.
"""
def __init__(self, header, type_vars, constructors):
"""Defines a TypeData object.
Parameters
----------
header: tvm.relay.GlobalTypeVar
The name of the ADT.
ADTs with the same constructors but different names are
treated as different types.
type_vars: List[TypeVar]
Type variables that appear in constructors.
constructors: List[tvm.relay.Constructor]
The constructors for the ADT.
Returns
-------
type_data: TypeData
The adt declaration.
"""
self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors)
@register_relay_node
class Clause(Object): class Clause(Object):
"""Clause for pattern matching in Relay.""" """Clause for pattern matching in Relay."""
...@@ -206,7 +137,7 @@ class Clause(Object): ...@@ -206,7 +137,7 @@ class Clause(Object):
@register_relay_node @register_relay_node
class Match(Expr): class Match(ExprWithOp):
"""Pattern matching expression in Relay.""" """Pattern matching expression in Relay."""
def __init__(self, data, clauses, complete=True): def __init__(self, data, clauses, complete=True):
......
...@@ -20,11 +20,11 @@ ...@@ -20,11 +20,11 @@
This file contains the set of passes for Relay, which exposes an interface for This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python. configuring the passes and scripting them in Python.
""" """
from tvm.ir import RelayExpr, IRModule
from . import _analysis from . import _analysis
from . import _make from . import _make
from .expr import Expr
from .ty import Type from .ty import Type
from .module import Module
from .feature import Feature from .feature import Feature
...@@ -70,7 +70,7 @@ def check_kind(t, mod=None): ...@@ -70,7 +70,7 @@ def check_kind(t, mod=None):
t : tvm.relay.Type t : tvm.relay.Type
The type to check The type to check
mod : Optional[tvm.relay.Module] mod : Optional[tvm.IRModule]
The global module. The global module.
Returns Returns
...@@ -169,7 +169,7 @@ def free_type_vars(expr, mod=None): ...@@ -169,7 +169,7 @@ def free_type_vars(expr, mod=None):
expr : Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod : Optional[tvm.relay.Module] mod : Optional[tvm.IRModule]
The global module The global module
Returns Returns
...@@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None): ...@@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None):
free : List[tvm.relay.TypeVar] free : List[tvm.relay.TypeVar]
The list of free type variables in post-DFS order The list of free type variables in post-DFS order
""" """
use_mod = mod if mod is not None else Module() use_mod = mod if mod is not None else IRModule()
return _analysis.free_type_vars(expr, use_mod) return _analysis.free_type_vars(expr, use_mod)
...@@ -189,7 +189,7 @@ def bound_type_vars(expr, mod=None): ...@@ -189,7 +189,7 @@ def bound_type_vars(expr, mod=None):
expr : Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod : Optional[tvm.relay.Module] mod : Optional[tvm.IRModule]
The global module The global module
Returns Returns
...@@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None): ...@@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None):
free : List[tvm.relay.TypeVar] free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order The list of bound type variables in post-DFS order
""" """
use_mod = mod if mod is not None else Module() use_mod = mod if mod is not None else IRModule()
return _analysis.bound_type_vars(expr, use_mod) return _analysis.bound_type_vars(expr, use_mod)
...@@ -209,7 +209,7 @@ def all_type_vars(expr, mod=None): ...@@ -209,7 +209,7 @@ def all_type_vars(expr, mod=None):
expr : Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod : Optional[tvm.relay.Module] mod : Optional[tvm.IRModule]
The global module The global module
Returns Returns
...@@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None): ...@@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None):
free : List[tvm.relay.TypeVar] free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order The list of all type variables in post-DFS order
""" """
use_mod = mod if mod is not None else Module() use_mod = mod if mod is not None else IRModule()
return _analysis.all_type_vars(expr, use_mod) return _analysis.all_type_vars(expr, use_mod)
...@@ -353,7 +353,7 @@ def unmatched_cases(match, mod=None): ...@@ -353,7 +353,7 @@ def unmatched_cases(match, mod=None):
match : tvm.relay.Match match : tvm.relay.Match
The match expression The match expression
mod : Optional[tvm.relay.Module] mod : Optional[tvm.IRModule]
The module (defaults to an empty module) The module (defaults to an empty module)
Returns Returns
...@@ -370,10 +370,10 @@ def detect_feature(a, b=None): ...@@ -370,10 +370,10 @@ def detect_feature(a, b=None):
Parameters Parameters
---------- ----------
a : Union[tvm.relay.Expr, tvm.relay.Module] a : Union[tvm.relay.Expr, tvm.IRModule]
The input expression or module. The input expression or module.
b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]] b : Optional[Union[tvm.relay.Expr, tvm.IRModule]]
The input expression or module. The input expression or module.
The two arguments cannot both be expression or module. The two arguments cannot both be expression or module.
...@@ -382,7 +382,7 @@ def detect_feature(a, b=None): ...@@ -382,7 +382,7 @@ def detect_feature(a, b=None):
features : Set[Feature] features : Set[Feature]
Features used in the program. Features used in the program.
""" """
if isinstance(a, Module): if isinstance(a, IRModule):
a, b = b, a a, b = b, a
return {Feature(int(x)) for x in _analysis.detect_feature(a, b)} return {Feature(int(x)) for x in _analysis.detect_feature(a, b)}
...@@ -400,7 +400,7 @@ def structural_hash(value): ...@@ -400,7 +400,7 @@ def structural_hash(value):
result : int result : int
The hash value The hash value
""" """
if isinstance(value, Expr): if isinstance(value, RelayExpr):
return int(_analysis._expr_hash(value)) return int(_analysis._expr_hash(value))
elif isinstance(value, Type): elif isinstance(value, Type):
return int(_analysis._type_hash(value)) return int(_analysis._type_hash(value))
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
# under the License. # under the License.
"""The interface of expr function exposed from C++.""" """The interface of expr function exposed from C++."""
import tvm._ffi import tvm._ffi
from tvm.ir import container as _container
from ... import build_module as _build from ... import build_module as _build
from ... import container as _container
@tvm._ffi.register_func("relay.backend.lower") @tvm._ffi.register_func("relay.backend.lower")
......
...@@ -21,9 +21,10 @@ from __future__ import absolute_import ...@@ -21,9 +21,10 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from tvm.runtime import container from tvm.runtime import container
from tvm.ir import IRModule
from . import _backend from . import _backend
from .. import _make, analysis, transform from .. import _make, analysis, transform
from .. import module
from ... import nd from ... import nd
from ..base import Object, register_relay_node from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
...@@ -186,10 +187,10 @@ class Interpreter(Executor): ...@@ -186,10 +187,10 @@ class Interpreter(Executor):
Parameters Parameters
---------- ----------
mod : tvm.relay.Module mod : tvm.IRModule
The module to support the execution. The module to support the execution.
ctx : tvm.TVMContext ctx : tvmContext
The runtime context to run the code on. The runtime context to run the code on.
target : tvm.Target target : tvm.Target
...@@ -205,7 +206,7 @@ class Interpreter(Executor): ...@@ -205,7 +206,7 @@ class Interpreter(Executor):
Returns Returns
------- -------
opt_mod : tvm.relay.Module opt_mod : tvm.IRModule
The optimized module. The optimized module.
""" """
seq = transform.Sequential([transform.SimplifyInference(), seq = transform.Sequential([transform.SimplifyInference(),
...@@ -239,7 +240,7 @@ class Interpreter(Executor): ...@@ -239,7 +240,7 @@ class Interpreter(Executor):
if self.mod: if self.mod:
self.mod["main"] = func self.mod["main"] = func
else: else:
self.mod = module.Module.from_expr(func) self.mod = IRModule.from_expr(func)
mod = self.optimize() mod = self.optimize()
opt_expr = Call(mod["main"], relay_args) opt_expr = Call(mod["main"], relay_args)
......
...@@ -36,7 +36,7 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -36,7 +36,7 @@ def compile(mod, target=None, target_host=None, params=None):
Parameters Parameters
---------- ----------
mod : relay.Module mod : tvm.IRModule
The Relay module to build. The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. target : str, :any:`tvm.target.Target`, or dict of str(i.e.
...@@ -110,7 +110,7 @@ class VMCompiler(object): ...@@ -110,7 +110,7 @@ class VMCompiler(object):
Parameters Parameters
---------- ----------
mod : relay.Module mod : tvm.IRModule
The Relay module to build. The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. target : str, :any:`tvm.target.Target`, or dict of str(i.e.
...@@ -142,7 +142,7 @@ class VMCompiler(object): ...@@ -142,7 +142,7 @@ class VMCompiler(object):
Parameters Parameters
---------- ----------
mod : relay.Module mod : tvm.IRModule
target : str, :any:`tvm.target.Target`, or dict of str (i.e. target : str, :any:`tvm.target.Target`, or dict of str (i.e.
device/context name) to str/tvm.target.Target, optional device/context name) to str/tvm.target.Target, optional
...@@ -153,7 +153,7 @@ class VMCompiler(object): ...@@ -153,7 +153,7 @@ class VMCompiler(object):
Returns Returns
------- -------
mod : relay.Module mod : tvm.IRModule
The optimized relay module. The optimized relay module.
params : dict params : dict
...@@ -229,10 +229,10 @@ class VMExecutor(Executor): ...@@ -229,10 +229,10 @@ class VMExecutor(Executor):
Parameters Parameters
---------- ----------
mod : :py:class:`~tvm.relay.module.Module` mod : :py:class:`~tvm.IRModule`
The module to support the execution. The module to support the execution.
ctx : :py:class:`~tvm.TVMContext` ctx : :py:class:`~tvmContext`
The runtime context to run the code on. The runtime context to run the code on.
target : :py:class:`Target` target : :py:class:`Target`
......
...@@ -14,16 +14,25 @@ ...@@ -14,16 +14,25 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck # pylint: disable=no-else-return, unidiomatic-typecheck, unused-import
"""The base node types for the Relay language.""" """The base node types for the Relay language."""
import os
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from tvm.ir import SourceName, Span, Node as RelayNode
from . import _make from . import _make
from . import _expr from . import _expr
from . import _base from . import _base
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
@tvm._ffi.register_func("tvm.relay.std_path")
def _std_path():
return __STD_PATH__
def register_relay_node(type_key=None): def register_relay_node(type_key=None):
"""Register a Relay node type. """Register a Relay node type.
...@@ -52,55 +61,6 @@ def register_relay_attr_node(type_key=None): ...@@ -52,55 +61,6 @@ def register_relay_attr_node(type_key=None):
return tvm._ffi.register_object(type_key) return tvm._ffi.register_object(type_key)
class RelayNode(Object):
"""Base class of all Relay nodes."""
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Parameters
----------
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
Note
----
The meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big (e.g constant weights),
so it can be helpful to skip printing the meta data section.
Returns
-------
text : str
The text format of the expression.
"""
return _expr.AsText(self, show_meta_data, annotate)
def set_span(self, span):
_base.set_span(self, span)
def __str__(self):
return self.astext(show_meta_data=False)
@register_relay_node
class Span(RelayNode):
"""Specifies a location in a source program."""
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
@register_relay_node
class SourceName(RelayNode):
"""A identifier for a source location"""
def __init__(self, name):
self.__init_handle_by_constructor__(_make.SourceName, name)
@register_relay_node @register_relay_node
class Id(Object): class Id(Object):
"""Unique identifier(name) used in Var. """Unique identifier(name) used in Var.
......
...@@ -21,13 +21,14 @@ from a Relay expression. ...@@ -21,13 +21,14 @@ from a Relay expression.
import warnings import warnings
import numpy as np import numpy as np
from tvm.ir import IRModule
from tvm import expr as tvm_expr from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt from ..contrib import graph_runtime as _graph_rt
from . import _build_module from . import _build_module
from . import ty as _ty from . import ty as _ty
from . import expr as _expr from . import expr as _expr
from .module import Module as _Module
from .backend import interpreter as _interpreter from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor from .backend.vm import VMExecutor
...@@ -141,7 +142,7 @@ class BuildModule(object): ...@@ -141,7 +142,7 @@ class BuildModule(object):
Returns Returns
------- -------
mod : relay.Module mod : tvm.IRModule
The optimized relay module. The optimized relay module.
params : dict params : dict
...@@ -185,7 +186,7 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -185,7 +186,7 @@ def build(mod, target=None, target_host=None, params=None):
Parameters Parameters
---------- ----------
mod : relay.Module mod : tvm.IRModule
The module to build. Using relay.Function is deprecated. The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
...@@ -217,16 +218,16 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -217,16 +218,16 @@ def build(mod, target=None, target_host=None, params=None):
params : dict params : dict
The parameters of the final graph. The parameters of the final graph.
""" """
if isinstance(mod, _Module): if isinstance(mod, IRModule):
func = mod["main"] func = mod["main"]
elif isinstance(mod, _expr.Function): elif isinstance(mod, _expr.Function):
func = mod func = mod
warnings.warn( warnings.warn(
"Please use input parameter mod (tvm.relay.module.Module) " "Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)", "instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning) DeprecationWarning)
else: else:
raise ValueError("Type of input parameter mod must be tvm.relay.module.Module") raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target) target = _update_target(target)
...@@ -254,7 +255,7 @@ def optimize(mod, target=None, params=None): ...@@ -254,7 +255,7 @@ def optimize(mod, target=None, params=None):
Parameters Parameters
---------- ----------
mod : relay.Module mod : tvm.IRModule
The module to build. Using relay.Function is deprecated. The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
...@@ -268,7 +269,7 @@ def optimize(mod, target=None, params=None): ...@@ -268,7 +269,7 @@ def optimize(mod, target=None, params=None):
Returns Returns
------- -------
mod : relay.Module mod : tvm.IRModule
The optimized relay module. The optimized relay module.
params : dict params : dict
...@@ -279,11 +280,11 @@ def optimize(mod, target=None, params=None): ...@@ -279,11 +280,11 @@ def optimize(mod, target=None, params=None):
elif isinstance(mod, _expr.Function): elif isinstance(mod, _expr.Function):
func = mod func = mod
warnings.warn( warnings.warn(
"Please use input parameter mod (tvm.relay.module.Module) " "Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)", "instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning) DeprecationWarning)
else: else:
raise ValueError("Type of input parameter mod must be tvm.relay.module.Module") raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target) target = _update_target(target)
...@@ -330,7 +331,7 @@ class GraphExecutor(_interpreter.Executor): ...@@ -330,7 +331,7 @@ class GraphExecutor(_interpreter.Executor):
Parameters Parameters
---------- ----------
mod : :py:class:`~tvm.relay.module.Module` mod : :py:class:`~tvm.IRModule`
The module to support the execution. The module to support the execution.
ctx : :py:class:`TVMContext` ctx : :py:class:`TVMContext`
...@@ -385,17 +386,17 @@ def create_executor(kind="debug", ...@@ -385,17 +386,17 @@ def create_executor(kind="debug",
kind : str kind : str
The type of executor The type of executor
mod : :py:class:`~tvm.relay.module.Module` mod : :py:class:`~tvm.IRModule`
The Relay module containing collection of functions The Relay module containing collection of functions
ctx : :py:class:`tvm.TVMContext` ctx : :py:class:`tvmContext`
The context to execute the code. The context to execute the code.
target : :py:class:`tvm.Target` target : :py:class:`tvm.Target`
The corresponding context The corresponding context
""" """
if mod is None: if mod is None:
mod = _Module() mod = IRModule()
if ctx is not None: if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type assert ctx.device_type == _nd.context(str(target), 0).device_type
else: else:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, invalid-name, unused-import
"""The expression nodes of Relay.""" """The expression nodes of Relay."""
from __future__ import absolute_import from __future__ import absolute_import
from numbers import Number as _Number from numbers import Number as _Number
...@@ -22,33 +22,21 @@ from numbers import Number as _Number ...@@ -22,33 +22,21 @@ from numbers import Number as _Number
import numpy as _np import numpy as _np
from tvm._ffi import base as _base from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd from tvm.runtime import NDArray, convert, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar, BaseFunc
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import _expr from . import _expr
from . import ty as _ty from . import ty as _ty
# alias relay expr as Expr.
Expr = RelayExpr
# will be registered afterwards # will be registered afterwards
_op_make = None _op_make = None
class Expr(RelayNode): class ExprWithOp(RelayExpr):
"""The base type for all Relay expressions.""" """Basetype of all relay expressions that defines op overloading."""
@property
def checked_type(self):
"""Get the checked type of tvm.relay.Expr.
Returns
-------
checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
" the checked_type for this node")
return ret
def astype(self, dtype): def astype(self, dtype):
"""Cast the content type of the current data to dtype. """Cast the content type of the current data to dtype.
...@@ -173,7 +161,7 @@ class Expr(RelayNode): ...@@ -173,7 +161,7 @@ class Expr(RelayNode):
return Call(self, args) return Call(self, args)
@register_relay_node @register_relay_node
class Constant(Expr): class Constant(ExprWithOp):
"""A constant expression in Relay. """A constant expression in Relay.
Parameters Parameters
...@@ -186,7 +174,7 @@ class Constant(Expr): ...@@ -186,7 +174,7 @@ class Constant(Expr):
@register_relay_node @register_relay_node
class Tuple(Expr): class Tuple(ExprWithOp):
"""Tuple expression that groups several fields together. """Tuple expression that groups several fields together.
Parameters Parameters
...@@ -210,7 +198,7 @@ class Tuple(Expr): ...@@ -210,7 +198,7 @@ class Tuple(Expr):
@register_relay_node @register_relay_node
class Var(Expr): class Var(ExprWithOp):
"""A local variable in Relay. """A local variable in Relay.
Local variable can be used to declare input Local variable can be used to declare input
...@@ -238,33 +226,7 @@ class Var(Expr): ...@@ -238,33 +226,7 @@ class Var(Expr):
@register_relay_node @register_relay_node
class GlobalVar(Expr): class Function(BaseFunc):
"""A global variable in Tvm.Relay.
GlobalVar is used to refer to the global functions
stored in the module.
Parameters
----------
name_hint: str
The name of the variable.
"""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
def __call__(self, *args):
"""Invoke the gobal function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
@register_relay_node
class Function(Expr):
"""A function declaration expression. """A function declaration expression.
Parameters Parameters
...@@ -320,7 +282,7 @@ class Function(Expr): ...@@ -320,7 +282,7 @@ class Function(Expr):
@register_relay_node @register_relay_node
class Call(Expr): class Call(ExprWithOp):
"""Function call node in Relay. """Function call node in Relay.
Call node corresponds the operator application node Call node corresponds the operator application node
...@@ -349,7 +311,7 @@ class Call(Expr): ...@@ -349,7 +311,7 @@ class Call(Expr):
@register_relay_node @register_relay_node
class Let(Expr): class Let(ExprWithOp):
"""Let variable binding expression. """Let variable binding expression.
Parameters Parameters
...@@ -369,7 +331,7 @@ class Let(Expr): ...@@ -369,7 +331,7 @@ class Let(Expr):
@register_relay_node @register_relay_node
class If(Expr): class If(ExprWithOp):
"""A conditional expression in Relay. """A conditional expression in Relay.
Parameters Parameters
...@@ -389,7 +351,7 @@ class If(Expr): ...@@ -389,7 +351,7 @@ class If(Expr):
@register_relay_node @register_relay_node
class TupleGetItem(Expr): class TupleGetItem(ExprWithOp):
"""Get index-th item from a tuple. """Get index-th item from a tuple.
Parameters Parameters
...@@ -406,7 +368,7 @@ class TupleGetItem(Expr): ...@@ -406,7 +368,7 @@ class TupleGetItem(Expr):
@register_relay_node @register_relay_node
class RefCreate(Expr): class RefCreate(ExprWithOp):
"""Create a new reference from initial value. """Create a new reference from initial value.
Parameters Parameters
---------- ----------
...@@ -418,7 +380,7 @@ class RefCreate(Expr): ...@@ -418,7 +380,7 @@ class RefCreate(Expr):
@register_relay_node @register_relay_node
class RefRead(Expr): class RefRead(ExprWithOp):
"""Get the value inside the reference. """Get the value inside the reference.
Parameters Parameters
---------- ----------
...@@ -430,7 +392,7 @@ class RefRead(Expr): ...@@ -430,7 +392,7 @@ class RefRead(Expr):
@register_relay_node @register_relay_node
class RefWrite(Expr): class RefWrite(ExprWithOp):
""" """
Update the value inside the reference. Update the value inside the reference.
The whole expression will evaluate to an empty tuple. The whole expression will evaluate to an empty tuple.
...@@ -445,7 +407,7 @@ class RefWrite(Expr): ...@@ -445,7 +407,7 @@ class RefWrite(Expr):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value) self.__init_handle_by_constructor__(_make.RefWrite, ref, value)
class TempExpr(Expr): class TempExpr(ExprWithOp):
"""Baseclass of all TempExpr. """Baseclass of all TempExpr.
TempExprs are pass specific expression that can be TempExprs are pass specific expression that can be
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List
import tvm
from .base import Span, Object
from .ty import Type, TypeParam
from ._analysis import _get_checked_type
class Expr(Object):
def checked_type(self):
...
def __call__(self, *args):
...
class Constant(Expr):
data = ... # type: tvm.nd.NDArray
def __init__(self, data):
# type: (tvm.nd.NDArray) -> None
...
class Tuple(Expr):
fields = ... # type: List[Expr]
def __init__(self, fields):
# type: (List[Expr]) -> None
...
class Var(Expr):
"""A local variable in Relay."""
name_hint = ... # type: str
def __init__(self, name_hint):
# type: (str) -> None
...
class GlobalVar(Expr):
name_hint = ... # type: str
def __init__(self, name_hint):
# type: (str) -> None
...
class Param(Expr):
var = ... # type: Var
type = ... # type: Type
def __init__(self, var, ty):
# type: (Var, Type) -> None
...
class Function(Expr):
"""A function in Relay, see tvm/relay/expr.h for more details."""
type_params = ... # type: List[TypeParam]
params = ... # type: List[Param]
ret_type = ... # type: Type
body = ... # type: Expr
def __init__(self,
params, # type: List[Param],
ret_type, # type: Type,
body, # type: Expr,
type_params=None, # type: List[TypeParam]
):
# type: (...) -> None
...
@register_relay_node
class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
op = ... # type: Expr
args = ... # type: List[Expr]
# todo(@jroesch): add attrs. revise attrs type in __init__
def __init__(self, op, args, attrs=None, ty_args=None):
# type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None
if not ty_args:
ty_args = []
self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, ty_args)
@register_relay_node
class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
var = ... # type: Var
value = ... # type: Expr
body = ... # type: Expr
value_type = ... # type: Type
def __init__(self, var, value, body, value_type):
# type: (Var, Expr, Expr, Type) -> None
...
@register_relay_node
class If(Expr):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
cond = ... # type: Expr
true_value = ... # type: Expr
false_value = ... # type: Expr
span = ... # type: Span
def __init__(self, cond, true_value, false_value):
# type: (Expr, Expr, Expr) -> None
...
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
# under the License. # under the License.
# pylint: disable=import-self, invalid-name, line-too-long, unused-argument # pylint: disable=import-self, invalid-name, line-too-long, unused-argument
"""Caffe2 frontend""" """Caffe2 frontend"""
from __future__ import absolute_import as _abs
import tvm import tvm
from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import AttrCvt, Renamer from .common import AttrCvt, Renamer
...@@ -383,7 +383,7 @@ class Caffe2NetDef(object): ...@@ -383,7 +383,7 @@ class Caffe2NetDef(object):
self._ops = {} self._ops = {}
self._shape = shape self._shape = shape
self._dtype = dtype self._dtype = dtype
self._mod = _module.Module({}) self._mod = IRModule({})
def from_caffe2(self, init_net, predict_net): def from_caffe2(self, init_net, predict_net):
"""Construct Relay expression from caffe2 graph. """Construct Relay expression from caffe2 graph.
...@@ -395,7 +395,7 @@ class Caffe2NetDef(object): ...@@ -395,7 +395,7 @@ class Caffe2NetDef(object):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The module that optimizations will be performed on. The module that optimizations will be performed on.
params : dict params : dict
...@@ -565,7 +565,7 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"): ...@@ -565,7 +565,7 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The module that optimizations will be performed on. The module that optimizations will be performed on.
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
......
...@@ -20,9 +20,10 @@ import logging ...@@ -20,9 +20,10 @@ import logging
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import transform as _transform from .. import transform as _transform
from .. import op as _op from .. import op as _op
from .. import analysis from .. import analysis
...@@ -453,7 +454,7 @@ def get_name(node): ...@@ -453,7 +454,7 @@ def get_name(node):
def infer_type(node, mod=None): def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
new_mod = _module.Module.from_expr(node) new_mod = IRModule.from_expr(node)
if mod is not None: if mod is not None:
new_mod.update(mod) new_mod.update(mod)
new_mod = _transform.InferType()(new_mod) new_mod = _transform.InferType()(new_mod)
......
...@@ -21,9 +21,10 @@ from __future__ import absolute_import as _abs ...@@ -21,9 +21,10 @@ from __future__ import absolute_import as _abs
import math import math
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from ..._ffi import base as _base from ..._ffi import base as _base
...@@ -449,7 +450,7 @@ def from_coreml(model, shape=None): ...@@ -449,7 +450,7 @@ def from_coreml(model, shape=None):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module for compilation. The relay module for compilation.
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
...@@ -505,4 +506,4 @@ def from_coreml(model, shape=None): ...@@ -505,4 +506,4 @@ def from_coreml(model, shape=None):
outexpr = outexpr[0] outexpr = outexpr[0]
func = _expr.Function(analysis.free_vars(outexpr), outexpr) func = _expr.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return _module.Module.from_expr(func), params return IRModule.from_expr(func), params
...@@ -23,9 +23,10 @@ from __future__ import absolute_import as _abs ...@@ -23,9 +23,10 @@ from __future__ import absolute_import as _abs
from enum import Enum from enum import Enum
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .common import get_relay_op, new_var from .common import get_relay_op, new_var
__all__ = ['from_darknet'] __all__ = ['from_darknet']
...@@ -822,7 +823,7 @@ class GraphProto(object): ...@@ -822,7 +823,7 @@ class GraphProto(object):
outputs = _as_list(sym) + self._outs outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
sym = _expr.Function(analysis.free_vars(outputs), outputs) sym = _expr.Function(analysis.free_vars(outputs), outputs)
return _module.Module.from_expr(sym), self._tvmparams return IRModule.from_expr(sym), self._tvmparams
def from_darknet(net, def from_darknet(net,
shape=None, shape=None,
...@@ -840,7 +841,7 @@ def from_darknet(net, ...@@ -840,7 +841,7 @@ def from_darknet(net,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module for compilation. The relay module for compilation.
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
......
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
import sys import sys
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import ExprTable, new_var from .common import ExprTable, new_var
...@@ -752,7 +753,7 @@ def from_keras(model, shape=None): ...@@ -752,7 +753,7 @@ def from_keras(model, shape=None):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module for compilation. The relay module for compilation.
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
...@@ -837,4 +838,4 @@ def from_keras(model, shape=None): ...@@ -837,4 +838,4 @@ def from_keras(model, shape=None):
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
func = _expr.Function(analysis.free_vars(outexpr), outexpr) func = _expr.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return _module.Module.from_expr(func), params return IRModule.from_expr(func), params
...@@ -21,12 +21,13 @@ from __future__ import absolute_import as _abs ...@@ -21,12 +21,13 @@ from __future__ import absolute_import as _abs
import json import json
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from tvm import relay from tvm import relay
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .. import module as _module
from .. import scope_builder as _scope_builder from .. import scope_builder as _scope_builder
from ... import nd as _nd from ... import nd as _nd
...@@ -1902,7 +1903,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): ...@@ -1902,7 +1903,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
dtype_info : dict or str. dtype_info : dict or str.
Known parameter dtypes Known parameter dtypes
mod : tvm.relay.Module mod : tvm.IRModule
The module that contains global information. It will be used for The module that contains global information. It will be used for
converting ops that need global information, e.g. control-flow ops. converting ops that need global information, e.g. control-flow ops.
...@@ -2009,7 +2010,7 @@ def from_mxnet(symbol, ...@@ -2009,7 +2010,7 @@ def from_mxnet(symbol,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module for compilation The relay module for compilation
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
...@@ -2020,7 +2021,7 @@ def from_mxnet(symbol, ...@@ -2020,7 +2021,7 @@ def from_mxnet(symbol,
except ImportError as e: except ImportError as e:
raise ImportError("{}. MXNet is required to parse symbols.".format(e)) raise ImportError("{}. MXNet is required to parse symbols.".format(e))
mod = _module.Module() mod = IRModule()
if isinstance(symbol, mx.sym.Symbol): if isinstance(symbol, mx.sym.Symbol):
params = {} params = {}
arg_params = arg_params if arg_params else {} arg_params = arg_params if arg_params else {}
......
...@@ -17,14 +17,13 @@ ...@@ -17,14 +17,13 @@
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relay.""" """ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from ... import nd as _nd from ... import nd as _nd
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from .common import AttrCvt, Renamer from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels from .common import get_relay_op, new_var, infer_shape, infer_channels
...@@ -1615,7 +1614,7 @@ class GraphProto(object): ...@@ -1615,7 +1614,7 @@ class GraphProto(object):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The returned relay module The returned relay module
params : dict params : dict
...@@ -1708,7 +1707,7 @@ class GraphProto(object): ...@@ -1708,7 +1707,7 @@ class GraphProto(object):
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs) func = _expr.Function(analysis.free_vars(outputs), outputs)
return _module.Module.from_expr(func), self._params return IRModule.from_expr(func), self._params
def _parse_value_proto(self, value_proto): def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str.""" """Parse ValueProto or raw str."""
...@@ -1836,7 +1835,7 @@ def from_onnx(model, ...@@ -1836,7 +1835,7 @@ def from_onnx(model,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module for compilation The relay module for compilation
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
......
...@@ -29,13 +29,13 @@ import numpy as np ...@@ -29,13 +29,13 @@ import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ..expr_functor import ExprMutator from ..expr_functor import ExprMutator
from .. import module as _module
from .common import AttrCvt, get_relay_op from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
...@@ -2136,7 +2136,7 @@ class GraphProto(object): ...@@ -2136,7 +2136,7 @@ class GraphProto(object):
self._input_shapes = {} self._input_shapes = {}
self._loops = {} self._loops = {}
self._branches = {} self._branches = {}
self._mod = _module.Module({}) self._mod = IRModule({})
self._prelude = Prelude(self._mod) self._prelude = Prelude(self._mod)
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
...@@ -2171,7 +2171,7 @@ class GraphProto(object): ...@@ -2171,7 +2171,7 @@ class GraphProto(object):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The module that optimizations will be performed on. The module that optimizations will be performed on.
params : dict params : dict
...@@ -2653,7 +2653,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): ...@@ -2653,7 +2653,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The module that optimizations will be performed on. The module that optimizations will be performed on.
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel # pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend.""" """Tensorflow lite frontend."""
from __future__ import absolute_import as _abs
import math import math
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from tvm import relay from tvm import relay
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from .. import qnn as _qnn from .. import qnn as _qnn
from ..util import get_scalar_from_constant from ..util import get_scalar_from_constant
...@@ -1901,7 +1901,7 @@ def from_tflite(model, shape_dict, dtype_dict): ...@@ -1901,7 +1901,7 @@ def from_tflite(model, shape_dict, dtype_dict):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module for compilation. The relay module for compilation.
params : dict of str to tvm.nd.NDArray params : dict of str to tvm.nd.NDArray
...@@ -1940,5 +1940,5 @@ def from_tflite(model, shape_dict, dtype_dict): ...@@ -1940,5 +1940,5 @@ def from_tflite(model, shape_dict, dtype_dict):
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs) func = _expr.Function(analysis.free_vars(outputs), outputs)
mod = _module.Module.from_expr(func) mod = IRModule.from_expr(func)
return mod, params return mod, params
...@@ -176,7 +176,7 @@ class ManifestAllocPass(ExprMutator): ...@@ -176,7 +176,7 @@ class ManifestAllocPass(ExprMutator):
view = LinearizeRetType(ret_type) view = LinearizeRetType(ret_type)
out_types = view.unpack() out_types = view.unpack()
is_dynamic = ret_type.is_dynamic() is_dynamic = ty.type_has_any(ret_type)
# TODO(@jroesch): restore this code, more complex then it seems # TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args: # for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic() # is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
......
...@@ -41,7 +41,6 @@ from . import _tensor_grad ...@@ -41,7 +41,6 @@ from . import _tensor_grad
from . import _transform from . import _transform
from . import _reduce from . import _reduce
from . import _algorithm from . import _algorithm
from ..expr import Expr
from ..base import register_relay_node from ..base import register_relay_node
......
...@@ -275,7 +275,7 @@ def legalize_conv2d(attrs, inputs, types): ...@@ -275,7 +275,7 @@ def legalize_conv2d(attrs, inputs, types):
Parameters Parameters
---------- ----------
attrs : tvm.attrs.Attrs attrs : tvm.ir.Attrs
Attributes of current convolution Attributes of current convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized The args of the Relay expr to be legalized
...@@ -296,7 +296,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): ...@@ -296,7 +296,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
Parameters Parameters
---------- ----------
attrs : tvm.attrs.Attrs attrs : tvm.ir.Attrs
Attributes of current convolution Attributes of current convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized The args of the Relay expr to be legalized
...@@ -413,7 +413,7 @@ def legalize_conv2d_transpose(attrs, inputs, types): ...@@ -413,7 +413,7 @@ def legalize_conv2d_transpose(attrs, inputs, types):
Parameters Parameters
---------- ----------
attrs : tvm.attrs.Attrs attrs : tvm.ir.Attrs
Attributes of current Transposed convolution Attributes of current Transposed convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized The args of the Relay expr to be legalized
...@@ -947,7 +947,7 @@ def legalize_bitserial_conv2d(attrs, inputs, types): ...@@ -947,7 +947,7 @@ def legalize_bitserial_conv2d(attrs, inputs, types):
Parameters Parameters
---------- ----------
attrs : tvm.attrs.Attrs attrs : tvm.ir.Attrs
Attributes of current convolution Attributes of current convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized The args of the Relay expr to be legalized
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable # pylint: disable=invalid-name, unused-variable
"""NN operator common utilities""" """NN operator common utilities"""
from __future__ import absolute_import from tvm.ir import container
from .... import container
def get_pad_tuple2d(padding): def get_pad_tuple2d(padding):
"""Common code to get the pad option """Common code to get the pad option
......
...@@ -20,13 +20,13 @@ import topi ...@@ -20,13 +20,13 @@ import topi
import tvm._ffi import tvm._ffi
from ..base import register_relay_node from ..base import register_relay_node
from ..expr import Expr from ..expr import RelayExpr
from ...api import register_func from ...api import register_func
from ...build_module import lower, build from ...build_module import lower, build
from . import _make from . import _make
@register_relay_node @register_relay_node
class Op(Expr): class Op(RelayExpr):
"""A Relay operator definition.""" """A Relay operator definition."""
def __init__(self): def __init__(self):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
"""The attributes node used for Relay operators""" """The attributes node used for Relay operators"""
from ...attrs import Attrs from tvm.ir import Attrs
from ..base import register_relay_attr_node from ..base import register_relay_attr_node
......
...@@ -16,13 +16,15 @@ ...@@ -16,13 +16,15 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions.""" """A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule
from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, Function, GlobalVar, If, const from .expr import Var, Function, GlobalVar, If, const
from .op.tensor import add, subtract, equal from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op from . import op
from .module import Module
class TensorArrayOps(object): class TensorArrayOps(object):
"""Contains tensor array related ops""" """Contains tensor array related ops"""
...@@ -648,7 +650,7 @@ class Prelude: ...@@ -648,7 +650,7 @@ class Prelude:
def __init__(self, mod=None): def __init__(self, mod=None):
if mod is None: if mod is None:
mod = Module() mod = IRModule()
self.mod = mod self.mod = mod
self.load_prelude() self.load_prelude()
......
...@@ -63,7 +63,7 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): ...@@ -63,7 +63,7 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
Parameters Parameters
---------- ----------
attrs : tvm.attrs.Attrs attrs : tvm.ir.Attrs
Attributes of current convolution Attributes of current convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized The args of the Relay expr to be legalized
...@@ -106,7 +106,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): ...@@ -106,7 +106,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
Parameters Parameters
---------- ----------
attrs : tvm.attrs.Attrs attrs : tvm.ir.Attrs
Attributes of current convolution Attributes of current convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized The args of the Relay expr to be legalized
...@@ -169,7 +169,7 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): ...@@ -169,7 +169,7 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
Parameters Parameters
---------- ----------
attrs : tvm.attrs.Attrs attrs : tvm.ir.Attrs
Attributes of current convolution Attributes of current convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized The args of the Relay expr to be legalized
......
...@@ -42,7 +42,7 @@ def CanonicalizeOps(): ...@@ -42,7 +42,7 @@ def CanonicalizeOps():
# We want to utilize all the existing Relay infrastructure. So, instead of supporting this # We want to utilize all the existing Relay infrastructure. So, instead of supporting this
# QNN requantize op, we convert it into a sequence of existing Relay operators. # QNN requantize op, we convert it into a sequence of existing Relay operators.
mod = relay.Module.from_expr(qnn_expr) mod = tvm.IRModule.from_expr(qnn_expr)
mod = relay.qnn.transform.CanonicalizeOps()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
relay_expr = mod['main'] relay_expr = mod['main']
print(relay_expr) print(relay_expr)
......
...@@ -20,12 +20,12 @@ import logging ...@@ -20,12 +20,12 @@ import logging
import multiprocessing as mp import multiprocessing as mp
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule
from . import _quantize from . import _quantize
from . import quantize from . import quantize
from .. import op as _op from .. import op as _op
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import analysis as _analysis from .. import analysis as _analysis
from .. import transform as _transform from .. import transform as _transform
from .. import build_module as _build_module from .. import build_module as _build_module
...@@ -141,7 +141,7 @@ def _set_params(mod, input_scale_func, weight_scale_func): ...@@ -141,7 +141,7 @@ def _set_params(mod, input_scale_func, weight_scale_func):
func = mod['main'] func = mod['main']
_analysis.post_order_visit(func, visit_func) _analysis.post_order_visit(func, visit_func)
func = _expr.bind(func, const_params) func = _expr.bind(func, const_params)
return _module.Module.from_expr(func) return IRModule.from_expr(func)
# weight scale functions # weight scale functions
......
...@@ -47,7 +47,7 @@ from ..transform import gradient ...@@ -47,7 +47,7 @@ from ..transform import gradient
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -103,7 +103,7 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype= ...@@ -103,7 +103,7 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype=
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a DCGAN network. The relay module that contains a DCGAN network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -105,7 +105,7 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4, ...@@ -105,7 +105,7 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4,
Returns Returns
------- -------
mod: tvm.relay.Module mod: tvm.IRModule
The relay module that contains a DenseNet network. The relay module that contains a DenseNet network.
params : dict of str to NDArray params : dict of str to NDArray
......
...@@ -72,7 +72,7 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo ...@@ -72,7 +72,7 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo
The data type The data type
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a DQN network. The relay module that contains a DQN network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -290,7 +290,7 @@ def get_workload(batch_size=1, num_classes=1000, ...@@ -290,7 +290,7 @@ def get_workload(batch_size=1, num_classes=1000,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains an Inception V3 network. The relay module that contains an Inception V3 network.
params : dict of str to NDArray params : dict of str to NDArray
......
...@@ -144,13 +144,13 @@ def create_workload(net, initializer=None, seed=0): ...@@ -144,13 +144,13 @@ def create_workload(net, initializer=None, seed=0):
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The created relay module. The created relay module.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
mod = relay.Module.from_expr(net) mod = tvm.IRModule.from_expr(net)
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
shape_dict = { shape_dict = {
v.name_hint : v.checked_type for v in mod["main"].params} v.name_hint : v.checked_type for v in mod["main"].params}
......
...@@ -173,7 +173,7 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"): ...@@ -173,7 +173,7 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"):
The data type The data type
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a LSTM network. The relay module that contains a LSTM network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -84,7 +84,7 @@ def get_workload(batch_size, ...@@ -84,7 +84,7 @@ def get_workload(batch_size,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a mlp network. The relay module that contains a mlp network.
params : dict of str to NDArray params : dict of str to NDArray
......
...@@ -151,7 +151,7 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), ...@@ -151,7 +151,7 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224),
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a MobileNet network. The relay module that contains a MobileNet network.
params : dict of str to NDArray params : dict of str to NDArray
......
...@@ -584,7 +584,7 @@ class PythonConverter(ExprFunctor): ...@@ -584,7 +584,7 @@ class PythonConverter(ExprFunctor):
def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
"""Converts the given Relay expression into a Python script (as a Python AST object). """Converts the given Relay expression into a Python script (as a Python AST object).
For easiest debugging, import the astor package and use to_source().""" For easiest debugging, import the astor package and use to_source()."""
mod = mod if mod is not None else relay.Module() mod = mod if mod is not None else tvm.IRModule()
converter = PythonConverter(mod, target) converter = PythonConverter(mod, target)
return converter.convert(expr) return converter.convert(expr)
...@@ -592,7 +592,7 @@ def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): ...@@ -592,7 +592,7 @@ def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
"""Converts the given Relay expression into a Python script and """Converts the given Relay expression into a Python script and
executes it.""" executes it."""
mod = mod if mod is not None else relay.Module() mod = mod if mod is not None else tvm.IRModule()
py_ast = to_python(expr, mod, target) py_ast = to_python(expr, mod, target)
code = compile(py_ast, '<string>', 'exec') code = compile(py_ast, '<string>', 'exec')
var_map = { var_map = {
......
...@@ -262,7 +262,7 @@ def get_workload(batch_size=1, ...@@ -262,7 +262,7 @@ def get_workload(batch_size=1,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a ResNet network. The relay module that contains a ResNet network.
params : dict of str to NDArray params : dict of str to NDArray
......
...@@ -149,7 +149,7 @@ def get_workload(batch_size=1, ...@@ -149,7 +149,7 @@ def get_workload(batch_size=1,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a SqueezeNet network. The relay module that contains a SqueezeNet network.
params : dict of str to NDArray params : dict of str to NDArray
......
...@@ -124,7 +124,7 @@ def get_workload(batch_size, ...@@ -124,7 +124,7 @@ def get_workload(batch_size,
Returns Returns
------- -------
mod : tvm.relay.Module mod : tvm.IRModule
The relay module that contains a VGG network. The relay module that contains a VGG network.
params : dict of str to NDArray params : dict of str to NDArray
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from .base import Object
class PassContext(Object):
def __init__(self):
...
class PassInfo(Object):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list
def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None
class Pass(Object):
def __init__(self):
...
class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class Sequential(Pass):
name = ... # type: str
opt_level = ... # type: int
passes = ... # type: list
required = ... # type: list
disabled = ... # type: list
def __init__(self, name, opt_level, passes, required, disabled):
# type: (str, int, list, list, list) -> None
...
...@@ -14,133 +14,30 @@ ...@@ -14,133 +14,30 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=invalid-name, unused-import
"""The type nodes of the Relay language.""" """The type nodes of the Relay language."""
from enum import IntEnum from tvm.ir import Type, TypeKind, TypeVar, GlobalTypeVar
from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType
from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
Any = _make.Any Any = _make.Any
class Type(RelayNode): def type_has_any(tensor_type):
"""The base type for all Relay types.""" """Check whether type has any as a shape.
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._alpha_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
def __call__(self, *args):
"""Create a type call from this type.
Parameters tensor_type : Type
---------- The type to be inspected
args: List[relay.Type]
The arguments to the type call.
Returns
-------
call: relay.TypeCall
"""
return TypeCall(self, args)
def is_dynamic(self):
return _make.IsDynamic(self)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay.
This is the type assigned to tensors with a known dtype and shape. For
example, a tensor of `float32` and `(5, 5)`.
Parameters
----------
shape : List[tvm.Expr]
The shape of the Tensor
dtype : Optional[str]
The content data type.
Default to "float32".
Returns Returns
------- -------
tensor_type : tvm.relay.TensorType has_any : bool
The tensor type. The check result.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype)
@property
def concrete_shape(self):
"""Get shape of the type as concrete tuple of int.
Returns
-------
shape : List[int]
The concrete shape of the Type.
Raises
------
TypeError : If the shape is symbolic
"""
return tuple(int(x) for x in self.shape)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
Type = 0
ShapeVar = 1
BaseType = 2
Shape = 3
Constraint = 4
AdtHandle = 5
TypeData = 6
@register_relay_node
class TypeVar(Type):
"""A type variable used for generic types in Relay,
see tvm/relay/type.h for more details.
A type variable represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
""" """
return _make.IsDynamic(tensor_type)
def __init__(self, name_hint, kind=Kind.Type):
"""Construct a TypeVar.
Parameters
----------
name_hint: str
The name of the type variable. This name only acts as a hint, and
is not used for equality.
kind : Optional[Kind]
The kind of the type parameter.
Default to Kind.Type.
Returns
-------
type_var : tvm.relay.TypeVar
The type variable.
"""
self.__init_handle_by_constructor__(_make.TypeVar, name_hint, kind)
def ShapeVar(name): def ShapeVar(name):
"""A helper which constructs a type var of which the shape kind. """A helper which constructs a type var of which the shape kind.
...@@ -154,172 +51,9 @@ def ShapeVar(name): ...@@ -154,172 +51,9 @@ def ShapeVar(name):
type_var : tvm.relay.TypeVar type_var : tvm.relay.TypeVar
The shape variable. The shape variable.
""" """
return TypeVar(name, kind=Kind.ShapeVar) return TypeVar(name, kind=TypeKind.ShapeVar)
@register_relay_node
class GlobalTypeVar(Type):
"""A global type variable in Relay.
GlobalTypeVar is used to refer to the global type-level definitions
stored in the environment.
"""
def __init__(self, name_hint, kind=Kind.AdtHandle):
"""Construct a GlobalTypeVar.
Parameters
----------
name_hint: str
The name of the global type variable. This name only acts as a
hint, and is not used for equality.
kind: Kind, optional
The kind of the type parameter, Kind.AdtHandle by default.
Returns
-------
type_var: GlobalTypeVar
The global type variable.
"""
self.__init_handle_by_constructor__(_make.GlobalTypeVar, name_hint, kind)
@register_relay_node
class TypeCall(Type):
"""Type-level function application in Relay.
A type call applies argument types to a constructor (type-level function).
"""
def __init__(self, func, args):
"""Construct a TypeCall.
Parameters
----------
func: tvm.relay.Type
The function.
args: List[tvm.expr.Type]
The arguments.
Returns
-------
type_call: TypeCall
The type function application.
"""
self.__init_handle_by_constructor__(_make.TypeCall, func, args)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields : List[tvm.relay.Type]
The fields in the tuple
Returns
-------
tuple_type : tvm.relay.TupleType
the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
functions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
Parameters
----------
arg_types : List[tvm.relay.Type]
The argument types
ret_type : tvm.relay.Type
The return type.
type_params : Optional[List[tvm.relay.TypeVar]]
The type parameters
type_constraints : Optional[List[tvm.relay.TypeConstraint]]
The type constraints.
"""
def __init__(self,
arg_types,
ret_type,
type_params=None,
type_constraints=None):
if type_params is None:
type_params = []
if type_constraints is None:
type_constraints = []
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : [tvm.relay.Type]
List of types to the func.
num_inputs : int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
Returns
-------
type_relation : tvm.relay.TypeRelation
The type relation.
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
@register_relay_node
class RefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value)
def scalar_type(dtype): def scalar_type(dtype):
"""Creates a scalar type. """Creates a scalar type.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import Object, register_relay_node
from . import _make
class Type(Object):
"""The base type for all Relay types."""
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_eq(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters
----------
shape: list of tvm.Expr
dtype: str
Returns
-------
tensor_type: The TensorType
"""
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
ShapeVar = 0
Shape = 1
BaseType = 2
Type = 3
@register_relay_node
class TypeParam(Type):
"""A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details.
A type parameter represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
"""
def __init__(self, var, kind):
"""Construct a TypeParam.
Parameters
----------
var: tvm.expr.Var
The tvm.Var which backs the type parameter.
kind: Kind
The kind of the type parameter.
Returns
-------
type_param: TypeParam
The type parameter.
"""
self.__init_handle_by_constructor__(_make.TypeParam, var, kind)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields: list of tvm.Type
Returns
-------
tuple_type: the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
functions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
"""
def __init__(self,
arg_types,
ret_type,
type_params,
type_constraints,
):
"""Construct a function type.
Parameters
----------
arg_types: list of Type
ret_type: Type
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns
-------
func_type: FuncType
The function type.
"""
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
# pylint: disable=invalid-name, unused-argument # pylint: disable=invalid-name, unused-argument
"""FFI for tvm.runtime.extra""" """FFI for tvm.node"""
import tvm._ffi import tvm._ffi
# The implementations below are default ones when the corresponding # The implementations below are default ones when the corresponding
......
...@@ -19,11 +19,11 @@ import tvm._ffi ...@@ -19,11 +19,11 @@ import tvm._ffi
from tvm._ffi.base import string_types from tvm._ffi.base import string_types
from tvm.runtime import Object, convert from tvm.runtime import Object, convert
from tvm.ir import container as _container
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
from . import expr as _expr from . import expr as _expr
from . import container as _container
@tvm._ffi.register_object @tvm._ffi.register_object
......
...@@ -366,6 +366,14 @@ class Prefetch(Stmt): ...@@ -366,6 +366,14 @@ class Prefetch(Stmt):
_make.Prefetch, func, value_index, dtype, bounds) _make.Prefetch, func, value_index, dtype, bounds)
@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
def stmt_seq(*args): def stmt_seq(*args):
"""Make sequence of statements """Make sequence of statements
......
...@@ -38,7 +38,7 @@ Constructor::Constructor(std::string name_hint, ...@@ -38,7 +38,7 @@ Constructor::Constructor(std::string name_hint,
TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.Constructor") TVM_REGISTER_GLOBAL("ir.Constructor")
.set_body_typed([](std::string name_hint, .set_body_typed([](std::string name_hint,
tvm::Array<Type> inputs, tvm::Array<Type> inputs,
GlobalTypeVar belong_to) { GlobalTypeVar belong_to) {
...@@ -64,7 +64,7 @@ TypeData::TypeData(GlobalTypeVar header, ...@@ -64,7 +64,7 @@ TypeData::TypeData(GlobalTypeVar header,
TVM_REGISTER_NODE_TYPE(TypeDataNode); TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("relay._make.TypeData") TVM_REGISTER_GLOBAL("ir.TypeData")
.set_body_typed([](GlobalTypeVar header, .set_body_typed([](GlobalTypeVar header,
tvm::Array<TypeVar> type_vars, tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) { tvm::Array<Constructor> constructors) {
......
...@@ -334,7 +334,7 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const { ...@@ -334,7 +334,7 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict); return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
} }
TVM_REGISTER_GLOBAL("_AttrsListFieldInfo") TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Attrs()->ListFieldInfo(); *ret = args[0].operator Attrs()->ListFieldInfo();
}); });
......
...@@ -50,10 +50,10 @@ EnvFunc EnvFunc::Get(const std::string& name) { ...@@ -50,10 +50,10 @@ EnvFunc EnvFunc::Get(const std::string& name) {
return EnvFunc(CreateEnvNode(name)); return EnvFunc(CreateEnvNode(name));
} }
TVM_REGISTER_GLOBAL("_EnvFuncGet") TVM_REGISTER_GLOBAL("ir.EnvFuncGet")
.set_body_typed(EnvFunc::Get); .set_body_typed(EnvFunc::Get);
TVM_REGISTER_GLOBAL("_EnvFuncCall") TVM_REGISTER_GLOBAL("ir.EnvFuncCall")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
EnvFunc env = args[0]; EnvFunc env = args[0];
CHECK_GE(args.size(), 1); CHECK_GE(args.size(), 1);
...@@ -62,7 +62,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall") ...@@ -62,7 +62,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall")
args.size() - 1), rv); args.size() - 1), rv);
}); });
TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc") TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc")
.set_body_typed([](const EnvFunc&n) { .set_body_typed([](const EnvFunc&n) {
return n->func; return n->func;
}); });
......
...@@ -154,7 +154,7 @@ GlobalVar::GlobalVar(std::string name_hint) { ...@@ -154,7 +154,7 @@ GlobalVar::GlobalVar(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalVar") TVM_REGISTER_GLOBAL("ir.GlobalVar")
.set_body_typed([](std::string name){ .set_body_typed([](std::string name){
return GlobalVar(name); return GlobalVar(name);
}); });
......
...@@ -338,13 +338,13 @@ IRModule IRModule::FromText(const std::string& text, const std::string& source_p ...@@ -338,13 +338,13 @@ IRModule IRModule::FromText(const std::string& text, const std::string& source_p
TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_NODE_TYPE(IRModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module") TVM_REGISTER_GLOBAL("ir.IRModule")
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types) { tvm::Map<GlobalTypeVar, TypeData> types) {
return IRModule(funcs, types, {}); return IRModule(funcs, types, {});
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_Add") TVM_REGISTER_GLOBAL("ir.Module_Add")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IRModule mod = args[0]; IRModule mod = args[0];
GlobalVar var = args[1]; GlobalVar var = args[1];
...@@ -369,67 +369,67 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add") ...@@ -369,67 +369,67 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
*ret = mod; *ret = mod;
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_AddDef") TVM_REGISTER_GLOBAL("ir.Module_AddDef")
.set_body_method<IRModule>(&IRModuleNode::AddTypeDef); .set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar") TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar); .set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVars") TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVars); .set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVars") TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars); .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
TVM_REGISTER_GLOBAL("relay._module.Module_ContainGlobalVar") TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar); .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar") TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar); .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup") TVM_REGISTER_GLOBAL("ir.Module_Lookup")
.set_body_typed([](IRModule mod, GlobalVar var) { .set_body_typed([](IRModule mod, GlobalVar var) {
return mod->Lookup(var); return mod->Lookup(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
.set_body_typed([](IRModule mod, std::string var) { .set_body_typed([](IRModule mod, std::string var) {
return mod->Lookup(var); return mod->Lookup(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") TVM_REGISTER_GLOBAL("ir.Module_LookupDef")
.set_body_typed([](IRModule mod, GlobalTypeVar var) { .set_body_typed([](IRModule mod, GlobalTypeVar var) {
return mod->LookupTypeDef(var); return mod->LookupTypeDef(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
.set_body_typed([](IRModule mod, std::string var) { .set_body_typed([](IRModule mod, std::string var) {
return mod->LookupTypeDef(var); return mod->LookupTypeDef(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") TVM_REGISTER_GLOBAL("ir.Module_LookupTag")
.set_body_typed([](IRModule mod, int32_t tag) { .set_body_typed([](IRModule mod, int32_t tag) {
return mod->LookupTag(tag); return mod->LookupTag(tag);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
.set_body_typed([](RelayExpr e, .set_body_typed([](RelayExpr e,
tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) { tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return IRModule::FromExpr(e, funcs, type_defs); return IRModule::FromExpr(e, funcs, type_defs);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_Update") TVM_REGISTER_GLOBAL("ir.Module_Update")
.set_body_typed([](IRModule mod, IRModule from) { .set_body_typed([](IRModule mod, IRModule from) {
mod->Update(from); mod->Update(from);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_Import") TVM_REGISTER_GLOBAL("ir.Module_Import")
.set_body_typed([](IRModule mod, std::string path) { .set_body_typed([](IRModule mod, std::string path) {
mod->Import(path); mod->Import(path);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
.set_body_typed([](IRModule mod, std::string path) { .set_body_typed([](IRModule mod, std::string path) {
mod->ImportFromStd(path); mod->ImportFromStd(path);
});; });;
......
...@@ -45,7 +45,7 @@ SourceName SourceName::Get(const std::string& name) { ...@@ -45,7 +45,7 @@ SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name)); return SourceName(GetSourceNameNode(name));
} }
TVM_REGISTER_GLOBAL("relay._make.SourceName") TVM_REGISTER_GLOBAL("ir.SourceName")
.set_body_typed(SourceName::Get); .set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -70,7 +70,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { ...@@ -70,7 +70,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span") TVM_REGISTER_GLOBAL("ir.Span")
.set_body_typed(SpanNode::make); .set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
...@@ -55,7 +55,7 @@ PrimExpr TensorTypeNode::Size() const { ...@@ -55,7 +55,7 @@ PrimExpr TensorTypeNode::Size() const {
TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TensorType") TVM_REGISTER_GLOBAL("ir.TensorType")
.set_body_typed([](Array<PrimExpr> shape, DataType dtype) { .set_body_typed([](Array<PrimExpr> shape, DataType dtype) {
return TensorType(shape, dtype); return TensorType(shape, dtype);
}); });
......
...@@ -300,10 +300,15 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const { ...@@ -300,10 +300,15 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const {
Pass GetPass(const std::string& pass_name) { Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name; const runtime::PackedFunc* f = nullptr;
const auto* f = Registry::Get(fpass_name); if (pass_name.find("transform.") != std::string::npos) {
CHECK(f != nullptr) << "Cannot find " << fpass_name f = Registry::Get(pass_name);
<< "to create the pass " << pass_name; } else if ((f = Registry::Get("transform." + pass_name))) {
// pass
} else if ((f = Registry::Get("relay._transform." + pass_name))) {
}
CHECK(f != nullptr) << "Cannot use " << pass_name
<< "to create the pass";
return (*f)(); return (*f)();
} }
...@@ -311,7 +316,7 @@ Pass GetPass(const std::string& pass_name) { ...@@ -311,7 +316,7 @@ Pass GetPass(const std::string& pass_name) {
// a Sequential without the consideration of their orders. The phase // a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future. // ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(const IRModule& module, IRModule SequentialNode::operator()(const IRModule& module,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
IRModule mod = module; IRModule mod = module;
for (const Pass& pass : passes) { for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization."; CHECK(pass.defined()) << "Found undefined pass for optimization.";
...@@ -339,12 +344,12 @@ Pass CreateModulePass( ...@@ -339,12 +344,12 @@ Pass CreateModulePass(
TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("relay._transform.PassInfo") TVM_REGISTER_GLOBAL("transform.PassInfo")
.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) { .set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
return PassInfo(opt_level, name, required); return PassInfo(opt_level, name, required);
}); });
TVM_REGISTER_GLOBAL("relay._transform.Info") TVM_REGISTER_GLOBAL("transform.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
*ret = pass->Info(); *ret = pass->Info();
...@@ -366,14 +371,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -366,14 +371,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass") TVM_REGISTER_GLOBAL("transform.MakeModulePass")
.set_body_typed( .set_body_typed(
[](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func, [](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) { PassInfo pass_info) {
return ModulePass(pass_func, pass_info); return ModulePass(pass_func, pass_info);
}); });
TVM_REGISTER_GLOBAL("relay._transform.RunPass") TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
IRModule mod = args[1]; IRModule mod = args[1];
...@@ -390,7 +395,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -390,7 +395,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(SequentialNode); TVM_REGISTER_NODE_TYPE(SequentialNode);
TVM_REGISTER_GLOBAL("relay._transform.Sequential") TVM_REGISTER_GLOBAL("transform.Sequential")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0]; tvm::Array<Pass> passes = args[0];
int opt_level = args[1]; int opt_level = args[1];
...@@ -416,7 +421,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -416,7 +421,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_GLOBAL("relay._transform.PassContext") TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
auto pctx = PassContext::Create(); auto pctx = PassContext::Create();
int opt_level = args[0]; int opt_level = args[0];
...@@ -465,13 +470,13 @@ class PassContext::Internal { ...@@ -465,13 +470,13 @@ class PassContext::Internal {
} }
}; };
TVM_REGISTER_GLOBAL("relay._transform.GetCurrentPassContext") TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext")
.set_body_typed(PassContext::Current); .set_body_typed(PassContext::Current);
TVM_REGISTER_GLOBAL("relay._transform.EnterPassContext") TVM_REGISTER_GLOBAL("transform.EnterPassContext")
.set_body_typed(PassContext::Internal::EnterScope); .set_body_typed(PassContext::Internal::EnterScope);
TVM_REGISTER_GLOBAL("relay._transform.ExitPassContext") TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope); .set_body_typed(PassContext::Internal::ExitScope);
} // namespace transform } // namespace transform
......
...@@ -33,7 +33,7 @@ PrimType::PrimType(runtime::DataType dtype) { ...@@ -33,7 +33,7 @@ PrimType::PrimType(runtime::DataType dtype) {
TVM_REGISTER_NODE_TYPE(PrimTypeNode); TVM_REGISTER_NODE_TYPE(PrimTypeNode);
TVM_REGISTER_GLOBAL("relay._make.PrimType") TVM_REGISTER_GLOBAL("ir.PrimType")
.set_body_typed([](runtime::DataType dtype) { .set_body_typed([](runtime::DataType dtype) {
return PrimType(dtype); return PrimType(dtype);
}); });
...@@ -54,7 +54,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) { ...@@ -54,7 +54,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.TypeVar") TVM_REGISTER_GLOBAL("ir.TypeVar")
.set_body_typed([](std::string name, int kind) { .set_body_typed([](std::string name, int kind) {
return TypeVar(name, static_cast<TypeKind>(kind)); return TypeVar(name, static_cast<TypeKind>(kind));
}); });
...@@ -76,7 +76,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { ...@@ -76,7 +76,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") TVM_REGISTER_GLOBAL("ir.GlobalTypeVar")
.set_body_typed([](std::string name, int kind) { .set_body_typed([](std::string name, int kind) {
return GlobalTypeVar(name, static_cast<TypeKind>(kind)); return GlobalTypeVar(name, static_cast<TypeKind>(kind));
}); });
...@@ -102,7 +102,7 @@ FuncType::FuncType(tvm::Array<Type> arg_types, ...@@ -102,7 +102,7 @@ FuncType::FuncType(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType") TVM_REGISTER_GLOBAL("ir.FuncType")
.set_body_typed([](tvm::Array<Type> arg_types, .set_body_typed([](tvm::Array<Type> arg_types,
Type ret_type, Type ret_type,
tvm::Array<TypeVar> type_params, tvm::Array<TypeVar> type_params,
...@@ -131,7 +131,7 @@ TupleType TupleType::Empty() { ...@@ -131,7 +131,7 @@ TupleType TupleType::Empty() {
TVM_REGISTER_NODE_TYPE(TupleTypeNode); TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TupleType") TVM_REGISTER_GLOBAL("ir.TupleType")
.set_body_typed([](Array<Type> fields) { .set_body_typed([](Array<Type> fields) {
return TupleType(fields); return TupleType(fields);
}); });
...@@ -151,7 +151,7 @@ IncompleteType::IncompleteType(TypeKind kind) { ...@@ -151,7 +151,7 @@ IncompleteType::IncompleteType(TypeKind kind) {
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_GLOBAL("relay._make.IncompleteType") TVM_REGISTER_GLOBAL("ir.IncompleteType")
.set_body_typed([](int kind) { .set_body_typed([](int kind) {
return IncompleteType(static_cast<TypeKind>(kind)); return IncompleteType(static_cast<TypeKind>(kind));
}); });
...@@ -169,7 +169,7 @@ RelayRefType::RelayRefType(Type value) { ...@@ -169,7 +169,7 @@ RelayRefType::RelayRefType(Type value) {
data_ = std::move(n); data_ = std::move(n);
} }
TVM_REGISTER_GLOBAL("relay._make.RefType") TVM_REGISTER_GLOBAL("ir.RelayRefType")
.set_body_typed([](Type value) { .set_body_typed([](Type value) {
return RelayRefType(value); return RelayRefType(value);
}); });
......
...@@ -35,7 +35,7 @@ TypeCall::TypeCall(Type func, tvm::Array<Type> args) { ...@@ -35,7 +35,7 @@ TypeCall::TypeCall(Type func, tvm::Array<Type> args) {
TVM_REGISTER_NODE_TYPE(TypeCallNode); TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_GLOBAL("relay._make.TypeCall") TVM_REGISTER_GLOBAL("ir.TypeCall")
.set_body_typed([](Type func, Array<Type> type) { .set_body_typed([](Type func, Array<Type> type) {
return TypeCall(func, type); return TypeCall(func, type);
}); });
...@@ -61,7 +61,7 @@ TypeRelation::TypeRelation(TypeRelationFn func, ...@@ -61,7 +61,7 @@ TypeRelation::TypeRelation(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("relay._make.TypeRelation") TVM_REGISTER_GLOBAL("ir.TypeRelation")
.set_body_typed([](TypeRelationFn func, .set_body_typed([](TypeRelationFn func,
Array<Type> args, Array<Type> args,
int num_inputs, int num_inputs,
......
...@@ -131,6 +131,7 @@ class RelayTextPrinter : ...@@ -131,6 +131,7 @@ class RelayTextPrinter :
} else if (node.as<IRModuleNode>()) { } else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<IRModule>(node)); return PrintMod(Downcast<IRModule>(node));
} else { } else {
// default module.
std::ostringstream os; std::ostringstream os;
os << node; os << node;
return Doc() << os.str(); return Doc() << os.str();
...@@ -905,20 +906,18 @@ static const char* kSemVer = "v0.0.4"; ...@@ -905,20 +906,18 @@ static const char* kSemVer = "v0.0.4";
// - relay_text_printer.cc (specific printing logics for relay) // - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR) // - tir_text_printer.cc (specific printing logics for TIR)
std::string PrettyPrint(const ObjectRef& node) { std::string PrettyPrint(const ObjectRef& node) {
Doc doc; return AsText(node, false, nullptr);
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
} }
std::string AsText(const ObjectRef& node, std::string AsText(const ObjectRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) { runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) {
Doc doc; Doc doc;
doc << kSemVer << Doc::NewLine() doc << kSemVer << Doc::NewLine();
<< relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str(); return doc.str();
} }
TVM_REGISTER_GLOBAL("relay._expr.AsText") TVM_REGISTER_GLOBAL("ir.AsText")
.set_body_typed(AsText); .set_body_typed(AsText);
} // namespace tvm } // namespace tvm
...@@ -599,6 +599,11 @@ TVM_REGISTER_GLOBAL("relay._make._alpha_equal") ...@@ -599,6 +599,11 @@ TVM_REGISTER_GLOBAL("relay._make._alpha_equal")
return AlphaEqualHandler(false, false).Equal(a, b); return AlphaEqualHandler(false, false).Equal(a, b);
}); });
TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
.set_body_typed([](Type a, Type b) {
return AlphaEqual(a, b);
});
TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal") TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
......
...@@ -33,7 +33,7 @@ using namespace tvm::runtime; ...@@ -33,7 +33,7 @@ using namespace tvm::runtime;
TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_GLOBAL("relay._base.set_span") TVM_REGISTER_GLOBAL("ir.NodeSetSpan")
.set_body_typed([](ObjectRef node_ref, Span sp) { .set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) { if (auto* rn = node_ref.as<RelayNode>()) {
rn->span = sp; rn->span = sp;
......
...@@ -84,7 +84,7 @@ def check_server_drop(): ...@@ -84,7 +84,7 @@ def check_server_drop():
f1 = remote2.get_function("rpc.test2.addone") f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11 assert f1(10) == 11
except tvm.TVMError as e: except tvm.error.TVMError as e:
pass pass
remote3 = tclient.request("abc") remote3 = tclient.request("abc")
f1 = remote3.get_function("rpc.test2.addone") f1 = remote3.get_function("rpc.test2.addone")
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import mxnet as mx import mxnet as mx
import tvm
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
import model_zoo import model_zoo
...@@ -99,7 +101,7 @@ def test_multi_outputs(): ...@@ -99,7 +101,7 @@ def test_multi_outputs():
z = F.split(x, **kwargs) z = F.split(x, **kwargs)
z = F.subtract(F.add(z[0], z[2]), y) z = F.subtract(F.add(z[0], z[2]), y)
func = relay.Function(relay.analysis.free_vars(z), z) func = relay.Function(relay.analysis.free_vars(z), z)
return relay.Module.from_expr(func) return tvm.IRModule.from_expr(func)
mx_sym = mx_compose(mx, num_outputs=3, axis=1) mx_sym = mx_compose(mx, num_outputs=3, axis=1)
mod, _ = relay.frontend.from_mxnet( mod, _ = relay.frontend.from_mxnet(
......
...@@ -34,7 +34,7 @@ def test_mkldnn_dequantize(): ...@@ -34,7 +34,7 @@ def test_mkldnn_dequantize():
max_range=max_range, max_range=max_range,
in_dtype=in_dtype) in_dtype=in_dtype)
mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output) mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output)
mod = relay.Module.from_expr(mod) mod = tvm.IRModule.from_expr(mod)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None) graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
...@@ -90,7 +90,7 @@ def test_mkldnn_quantize(): ...@@ -90,7 +90,7 @@ def test_mkldnn_quantize():
max_range=max_range, max_range=max_range,
out_dtype=out_dtype) out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod) mod = tvm.IRModule.from_expr(mod)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None) graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
......
...@@ -23,7 +23,7 @@ from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_val ...@@ -23,7 +23,7 @@ from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_val
import numpy as np import numpy as np
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
add_nat_definitions(p) add_nat_definitions(p)
...@@ -730,7 +730,7 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", ...@@ -730,7 +730,7 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32",
def test_tensor_expand_dims(): def test_tensor_expand_dims():
def run(dtype): def run(dtype):
x = relay.var('x') x = relay.var('x')
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
expand_dims_func = p.get_var('tensor_expand_dims', dtype) expand_dims_func = p.get_var('tensor_expand_dims', dtype)
tensor1 = p.get_var('tensor1', dtype) tensor1 = p.get_var('tensor1', dtype)
...@@ -745,7 +745,7 @@ def test_tensor_expand_dims(): ...@@ -745,7 +745,7 @@ def test_tensor_expand_dims():
def test_tensor_array_constructor(): def test_tensor_array_constructor():
def run(dtype): def run(dtype):
x = relay.var('x') x = relay.var('x')
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype) tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([x], tensor_array(x)) mod["main"] = relay.Function([x], tensor_array(x))
...@@ -757,7 +757,7 @@ def test_tensor_array_constructor(): ...@@ -757,7 +757,7 @@ def test_tensor_array_constructor():
def test_tensor_array_read(): def test_tensor_array_read():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
l = relay.var('l') l = relay.var('l')
i = relay.var('i') i = relay.var('i')
...@@ -773,7 +773,7 @@ def test_tensor_array_read(): ...@@ -773,7 +773,7 @@ def test_tensor_array_read():
def test_tensor_array_write(): def test_tensor_array_write():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
v1 = relay.var('v1') v1 = relay.var('v1')
v2 = relay.var('v2') v2 = relay.var('v2')
...@@ -793,7 +793,7 @@ def test_tensor_array_write(): ...@@ -793,7 +793,7 @@ def test_tensor_array_write():
def test_tensor_array_stack(): def test_tensor_array_stack():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype) tensor_array = p.get_var('tensor_array', dtype)
tensor1 = p.get_var('tensor1', dtype) tensor1 = p.get_var('tensor1', dtype)
...@@ -815,7 +815,7 @@ def test_tensor_array_stack(): ...@@ -815,7 +815,7 @@ def test_tensor_array_stack():
def test_tensor_array_unstack(): def test_tensor_array_unstack():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype) unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
v = relay.var('v') v = relay.var('v')
...@@ -828,7 +828,7 @@ def test_tensor_array_unstack(): ...@@ -828,7 +828,7 @@ def test_tensor_array_unstack():
def test_tensor_take(): def test_tensor_take():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
take = p.get_var('tensor_take', dtype) take = p.get_var('tensor_take', dtype)
tensor2 = p.get_var('tensor2', dtype) tensor2 = p.get_var('tensor2', dtype)
...@@ -847,7 +847,7 @@ def test_tensor_take(): ...@@ -847,7 +847,7 @@ def test_tensor_take():
def test_tensor_concatenate(): def test_tensor_concatenate():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
concat = p.get_var('tensor_concatenate', dtype) concat = p.get_var('tensor_concatenate', dtype)
tensor1 = p.get_var('tensor1', dtype) tensor1 = p.get_var('tensor1', dtype)
...@@ -865,7 +865,7 @@ def test_tensor_concatenate(): ...@@ -865,7 +865,7 @@ def test_tensor_concatenate():
def test_tensor_array_concat(): def test_tensor_array_concat():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
v1 = relay.var('v1') v1 = relay.var('v1')
v2 = relay.var('v2') v2 = relay.var('v2')
...@@ -888,9 +888,9 @@ def test_tensor_array_concat(): ...@@ -888,9 +888,9 @@ def test_tensor_array_concat():
def test_tensor_array_scatter(): def test_tensor_array_scatter():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
# tensor array # tensor array
v1 = relay.var('v1') v1 = relay.var('v1')
v2 = relay.var('v2') v2 = relay.var('v2')
...@@ -938,9 +938,9 @@ def test_tensor_array_scatter(): ...@@ -938,9 +938,9 @@ def test_tensor_array_scatter():
def test_tensor_array_split(): def test_tensor_array_split():
def run(dtype): def run(dtype):
mod = relay.Module() mod = tvm.IRModule()
p = Prelude(mod) p = Prelude(mod)
# tensor array # tensor array
v1 = relay.var('v1') v1 = relay.var('v1')
v2 = relay.var('v2') v2 = relay.var('v2')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment