Unverified Commit c76fce9f by Tianqi Chen Committed by GitHub

[RELAY] BiasAdd, MLP, Resnet testing (#1969)

* [RELAY] BiasAdd, MLP, Resnet testing

* fix review comments
parent 399b39f1
...@@ -40,6 +40,8 @@ This level enables fully connected multi-layer perceptron. ...@@ -40,6 +40,8 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.nn.relu tvm.relay.nn.relu
tvm.relay.nn.dropout tvm.relay.nn.dropout
tvm.relay.nn.batch_norm tvm.relay.nn.batch_norm
tvm.relay.nn.bias_add
**Level 2: Convolutions** **Level 2: Convolutions**
...@@ -85,8 +87,13 @@ This level enables additional math and transform operators. ...@@ -85,8 +87,13 @@ This level enables additional math and transform operators.
tvm.relay.abs tvm.relay.abs
tvm.relay.negative tvm.relay.negative
tvm.relay.take tvm.relay.take
tvm.relay.zeros
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.full tvm.relay.full
tvm.relay.full_like tvm.relay.full_like
tvm.relay.cast
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
...@@ -151,6 +158,9 @@ Level 1 Definitions ...@@ -151,6 +158,9 @@ Level 1 Definitions
.. autofunction:: tvm.relay.nn.softmax .. autofunction:: tvm.relay.nn.softmax
.. autofunction:: tvm.relay.nn.log_softmax .. autofunction:: tvm.relay.nn.log_softmax
.. autofunction:: tvm.relay.nn.relu .. autofunction:: tvm.relay.nn.relu
.. autofunction:: tvm.relay.nn.dropout
.. autofunction:: tvm.relay.nn.batch_norm
.. autofunction:: tvm.relay.nn.bias_add
Level 2 Definitions Level 2 Definitions
...@@ -185,6 +195,9 @@ Level 3 Definitions ...@@ -185,6 +195,9 @@ Level 3 Definitions
.. autofunction:: tvm.relay.zeros_like .. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones .. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like .. autofunction:: tvm.relay.ones_like
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
Level 4 Definitions Level 4 Definitions
......
...@@ -12,6 +12,23 @@ ...@@ -12,6 +12,23 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*!
* \brief Add a 1D Tensor to an axis of a data.
*
* \note bias_add is a special add operator that is in nn
* and enables automatic derivation of bias's shape.
* You can directly use add for more generalized case.
*/
struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
int axis;
TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The axis to add the bias")
.set_default(1);
}
};
/*! \brief Attributes used in convolution operators */ /*! \brief Attributes used in convolution operators */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides; Array<IndexExpr> strides;
......
...@@ -12,6 +12,16 @@ ...@@ -12,6 +12,16 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! \brief data type cast */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
DataType dtype;
TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("Target data type");
}
}; // struct CastAttrs.
/*! \brief Attributes used in expand_dims operators */ /*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> { struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
int axis; int axis;
......
...@@ -112,15 +112,17 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -112,15 +112,17 @@ class ExprFunctor<R(const Expr& n, Args...)> {
} }
}; };
/*! \brief A simple visitor wrapper around ExprFunctor. /*!
* \brief A simple visitor wrapper around ExprFunctor.
* Recursively visit the content.
* *
* Exposes two visitors with default traversal strategies, one * ExprVisitor treats Expr as dataflow graph,
* which doesn't compute a result but can mutate internal state, * and only visit each Expr node once.
* and another which functionally builds a new Expr.
*/ */
class ExprVisitor
class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
public: public:
void VisitExpr(const Expr& expr) override;
void VisitExpr_(const VarNode* op) override; void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const GlobalVarNode* op) override; void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override; void VisitExpr_(const ConstantNode* op) override;
...@@ -132,13 +134,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { ...@@ -132,13 +134,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const OpNode* op) override; void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t); virtual void VisitType(const Type& t);
private:
// internal visited flag.
std::unordered_set<const Node*> visited_;
}; };
/*! \brief A wrapper around ExprFunctor which functionally updates the AST. /*!
* * \brief A wrapper around ExprFunctor which functionally updates the AST.
* ExprMutator uses memoization and self return in order to amortize *
* the cost of using functional updates. * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
*/ * The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*/
class ExprMutator class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public: public:
......
...@@ -102,35 +102,26 @@ bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -102,35 +102,26 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/ */
bool WellFormed(const Expr& e); bool WellFormed(const Expr& e);
/*! \brief Get free variables from expression e. /*! \brief Get free Vars from expr in PostDFS order.
* *
* Free variables are variables that are not bound by a let or a function parameter in the context. * Free variables are variables that are not bound by a
* let or a function parameter in the context.
* *
* \param e the expression. * \param expr the expression.
* *
* \return the set of free variable. * \return List of free vars, in the PostDFS order visited by expr.
*/ */
tvm::Array<Var> FreeVariables(const Expr& e); tvm::Array<Var> FreeVars(const Expr& expr);
/*! \brief Get free type parameters from expression e. /*! \brief Get free TypeVars from expression expr.
* *
* Free type parameters are type parameters that are not bound by a function type in the context. * Free type parameters are type parameters that are not bound by a function type in the context.
* *
* \param e the expression. * \param expr the expression.
* *
* \return the set of free type variables. * \return List of free vars, in the PostDFS order visited by expr.
*/ */
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e); tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
/*! \brief Get free type parameters from type t.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param t the type.
*
* \return the set of free type variables.
*/
tvm::Array<TypeVar> FreeTypeVariables(const Type& t);
/*! \brief Remove expressions which does not effect the program result. /*! \brief Remove expressions which does not effect the program result.
* *
......
...@@ -299,6 +299,9 @@ class IntImm(ConstExpr): ...@@ -299,6 +299,9 @@ class IntImm(ConstExpr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.IntImm, dtype, value) _make.IntImm, dtype, value)
def __int__(self):
return self.value
@register_node @register_node
class UIntImm(ConstExpr): class UIntImm(ConstExpr):
......
...@@ -6,7 +6,7 @@ import numpy as _np ...@@ -6,7 +6,7 @@ import numpy as _np
from .base import RelayNode, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import ty as _ty from . import ty as _ty
from .._ffi import base as _base, node as _node from .._ffi import base as _base
from .. import nd as _nd from .. import nd as _nd
from .. import convert from .. import convert
...@@ -28,6 +28,25 @@ class Expr(RelayNode): ...@@ -28,6 +28,25 @@ class Expr(RelayNode):
" the checked_type for this node") " the checked_type for this node")
return ret return ret
def astype(self, dtype):
"""Cast the content type of the current data to dtype.
Parameters
----------
dtype : str
The target data type.
Note
----
This function only works for TensorType Exprs.
Returns
-------
result : tvm.relay.Expr
The result expression.
"""
return _make.dtype_cast(self, dtype)
@register_relay_node @register_relay_node
class Constant(Expr): class Constant(Expr):
...@@ -62,6 +81,9 @@ class Tuple(Expr): ...@@ -62,6 +81,9 @@ class Tuple(Expr):
def __len__(self): def __len__(self):
return len(self.fields) return len(self.fields)
def astype(self, _):
raise TypeError("astype cannot be used on tuple")
@register_relay_node @register_relay_node
class Var(Expr): class Var(Expr):
...@@ -238,7 +260,7 @@ class TupleGetItem(Expr): ...@@ -238,7 +260,7 @@ class TupleGetItem(Expr):
_make.TupleGetItem, tuple_value, index) _make.TupleGetItem, tuple_value, index)
class TupleWrapper(_node.NodeGeneric): class TupleWrapper(object):
"""TupleWrapper. """TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size. This class is a Python wrapper for a Relay tuple of known size.
...@@ -257,10 +279,9 @@ class TupleWrapper(_node.NodeGeneric): ...@@ -257,10 +279,9 @@ class TupleWrapper(_node.NodeGeneric):
self.tuple_value = tuple_value self.tuple_value = tuple_value
self.size = size self.size = size
def asnode(self): def astuple(self):
"""Returns the underlying Relay tuple if this wrapper is passed """Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function.""" as an argument to an FFI function."""
return self.tuple_value return self.tuple_value
def __getitem__(self, index): def __getitem__(self, index):
...@@ -275,6 +296,9 @@ class TupleWrapper(_node.NodeGeneric): ...@@ -275,6 +296,9 @@ class TupleWrapper(_node.NodeGeneric):
return ("TupleWrapper(" + self.tuple_value.__repr__() + return ("TupleWrapper(" + self.tuple_value.__repr__() +
", " + self.size + ")") ", " + self.size + ")")
def astype(self, _):
raise TypeError("astype cannot be used on tuple")
def var(name_hint, def var(name_hint,
type_annotation=None, type_annotation=None,
......
...@@ -40,7 +40,7 @@ def well_formed(expr): ...@@ -40,7 +40,7 @@ def well_formed(expr):
Returns Returns
------- -------
well_form : bool well_form : bool
whether the input expression is well formed Whether the input expression is well formed
""" """
return _ir_pass.well_formed(expr) return _ir_pass.well_formed(expr)
...@@ -75,20 +75,26 @@ def check_kind(t, env=None): ...@@ -75,20 +75,26 @@ def check_kind(t, env=None):
return _ir_pass.check_kind(t) return _ir_pass.check_kind(t)
def free_vars(e): def free_vars(expr):
"""Get free variables from expression e. """Get free Vars from expression expr in Post DFS order.
Parameters Parameters
---------- ----------
e: tvm.relay.Expr expr: tvm.relay.Expr
The input expression The input expression
Returns Returns
------- -------
free : List[tvm.relay.Var] free : List[tvm.relay.Var]
The list of free variables The list of free variables in post DFS order.
Note
----
The fact that Vars are post-DFS ordred are useful in
neural networks: usually this means weights of previous
are ordered first.
""" """
return _ir_pass.free_vars(e) return _ir_pass.free_vars(expr)
def free_type_vars(expr): def free_type_vars(expr):
......
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
from . import mlp
from . import resnet
"""Initializer of parameters."""
import tvm
from tvm import relay
import numpy as np
class Initializer(object):
"""The base class of an initializer."""
def __init__(self, **kwargs):
self._kwargs = kwargs
def __call__(self, desc, arr):
"""Initialize an array
Parameters
----------
desc : str
Initialization pattern descriptor.
arr : NDArray
The array to be initialized.
"""
if desc.endswith('weight'):
self._init_weight(desc, arr)
elif desc.endswith('bias'):
self._init_bias(desc, arr)
elif desc.endswith('gamma'):
self._init_gamma(desc, arr)
elif desc.endswith('beta'):
self._init_beta(desc, arr)
elif desc.endswith('mean'):
self._init_mean(desc, arr)
elif desc.endswith('var'):
self._init_var(desc, arr)
else:
self._init_default(desc, arr)
def _init_bias(self, _, arr):
arr[:] = 0.0
def _init_gamma(self, _, arr):
arr[:] = 1.0
def _init_beta(self, _, arr):
arr[:] = 0.0
def _init_mean(self, _, arr):
arr[:] = 0.0
def _init_var(self, _, arr):
arr[:] = 1.0
def _init_weight(self, name, arr):
"""Abstract method to Initialize weight."""
raise NotImplementedError("Must override it")
def _init_default(self, name, _):
raise ValueError(
'Unknown initialization pattern for %s. ' \
'Default initialization is now limited to '\
'"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \
'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name)
class Xavier(Initializer):
""" "Xavier" initialization for weights
Parameters
----------
rnd_type: str, optional
Random generator type, can be ``'gaussian'`` or ``'uniform'``.
factor_type: str, optional
Can be ``'avg'``, ``'in'``, or ``'out'``.
magnitude: float, optional
Scale of random number.
"""
def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3):
super(Xavier, self).__init__(rnd_type=rnd_type,
factor_type=factor_type,
magnitude=magnitude)
self.rnd_type = rnd_type
self.factor_type = factor_type
self.magnitude = float(magnitude)
def _init_weight(self, name, arr):
shape = arr.shape
hw_scale = 1.
if len(shape) < 2:
raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at'
' least 2D.'.format(name))
if len(shape) > 2:
hw_scale = np.prod(shape[2:])
fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale
factor = 1.
if self.factor_type == "avg":
factor = (fan_in + fan_out) / 2.0
elif self.factor_type == "in":
factor = fan_in
elif self.factor_type == "out":
factor = fan_out
else:
raise ValueError("Incorrect factor type")
# Hack for mobilenet, because there is less connectivity
if "depthwise" in name:
factor = 3 * 3
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == "uniform":
arr[:] = np.random.uniform(-scale, scale, size=arr.shape)
else:
raise ValueError("Unknown random type")
def create_workload(net, initializer=None, seed=0):
"""Helper function to create benchmark image classification workload.
Parameters
----------
net : tvm.relay.Function
The selected function of the network.
initializer : Initializer
The initializer used
seed : int
The seed used in initialization.
Returns
-------
net : tvm.relay.Function
The updated dataflow
params : dict of str to NDArray
The parameters.
"""
net = relay.ir_pass.infer_type(net)
shape_dict = {
v.name_hint : v.checked_type for v in net.params}
net.astext()
np.random.seed(seed)
initializer = initializer if initializer else Xavier()
params = {}
for k, v in shape_dict.items():
if k == "data":
continue
init_value = np.zeros(v.concrete_shape).astype(v.dtype)
initializer(k, init_value)
params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0))
return net, params
"""Simple Layer DSL wrapper to ease creation of neural nets."""
from tvm import relay
def batch_norm_infer(data,
gamma=None,
beta=None,
moving_mean=None,
moving_var=None,
**kwargs):
"""Wrapper of batch_norm.
This function automatically creates weights and return
the first output(normalized result).
Parameters
----------
data : relay.Expr
The input expression.
gamma : relay.Expr
The gamma scale factor.
beta : relay.Expr
The beta offset factor.
moving_mean : relay.Expr
Running mean of input,
moving_var : relay.Expr
Running variance of input.
kwargs : dict
Additional arguments.
Returns
-------
result : relay.Expr
The result.
"""
name = kwargs.get("name")
kwargs.pop("name")
if not gamma:
gamma = relay.var(name + "_gamma")
if not beta:
beta = relay.var(name + "_beta")
if not moving_mean:
moving_mean = relay.var(name + "_moving_mean")
if not moving_var:
moving_var = relay.var(name + "_moving_var")
return relay.nn.batch_norm(data,
gamma=gamma,
beta=beta,
moving_mean=moving_mean,
moving_var=moving_var,
**kwargs)[0]
def conv2d(data, weight=None, **kwargs):
"""Wrapper of conv2d which automatically creates weights if not given.
Parameters
----------
data : relay.Expr
The input expression.
weight : relay.Expr
The weight to conv2d.
kwargs : dict
Additional arguments.
Returns
-------
result : relay.Expr
The result.
"""
name = kwargs.get("name")
kwargs.pop("name")
if not weight:
weight = relay.var(name + "_weight")
return relay.nn.conv2d(data, weight, **kwargs)
def dense_add_bias(data, weight=None, bias=None, **kwargs):
"""Wrapper of dense which automatically creates weights if not given.
Parameters
----------
data : relay.Expr
The input expression.
weight : relay.Expr
The weight to conv2d.
bias : relay.Expr
The bias.
kwargs : dict
Additional arguments.
Returns
-------
result : relay.Expr
The result.
"""
name = kwargs.get("name")
kwargs.pop("name")
if not weight:
weight = relay.var(name + "_weight")
if not bias:
bias = relay.var(name + "_bias")
data = relay.nn.dense(data, weight, **kwargs)
data = relay.nn.bias_add(data, bias)
return data
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
from tvm import relay
from .init import create_workload
def get_net(batch_size,
num_classes=10,
image_shape=(1, 28, 28),
dtype="float32"):
"""Get network a simple multilayer perceptron.
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : relay.Function
The dataflow.
"""
data_shape = (batch_size,) + image_shape
data = relay.var("data",
shape=data_shape,
dtype=dtype)
data = relay.nn.batch_flatten(data)
fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"))
act1 = relay.nn.relu(fc1)
fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64)
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"))
act2 = relay.nn.relu(fc2)
fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes)
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"))
mlp = relay.nn.softmax(data=fc3)
args = relay.ir_pass.free_vars(mlp)
return relay.Function(args, mlp)
def get_workload(batch_size,
num_classes=10,
image_shape=(1, 28, 28),
dtype="float32"):
"""Get benchmark workload for a simple multilayer perceptron.
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : relay.Function
The dataflow.
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, num_classes, image_shape, dtype)
return create_workload(net)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
"""
# pylint: disable=unused-argument
from tvm import relay
from .init import create_workload
from . import layers
def residual_unit(data,
num_filter,
stride,
dim_match,
name,
bottle_neck=True):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : bool
True means channel number between input and output is the same,
otherwise means differ
name : str
Base name of the operators
"""
if bottle_neck:
bn1 = layers.batch_norm_infer(data=data,
epsilon=2e-5,
name=name + '_bn1')
act1 = relay.relu(data=bn1)
conv1 = layers.conv2d(
data=act1,
channels=int(num_filter*0.25),
kernel_size=(1, 1),
strides=stride,
padding=(0, 0),
name=name + '_conv1')
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = relay.relu(data=bn3)
conv3 = layers.conv2d(
data=act3, channels=num_filter, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc')
return relay.add(conv3, shortcut)
else:
bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = relay.nn.relu(data=bn1)
conv1 = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), name=name + '_conv1')
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc')
return relay.add(conv2, shortcut)
def resnet(units,
num_stages,
filter_list,
num_classes,
data_shape,
bottle_neck=True,
dtype="float32"):
"""Return ResNet Program.
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
data_shape : tuple of int.
The shape of input data.
bottle_neck : bool
Whether apply bottleneck transformation.
dtype : str
The global data type.
"""
num_unit = len(units)
assert num_unit == num_stages
data = relay.var("data", shape=data_shape, dtype=dtype)
data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data')
(_, _, height, _) = data_shape
if height <= 32: # such as cifar10
body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name="conv0")
else: # often expected to be 224 such as imagenet
body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(7, 7),
strides=(2, 2), padding=(3, 3), name="conv0")
body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0')
body = relay.nn.relu(data=body)
body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
for i in range(num_stages):
body = residual_unit(
body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck)
for j in range(units[i]-1):
body = residual_unit(
body, filter_list[i+1], (1, 1), True,
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck)
bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1')
relu1 = relay.nn.relu(data=bn1)
# Although kernel is not used here when global_pool=True, we should put one
pool1 = relay.nn.global_avg_pool2d(data=relu1)
flat = relay.nn.batch_flatten(data=pool1)
fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
net = relay.nn.softmax(data=fc1)
return relay.Function(relay.ir_pass.free_vars(net), net)
def get_net(batch_size,
num_classes,
num_layers=50,
image_shape=(3, 224, 224),
dtype="float32",
**kwargs):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
"""
(_, height, _) = image_shape
data_shape = (batch_size,) + image_shape
if height <= 28:
num_stages = 3
if (num_layers-2) % 9 == 0 and num_layers >= 164:
per_unit = [(num_layers-2)//9]
filter_list = [16, 64, 128, 256]
bottle_neck = True
elif (num_layers-2) % 6 == 0 and num_layers < 164:
per_unit = [(num_layers-2)//6]
filter_list = [16, 16, 32, 64]
bottle_neck = False
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
units = per_unit * num_stages
else:
if num_layers >= 50:
filter_list = [64, 256, 512, 1024, 2048]
bottle_neck = True
else:
filter_list = [64, 64, 128, 256, 512]
bottle_neck = False
num_stages = 4
if num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 50:
units = [3, 4, 6, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
return resnet(units=units,
num_stages=num_stages,
filter_list=filter_list,
num_classes=num_classes,
data_shape=data_shape,
bottle_neck=bottle_neck,
dtype=dtype)
def get_workload(batch_size=1,
num_classes=1000,
num_layers=18,
image_shape=(3, 224, 224),
dtype="float32",
**kwargs):
"""Get benchmark workload for resnet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
num_layers : int, optional
Number of layers
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : relay.Function
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size=batch_size,
num_classes=num_classes,
num_layers=num_layers,
image_shape=image_shape,
dtype=dtype,
**kwargs)
return create_workload(net)
...@@ -47,6 +47,21 @@ class TensorType(Type): ...@@ -47,6 +47,21 @@ class TensorType(Type):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype) _make.TensorType, shape, dtype)
@property
def concrete_shape(self):
"""Get shape of the type as concrete tuple of int.
Returns
-------
shape : List[int]
The concrete shape of the Type.
Raises
------
TypeError : If the shape is symbolic
"""
return tuple(int(x) for x in self.shape)
class Kind(IntEnum): class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape, """The kind of a type parameter, represents a variable shape,
......
...@@ -159,6 +159,13 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { ...@@ -159,6 +159,13 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type ExprMutator::VisitType(const Type& t) { return t; } Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) {
if (visited_.count(expr.get())) return;
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(expr);
visited_.insert(expr.get());
}
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
if (op->type_annotation.defined()) { if (op->type_annotation.defined()) {
this->VisitType(op->type_annotation); this->VisitType(op->type_annotation);
...@@ -197,8 +204,8 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { ...@@ -197,8 +204,8 @@ void ExprVisitor::VisitExpr_(const CallNode* op) {
} }
void ExprVisitor::VisitExpr_(const LetNode* op) { void ExprVisitor::VisitExpr_(const LetNode* op) {
this->VisitExpr(op->var);
this->VisitExpr(op->value); this->VisitExpr(op->value);
this->VisitExpr(op->var);
this->VisitExpr(op->body); this->VisitExpr(op->body);
} }
......
...@@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO ...@@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* *
* \code * \code
* *
* function(%x: Tensor[(meta.Variable(id=0),), float32]) { * fn (%x: Tensor[(meta.Variable(id=0),), float32]) {
* %x * %x
* } * }
* # Meta data section is a json-serialized string * # Meta data section is a json-serialized string
...@@ -154,7 +154,7 @@ class TextPrinter : ...@@ -154,7 +154,7 @@ class TextPrinter :
} }
void PrintFunc(const Function& func) { void PrintFunc(const Function& func) {
this->PrintFuncInternal("function", func); this->PrintFuncInternal("fn ", func);
stream_ << "\n"; stream_ << "\n";
} }
...@@ -343,7 +343,7 @@ class TextPrinter : ...@@ -343,7 +343,7 @@ class TextPrinter :
TextValue tuple = GetValue(op->tuple); TextValue tuple = GetValue(op->tuple);
TextValue id = this->AllocTempVar(); TextValue id = this->AllocTempVar();
this->PrintIndent(); this->PrintIndent();
stream_ << id << " = " << tuple << "[" << op->index << "]"; stream_ << id << " = " << tuple << "." << op->index << "";
this->PrintEndInst("\n"); this->PrintEndInst("\n");
return id; return id;
} }
...@@ -379,6 +379,17 @@ class TextPrinter : ...@@ -379,6 +379,17 @@ class TextPrinter :
os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]"; os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]";
} }
void VisitType_(const TupleTypeNode* node, std::ostream& os) final { // NOLINT(*)
os << "Tuple[";
for (size_t i = 0; i < node->fields.size(); ++i) {
this->PrintType(node->fields[i], os);
if (i + 1 != node->fields.size()) {
os << ", ";
}
}
os << "]";
}
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data // by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node)); os << meta_.GetMetaNode(GetRef<NodeRef>(node));
......
...@@ -96,40 +96,41 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -96,40 +96,41 @@ class TypeFunctor<R(const Type& n, Args...)> {
* *
* We recursively visit each type contained inside the visitor. * We recursively visit each type contained inside the visitor.
*/ */
template <typename... Args> class TypeVisitor :
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> { public ::tvm::relay::TypeFunctor<void(const Type& n)> {
void VisitType_(const TypeVarNode* op, Args... args) override {} public:
void VisitType_(const TypeVarNode* op) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override { void VisitType_(const FuncTypeNode* op) override {
for (auto type_param : op->type_params) { for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...); this->VisitType(type_param);
} }
for (auto type_cs : op->type_constraints) { for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...); this->VisitType(type_cs);
} }
for (auto arg_type : op->arg_types) { for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...); this->VisitType(arg_type);
} }
this->VisitType(op->ret_type, std::forward<Args>(args)...); this->VisitType(op->ret_type);
} }
void VisitType_(const TensorTypeNode* op, Args... args) override {} void VisitType_(const TensorTypeNode* op) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override { void VisitType_(const TupleTypeNode* op) override {
for (const Type& t : op->fields) { for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...); this->VisitType(t);
} }
} }
void VisitType_(const TypeRelationNode* op, Args... args) override { void VisitType_(const TypeRelationNode* op) override {
for (const Type& t : op->args) { for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...); this->VisitType(t);
} }
} }
void VisitType_(const IncompleteTypeNode* op, Args... args) override {} void VisitType_(const IncompleteTypeNode* op) override {}
}; };
// A functional visitor for rebuilding an AST in place. // A functional visitor for rebuilding an AST in place.
......
...@@ -15,6 +15,62 @@ ...@@ -15,6 +15,62 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(BiasAddAttrs);
bool BiasAddRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const BiasAddAttrs* param = attrs.as<BiasAddAttrs>();
CHECK(param != nullptr);
int axis = param->axis;
if (axis < 0) {
axis = data->shape.size() + axis;
}
CHECK_LE(axis, static_cast<int>(data->shape.size()))
<< "axis " << param->axis << " is out of range";
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(
{data->shape[axis]}, data->dtype));
reporter->Assign(types[2], types[0]);
return true;
}
// Positional relay function to create dense operator used by frontend FFI.
Expr MakeBiasAdd(Expr data,
Expr bias,
int axis) {
auto attrs = make_node<BiasAddAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.bias_add");
return CallNode::make(op, {data, bias}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.bias_add")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeBiasAdd, args, rv);
});
RELAY_REGISTER_OP("nn.bias_add")
.describe(R"code(Add bias to an axis of the input.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BiasAddAttrs")
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1)
.add_type_rel("BiasAdd", BiasAddRel);
TVM_REGISTER_NODE_TYPE(DenseAttrs); TVM_REGISTER_NODE_TYPE(DenseAttrs);
...@@ -82,7 +138,7 @@ RELAY_REGISTER_OP("nn.dense") ...@@ -82,7 +138,7 @@ RELAY_REGISTER_OP("nn.dense")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.") .add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.") .add_argument("weight", "2D Tensor", "Weight matrix.")
.set_support_level(2) .set_support_level(1)
.add_type_rel("Dense", DenseRel); .add_type_rel("Dense", DenseRel);
...@@ -235,13 +291,23 @@ Example:: ...@@ -235,13 +291,23 @@ Example::
.set_support_level(2) .set_support_level(2)
.add_type_rel("BatchFlatten", BatchFlattenRel); .add_type_rel("BatchFlatten", BatchFlattenRel);
RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu")
// relu
TVM_REGISTER_API("relay.op.nn._make.relu")
.set_body_typed<Expr(Expr)>([](Expr data) {
static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {data}, Attrs(), {});
});
RELAY_REGISTER_OP("nn.relu")
.describe(R"code(Returns the relu input array, computed element-wise. .describe(R"code(Returns the relu input array, computed element-wise.
.. math:: .. math::
max(x, 0) max(x, 0)
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel);
...@@ -371,24 +437,6 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input ...@@ -371,24 +437,6 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
// batch_norm // batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs); TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
bool CheckVectorLength(int64_t dim, const DataType& dtype, Type vector, const char* name) {
const auto* candidate = vector.as<TensorTypeNode>();
CHECK(candidate != nullptr)
<< name << " should be a vector but is not a tensor type,";
CHECK_EQ(dtype, candidate->dtype)
<< name << " should be of the same data type as the original but it is not.";
CHECK_EQ(candidate->shape.size(), 1)
<< name << " should be a vector but has a shape of "
<< candidate->shape.size() << " dimensions instead of 1.";
const int64_t* length = as_const_int(candidate->shape[0]);
if (length == nullptr) return false;
CHECK(*length == dim)
<< name << " should be as long as the channel but has length "
<< *length << " instead of " << dim << ".";
return true;
}
bool BatchNormRel(const Array<Type>& types, bool BatchNormRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
...@@ -396,33 +444,19 @@ bool BatchNormRel(const Array<Type>& types, ...@@ -396,33 +444,19 @@ bool BatchNormRel(const Array<Type>& types,
CHECK_EQ(types.size(), 6); CHECK_EQ(types.size(), 6);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false; if (data == nullptr) return false;
if (data->shape.size() == 0) return false;
const BatchNormAttrs* param = attrs.as<BatchNormAttrs>(); const BatchNormAttrs* param = attrs.as<BatchNormAttrs>();
// axis of -1 means use the last dimension // axis of -1 means use the last dimension
CHECK(param->axis >= -1 && param->axis < (int)data->shape.size()); CHECK(param->axis >= -1 && param->axis < (int)data->shape.size());
int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1; int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1;
auto axis_size = data->shape[axis];
auto dim = as_const_int(data->shape[axis]);
if (dim == nullptr) return false;
// if we are using beta and gamma, they need to be of shape (dim,) // if we are using beta and gamma, they need to be of shape (dim,)
if (param->scale && !CheckVectorLength(*dim, data->dtype, types[1], "The gamma scale factor")) { reporter->Assign(types[1], TensorTypeNode::make({axis_size}, data->dtype));
return false; reporter->Assign(types[2], TensorTypeNode::make({axis_size}, data->dtype));
} reporter->Assign(types[3], TensorTypeNode::make({axis_size}, data->dtype));
reporter->Assign(types[4], TensorTypeNode::make({axis_size}, data->dtype));
if (param->center && !CheckVectorLength(*dim, data->dtype, types[2], "The beta offset factor")) {
return false;
}
// the two running averages must also be vectors of length dim
if (!CheckVectorLength(*dim, data->dtype, types[3], "The moving mean")) {
return false;
}
if (!CheckVectorLength(*dim, data->dtype, types[4], "The moving variance")) {
return false;
}
// output is a tuple of the normed data (same shape as input), new running mean, // output is a tuple of the normed data (same shape as input), new running mean,
// and new running average (the latter two are both vectors of length dim) // and new running average (the latter two are both vectors of length dim)
......
...@@ -13,8 +13,52 @@ ...@@ -13,8 +13,52 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/* relay.expand_dims */ // relay.cast
TVM_REGISTER_NODE_TYPE(CastAttrs);
bool CastRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "cast: expect input type to be TensorType but get "
<< types[0];
return false;
}
const auto* param = attrs.as<CastAttrs>();
reporter->Assign(types[1], TensorTypeNode::make(
data->shape, param->dtype));
return true;
}
Expr MakeCast(Expr data,
DataType dtype) {
auto attrs = make_node<CastAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("cast");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay._make.dtype_cast")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
});
RELAY_REGISTER_OP("cast")
.describe(R"code(Cast the data into a new data type.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.CastAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Cast", CastRel);
// relay.expand_dims
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
bool ExpandDimsRel(const Array<Type>& types, bool ExpandDimsRel(const Array<Type>& types,
...@@ -25,6 +69,9 @@ bool ExpandDimsRel(const Array<Type>& types, ...@@ -25,6 +69,9 @@ bool ExpandDimsRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) { if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "expand_dims: expect input type to be TensorType but get "
<< types[0];
return false; return false;
} }
const auto* param = attrs.as<ExpandDimsAttrs>(); const auto* param = attrs.as<ExpandDimsAttrs>();
...@@ -91,6 +138,9 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -91,6 +138,9 @@ bool ConcatenateRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* tensor_tuple = types[0].as<TupleTypeNode>(); const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) { if (tensor_tuple == nullptr) {
CHECK(types[0].as<TupleTypeNode>())
<< "cast: expect input type to be TupleType but get "
<< types[0];
return false; return false;
} }
const auto* param = attrs.as<ConcatenateAttrs>(); const auto* param = attrs.as<ConcatenateAttrs>();
...@@ -161,6 +211,9 @@ bool TransposeRel(const Array<Type>& types, ...@@ -161,6 +211,9 @@ bool TransposeRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) { if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "transpose: expect input type to be TensorType but get "
<< types[0];
return false; return false;
} }
const auto* param = attrs.as<TransposeAttrs>(); const auto* param = attrs.as<TransposeAttrs>();
...@@ -243,6 +296,9 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -243,6 +296,9 @@ bool ReshapeRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) { if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "reshape: expect input type to be TensorType but get "
<< types[0];
return false; return false;
} }
const auto* param = attrs.as<ReshapeAttrs>(); const auto* param = attrs.as<ReshapeAttrs>();
......
...@@ -22,7 +22,7 @@ namespace relay { ...@@ -22,7 +22,7 @@ namespace relay {
using namespace tvm::runtime; using namespace tvm::runtime;
using Kind = TypeVarNode::Kind; using Kind = TypeVarNode::Kind;
struct KindChecker : TypeVisitor<> { struct KindChecker : TypeVisitor {
bool valid; bool valid;
KindChecker() : valid(true) {} KindChecker() : valid(true) {}
......
...@@ -471,6 +471,5 @@ TVM_REGISTER_API("relay._ir_pass.infer_type") ...@@ -471,6 +471,5 @@ TVM_REGISTER_API("relay._ir_pass.infer_type")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = InferType(args[0], args[1]); *ret = InferType(args[0], args[1]);
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -12,107 +12,120 @@ ...@@ -12,107 +12,120 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
class FreeVar; // FreeTypeVar
class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars; class FreeTypeVarTVisitor : public TypeVisitor {
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars; public:
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars, FreeTypeVarTVisitor(
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) : Array<TypeVar>* free_vars,
free_vars(free_vars), bound_vars(bound_vars) { } std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars)
: free_vars_(free_vars), bound_vars_(bound_vars) { }
void VisitType_(const TypeVarNode* tp) final { void VisitType_(const TypeVarNode* tp) final {
auto var = GetRef<TypeVar>(tp); TypeVar var = GetRef<TypeVar>(tp);
if (bound_vars->count(var) == 0) { if (bound_vars_->count(var) == 0) {
free_vars->insert(var); free_vars_->push_back(var);
} }
} }
void VisitType_(const FuncTypeNode* f) final { void VisitType_(const FuncTypeNode* f) final {
for (auto type_param : f->type_params) { for (auto type_param : f->type_params) {
bound_vars->insert(type_param); bound_vars_->insert(type_param);
} }
TypeVisitor::VisitType_(f);
for (auto type_cs : f->type_constraints) {
this->VisitType(type_cs);
} }
for (auto arg_type : f->arg_types) { private:
this->VisitType(arg_type); Array<TypeVar>* free_vars_;
} std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars_;
this->VisitType(f->ret_type);
}
friend FreeVar;
}; };
class FreeVar : public ExprVisitor { class FreeTypeVarEVisitor : private ExprVisitor {
void VisitExpr_(const VarNode* v) final { public:
auto var = GetRef<Var>(v); Array<TypeVar> Find(const Expr& expr) {
if (bound_vars.count(var) == 0) { this->VisitExpr(expr);
free_vars.insert(var); return free_vars_;
}
if (v->type_annotation.defined()) {
VisitType(v->type_annotation);
} }
Array<TypeVar> Find(const Type& type) {
this->VisitType(type);
return free_vars_;
} }
void VisitExpr_(const FunctionNode* f) final { void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) { for (const auto& tp : f->type_params) {
bound_types.insert(tp); bound_vars_.insert(tp);
}
for (const auto& param : f->params) {
bound_vars.insert(param);
} }
VisitExpr(f->body); ExprVisitor::VisitExpr_(f);
VisitType(f->ret_type);
} }
void VisitExpr_(const LetNode* l) final { void VisitType(const Type& t) final {
bound_vars.insert(l->var); FreeTypeVarTVisitor(&free_vars_, &bound_vars_)
VisitExpr(l->value); .VisitType(t);
VisitExpr(l->body);
} }
private:
// The result list
Array<TypeVar> free_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_vars_;
};
class FreeVarVisitor : protected ExprVisitor {
public: public:
std::unordered_set<Var, NodeHash, NodeEqual> free_vars; Array<Var> Find(const Expr& expr) {
std::unordered_set<Var, NodeHash, NodeEqual> bound_vars; this->VisitExpr(expr);
std::unordered_set<TypeVar, NodeHash, NodeEqual> free_types; return free_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_types; }
void VisitType(const Type& t) final { void VisitExpr_(const VarNode* var) final {
FreeTypeVar(&free_types, &bound_types)(t); if (bound_vars_.count(var) == 0) {
free_vars_.push_back(GetRef<Var>(var));
}
}
void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) {
bound_vars_.insert(param.operator->());
}
VisitExpr(op->body);
} }
void VisitExpr_(const LetNode* op) final {
bound_vars_.insert(op->var.operator->());
VisitExpr(op->value);
VisitExpr(op->body);
}
private:
// The result list
Array<Var> free_vars_;
std::unordered_set<const VarNode*> bound_vars_;
}; };
tvm::Array<Var> FreeVariables(const Expr& e) { tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) {
FreeVar fv; return FreeTypeVarEVisitor().Find(expr);
fv.VisitExpr(e);
return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end());
} }
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e) { tvm::Array<TypeVar> FreeTypeVars(const Type& type) {
FreeVar fv; return FreeTypeVarEVisitor().Find(type);
fv.VisitExpr(e);
return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
} }
tvm::Array<TypeVar> FreeTypeVariables(const Type& t) { tvm::Array<Var> FreeVars(const Expr& expr) {
FreeVar fv; return FreeVarVisitor().Find(expr);
fv.VisitType(t);
return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
} }
TVM_REGISTER_API("relay._ir_pass.free_vars") TVM_REGISTER_API("relay._ir_pass.free_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FreeVariables(args[0]); *ret = FreeVars(args[0]);
}); });
TVM_REGISTER_API("relay._ir_pass.free_type_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
if (x.as<TypeNode>()) { if (x.as<TypeNode>()) {
*ret = FreeTypeVariables(Downcast<Type>(x)); *ret = FreeTypeVars(Downcast<Type>(x));
} else { } else {
*ret = FreeTypeVariables(Downcast<Expr>(x)); *ret = FreeTypeVars(Downcast<Expr>(x));
} }
}); });
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
struct NotWellFormed { };
//! brief make sure each Var is bind at most once. //! brief make sure each Var is bind at most once.
class WellFormedChecker : private ExprVisitor { class WellFormedChecker : private ExprVisitor {
......
import tvm import tvm
import tvm.relay.testing
import numpy as np import numpy as np
from tvm import relay from tvm import relay
do_print = [False] do_print = [False]
def show(text): def show(text):
...@@ -94,9 +96,18 @@ def test_variable_name(): ...@@ -94,9 +96,18 @@ def test_variable_name():
v1 = relay.var("1") v1 = relay.var("1")
assert "%v1" in v1.astext() assert "%v1" in v1.astext()
def test_mlp():
net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
net.astext()
def test_resnet():
net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
net.astext()
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_resnet()
test_mlp()
test_func() test_func()
test_env() test_env()
test_meta_data() test_meta_data()
......
...@@ -12,9 +12,8 @@ def test_well_formed(): ...@@ -12,9 +12,8 @@ def test_well_formed():
assert not well_formed(relay.Let(x, v, let)) assert not well_formed(relay.Let(x, v, let))
f = relay.Function([x], x, ty) f = relay.Function([x], x, ty)
assert well_formed(f) assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing) assert well_formed(
# but we want all binder to be distinct from each other. relay.Let(relay.Var("y"), f,
assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v))) relay.Let(relay.Var("z"), f, v)))
...@@ -25,7 +24,7 @@ def test_tuple(): ...@@ -25,7 +24,7 @@ def test_tuple():
let = relay.Let(x, v, x) let = relay.Let(x, v, x)
assert well_formed(let) assert well_formed(let)
assert well_formed(relay.Tuple([v, v])) assert well_formed(relay.Tuple([v, v]))
assert not well_formed(relay.Tuple([let, let])) assert not well_formed(relay.Tuple([let, relay.Let(x, v, x)]))
def test_tuple_get_item(): def test_tuple_get_item():
......
...@@ -42,6 +42,15 @@ def test_binary_op(): ...@@ -42,6 +42,15 @@ def test_binary_op():
check_binary_op(opfunc) check_binary_op(opfunc)
def test_bias_add():
x = relay.var("x", shape=(10, 2, 3, 4))
bias = relay.var("bias")
z = relay.nn.bias_add(x, bias)
zz = relay.ir_pass.infer_type(z)
assert "axis=" not in zz.astext()
assert zz.args[1].checked_type == relay.TensorType((2,))
def test_expand_dims_infer_type(): def test_expand_dims_infer_type():
n, t, d = tvm.var("n"), tvm.var("t"), 100 n, t, d = tvm.var("n"), tvm.var("t"), 100
x = relay.var("x", shape=(n, t, d)) x = relay.var("x", shape=(n, t, d))
...@@ -91,7 +100,7 @@ def test_dropout(): ...@@ -91,7 +100,7 @@ def test_dropout():
n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d")
input_ty = relay.TensorType((n, t, d), "float32") input_ty = relay.TensorType((n, t, d), "float32")
x = relay.var("x", input_ty) x = relay.var("x", input_ty)
y, _ = relay.nn.dropout(x, rate=0.75) y = relay.nn.dropout(x, rate=0.75)
assert "rate=" in y.astext() assert "rate=" in y.astext()
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == input_ty assert yy.checked_type == input_ty
...@@ -106,7 +115,7 @@ def test_batch_norm(): ...@@ -106,7 +115,7 @@ def test_batch_norm():
moving_var = relay.var("moving_var", relay.TensorType((2,))) moving_var = relay.var("moving_var", relay.TensorType((2,)))
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
center=False, scale=False) center=False, scale=False)
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y.astuple())
assert "center=" in yy.astext() assert "center=" in yy.astext()
assert yy.checked_type == relay.ty.TupleType(tvm.convert([ assert yy.checked_type == relay.ty.TupleType(tvm.convert([
relay.TensorType((3, 2, 1), "float32"), relay.TensorType((3, 2, 1), "float32"),
...@@ -121,7 +130,7 @@ def test_batch_norm(): ...@@ -121,7 +130,7 @@ def test_batch_norm():
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
axis=0, center=False, scale=False) axis=0, center=False, scale=False)
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y.astuple())
assert yy.checked_type == relay.ty.TupleType(tvm.convert([ assert yy.checked_type == relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((3, 2, 1), "float32"), relay.ty.TensorType((3, 2, 1), "float32"),
relay.ty.TensorType((3,), "float32"), relay.ty.TensorType((3,), "float32"),
...@@ -136,7 +145,7 @@ def test_batch_norm(): ...@@ -136,7 +145,7 @@ def test_batch_norm():
moving_var = relay.var("moving_var", relay.TensorType((3,))) moving_var = relay.var("moving_var", relay.TensorType((3,)))
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
axis=-1, center=False, scale=False) axis=-1, center=False, scale=False)
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y.astuple())
assert yy.checked_type == relay.ty.TupleType(tvm.convert([ assert yy.checked_type == relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((1, 2, 3), "float32"), relay.ty.TensorType((1, 2, 3), "float32"),
relay.ty.TensorType((3,), "float32"), relay.ty.TensorType((3,), "float32"),
...@@ -145,6 +154,7 @@ def test_batch_norm(): ...@@ -145,6 +154,7 @@ def test_batch_norm():
if __name__ == "__main__": if __name__ == "__main__":
test_bias_add()
test_unary_op() test_unary_op()
test_binary_op() test_binary_op()
test_expand_dims_infer_type() test_expand_dims_infer_type()
......
...@@ -27,6 +27,14 @@ def test_unary_identity(): ...@@ -27,6 +27,14 @@ def test_unary_identity():
assert yy.checked_type == relay.TensorType((8, 9, 4), "float32") assert yy.checked_type == relay.TensorType((8, 9, 4), "float32")
def test_cast():
x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
y = x.astype("int32")
yy = relay.ir_pass.infer_type(y)
assert "dtype=" in yy.astext()
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
def test_clip_type(): def test_clip_type():
a = relay.var("a", relay.TensorType((10, 4), "float32")) a = relay.var("a", relay.TensorType((10, 4), "float32"))
y = relay.clip(a, 1., 4.) y = relay.clip(a, 1., 4.)
...@@ -139,7 +147,9 @@ def test_infer_type_leaky_relu(): ...@@ -139,7 +147,9 @@ def test_infer_type_leaky_relu():
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
if __name__ == "__main__": if __name__ == "__main__":
test_cast()
test_zeros_ones() test_zeros_ones()
test_unary_identity() test_unary_identity()
test_clip_type() test_clip_type()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment