Unverified Commit d4f8d20e by Tianqi Chen Committed by GitHub

[RELAY][FRONTEND] Initial MXNet frontend support. (#2163)

parent 0cf3ddf7
......@@ -82,6 +82,7 @@ This level enables additional math and transform operators.
tvm.relay.reshape_like
tvm.relay.copy
tvm.relay.transpose
tvm.relay.squeeze
tvm.relay.floor
tvm.relay.ceil
tvm.relay.trunc
......@@ -114,7 +115,7 @@ This level enables additional math and transform operators.
tvm.relay.less_equal
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.pow
tvm.relay.power
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
......@@ -196,6 +197,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.squeeze
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
.. autofunction:: tvm.relay.zeros
......@@ -220,7 +222,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.maximum
.. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.power
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
......
......@@ -89,7 +89,7 @@ struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;
TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("The axis to sum over when computing softmax.");
}
};
......
......@@ -62,7 +62,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<IndexExpr> newshape;
Array<Integer> newshape;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape)
.describe("The new shape. Should be compatible with the original shape.");
......
......@@ -420,9 +420,9 @@ along which to split the array.
return Array<Tensor>{
topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) };
} else {
Array<Expr> indices;
Array<Integer> indices;
for (auto i : param.indices_or_sections) {
indices.push_back(tvm::make_const(tvm::Int(32), i));
indices.push_back(static_cast<int>(i));
}
return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
}
......
......@@ -7,7 +7,7 @@ from . import ty
from . import expr
from . import module
from . import ir_pass
from .build_module import build, create_executor
from .build_module import build, build_config, create_executor
# Root operators
from .op import Op
......@@ -17,6 +17,7 @@ from .op.transform import *
from . import nn
from . import vision
from . import image
from . import frontend
from . import backend
from .scope_builder import ScopeBuilder
......@@ -40,6 +41,7 @@ IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
# Expr
Expr = expr.Expr
Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
......
......@@ -72,8 +72,18 @@ class CompileEngine(NodeBase):
cached_func: CachedFunc
The result of lowering.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
# pylint: disable=broad-except
try:
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
except Exception:
import traceback
msg = traceback.format_exc()
msg += "Error during compile func\n"
msg += "--------------------------\n"
msg += source_func.astext(show_meta_data=False)
msg += "--------------------------\n"
raise RuntimeError(msg)
def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
......
......@@ -357,4 +357,4 @@ class GraphRuntimeCodegen(ExprFunctor):
return name
index = self._name_map[name]
self._name_map[name] += 1
return self.get_unique_name(name + str(index))
return self._get_unique_name(name + str(index))
......@@ -13,7 +13,7 @@ from .backend import graph_runtime_codegen as _graph_gen
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 1,
"CombineParallelConv2D": 4,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
......@@ -157,7 +157,6 @@ def optimize(func, params=None):
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
return func
......
"""Relay frontends."""
from __future__ import absolute_import
from .mxnet import from_mxnet
"""Common utilities"""
from __future__ import absolute_import as _abs
class RequiredAttr(object):
"""Dummpy class to represent required attr"""
pass
class StrAttrsDict(object):
"""Helper class to parse attrs stored as Dict[str, str].
Parameters
----------
attrs : Dict[str, str]
The attributes to be used.
"""
def __init__(self, attrs):
self.attrs = attrs
def get_float(self, key, default=RequiredAttr()):
"""Get float attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
return float(self.attrs[key])
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_int(self, key, default=RequiredAttr()):
"""Get int attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
val = self.attrs[key]
if val == "None":
return None
return int(val)
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_str(self, key, default=RequiredAttr()):
"""Get str attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
return self.attrs[key]
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_int_tuple(self, key, default=RequiredAttr()):
"""Get int tuple attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()').split(','))
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_bool(self, key, default=RequiredAttr()):
"""Get bool tuple attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
val = self.attrs[key]
return val.strip().lower() in ['true', '1', 't', 'y', 'yes']
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
......@@ -14,6 +14,7 @@ from . import vision
# operator registry
from . import _tensor
from . import _transform
from . import _reduce
from ..expr import Expr
from ..base import register_relay_node
......
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
from . import op as _reg
def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
with target:
return topi.generic.schedule_reduce(outs)
_reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
......@@ -273,4 +273,5 @@ def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]
register_schedule("concatenate", schedule_injective)
register_pattern("concatenate", OpPattern.INJECTIVE)
# TODO(tqchen): renable concat as injective
register_pattern("concatenate", OpPattern.OPAQUE)
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name
from __future__ import absolute_import
import topi
import topi.cuda
from tvm import container
from . import op as _reg
from .op import (schedule_injective, register_compute, register_schedule,
register_pattern, OpPattern)
schedule_broadcast = schedule_injective
schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
# squeeze
@register_compute("squeeze")
def squeeze_compiler(attrs, inputs, output_type, target):
"""Compiler for squeeze dims."""
assert len(inputs) == 1
if attrs.axis is None:
axis = None
elif isinstance(attrs.axis, container.Array):
axis = tuple(attrs.axis)
else:
axis = int(attrs.axis)
return [topi.squeeze(inputs[0], axis)]
register_pattern("squeeze", OpPattern.INJECTIVE)
register_schedule("squeeze", schedule_injective)
# expand_dims
@register_compute("expand_dims")
def expand_dims_compiler(attrs, inputs, output_type, target):
"""Compiler for expand_dims."""
assert len(inputs) == 1
new_axis = int(attrs.num_newaxis)
assert new_axis >= 0
# axis should be in range [-data.ndim - 1, data.ndim]
axis = int(attrs.axis)
assert axis >= -len(inputs[0].shape) - 1
assert axis <= len(inputs[0].shape)
return [topi.expand_dims(inputs[0], axis, new_axis)]
_reg.register_schedule("squeeze", schedule_injective)
_reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_pattern("expand_dims", OpPattern.BROADCAST)
# strided_slice
_reg.register_schedule("strided_slice", schedule_injective)
# slice_like
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_pattern("slice_like", OpPattern.INJECTIVE)
# reshape
_reg.register_schedule("reshape", schedule_injective)
_reg.register_pattern("reshape", OpPattern.INJECTIVE)
# reshape_like
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_pattern("reshape_like", OpPattern.INJECTIVE)
_reg.register_schedule("cast", schedule_broadcast)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
from topi.util import get_const_int, get_const_tuple
from .. import op as reg
......
......@@ -145,7 +145,7 @@ def conv2d_transpose(data,
weight_layout, output_padding, out_dtype)
def softmax(data, axis=1):
def softmax(data, axis=-1):
r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
......@@ -169,7 +169,7 @@ def softmax(data, axis=1):
return _make.softmax(data, axis)
def log_softmax(data, axis):
def log_softmax(data, axis=-1):
r"""Computes log softmax.
.. math::
......
......@@ -54,7 +54,7 @@ def Inception7A(data,
name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv')
concat = relay.concatenate((tower_1x1, tower_5x5, tower_3x3, cproj), axis=0)
concat = relay.concatenate((tower_1x1, tower_5x5, tower_3x3, cproj), axis=1)
return concat
# First Downsample
......@@ -72,7 +72,7 @@ def Inception7B(data,
name=('%s_tower' % name), suffix='_conv_2')
pooling = Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0, 0), pool_type="max",
name=('max_pool_%s_pool' % name))
concat = relay.concatenate((tower_3x3, tower_d3x3, pooling), axis=0)
concat = relay.concatenate((tower_3x3, tower_d3x3, pooling), axis=1)
return concat
def Inception7C(data,
......@@ -101,7 +101,7 @@ def Inception7C(data,
cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1),
name=('%s_tower_2' % name), suffix='_conv')
# concat
concat = relay.concatenate((tower_1x1, tower_d7, tower_q7, cproj), axis=0)
concat = relay.concatenate((tower_1x1, tower_d7, tower_q7, cproj), axis=1)
return concat
def Inception7D(data,
......@@ -124,7 +124,7 @@ def Inception7D(data,
pooling = Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, pad=(0, 0),
name=('%s_pool_%s_pool' % (pool, name)))
# concat
concat = relay.concatenate((tower_3x3, tower_d7_3x3, pooling), axis=0)
concat = relay.concatenate((tower_3x3, tower_d7_3x3, pooling), axis=1)
return concat
def Inception7E(data,
......@@ -153,7 +153,7 @@ def Inception7E(data,
suffix='_conv')
# concat
concat = relay.concatenate(
(tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj), axis=0)
(tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj), axis=1)
return concat
def get_net(batch_size,
......
......@@ -31,19 +31,21 @@ from .init import create_workload
from . import layers
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels, prefix):
net = _make_fire_conv(net, squeeze_channels, 1, 0, "%s_input" % prefix)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
left = _make_fire_conv(net, expand1x1_channels, 1, 0, "%s_left" % prefix)
right = _make_fire_conv(net, expand3x3_channels, 3, 1, "%s_right" % prefix)
# NOTE : Assume NCHW layout here
net = relay.concatenate((left, right), axis=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = layers.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding), name="conv2d")
def _make_fire_conv(net, channels, kernel_size, padding=0, prefix=""):
net = layers.conv2d(net,
channels=channels,
kernel_size=(kernel_size, kernel_size),
padding=(padding, padding), name="%s_conv" % prefix)
net = relay.nn.bias_add(net, relay.var("%s_conv_bias" % prefix))
net = relay.nn.relu(net)
return net
......@@ -75,41 +77,44 @@ def get_net(batch_size, image_shape, num_classes, version, dtype):
kernel_size=(7, 7),
strides=(2, 2),
padding=(3, 3),
name="conv2d")
net = relay.nn.bias_add(net, relay.var("dense1_bias"))
name="conv1")
net = relay.nn.bias_add(net, relay.var("conv1_bias"))
net = relay.nn.relu(net)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 16, 64, 64, "fire1")
net = _make_fire(net, 16, 64, 64, "fire2")
net = _make_fire(net, 32, 128, 128, "fire3")
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 32, 128, 128, "fire4")
net = _make_fire(net, 48, 192, 192, "fire5")
net = _make_fire(net, 48, 192, 192, "fire6")
net = _make_fire(net, 64, 256, 256, "fire7")
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256, "fire8")
else:
net = layers.conv2d(net,
channels=64,
kernel_size=(3, 3),
strides=(2, 2),
padding=(1, 1),
name="conv2d")
name="conv1")
net = relay.nn.bias_add(net, relay.var("conv1_bias"))
net = relay.nn.relu(net)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64, "fire1")
net = _make_fire(net, 16, 64, 64, "fire2")
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128, "fire3")
net = _make_fire(net, 32, 128, 128, "fire4")
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 48, 192, 192, "fire5")
net = _make_fire(net, 48, 192, 192, "fire6")
net = _make_fire(net, 64, 256, 256, "fire7")
net = _make_fire(net, 64, 256, 256, "fire8")
net = relay.nn.dropout(net, rate=0.5)
net = layers.conv2d(net, channels=num_classes, kernel_size=(1, 1), name="conv2d")
net = layers.conv2d(
net, channels=num_classes, kernel_size=(1, 1), name="conv_final")
net = relay.nn.bias_add(net, relay.var("conv_final_bias"))
net = relay.nn.relu(net)
net = relay.nn.global_avg_pool2d(net)
net = relay.nn.batch_flatten(net)
......@@ -117,8 +122,12 @@ def get_net(batch_size, image_shape, num_classes, version, dtype):
args = relay.ir_pass.free_vars(net)
return relay.Function(args, net)
def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32"):
def get_workload(batch_size=1,
num_classes=1000,
version='1.0',
image_shape=(3, 224, 224),
dtype="float32"):
"""Get benchmark workload for SqueezeNet
Parameters
......
......@@ -24,20 +24,24 @@ from tvm import relay
from .init import create_workload
from . import layers as wrapper
def get_feature(internel_layer, layers, filters, batch_norm=False):
def get_feature(internal_layer, layers, filters, batch_norm=False):
"""Get VGG feature body as stacks of convoltions."""
for i, num in enumerate(layers):
for j in range(num):
internel_layer = wrapper.conv2d(
data=internel_layer, kernel_size=(3, 3), padding=(1, 1),
channels=filters[i], name="conv%s_%s"%(i + 1, j + 1))
internal_layer = wrapper.conv2d(
data=internal_layer, kernel_size=(3, 3), padding=(1, 1),
channels=filters[i], name="conv%s_%s" % (i + 1, j + 1))
internal_layer = relay.nn.bias_add(
internal_layer, relay.var("conv%s_%s_bias" % (i + 1, j + 1)))
if batch_norm:
internel_layer = wrapper.batch_norm_infer(
data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = relay.nn.relu(data=internel_layer)
internel_layer = relay.nn.max_pool2d(
data=internel_layer, pool_size=(2, 2), strides=(2, 2))
return internel_layer
internal_layer = wrapper.batch_norm_infer(
data=internal_layer, name="bn%s_%s" %(i + 1, j + 1))
internal_layer = relay.nn.relu(data=internal_layer)
internal_layer = relay.nn.max_pool2d(
data=internal_layer, pool_size=(2, 2), strides=(2, 2))
return internal_layer
def get_classifier(input_data, num_classes):
"""Get VGG classifier layers as fc layers."""
......@@ -51,6 +55,7 @@ def get_classifier(input_data, num_classes):
fc8 = wrapper.dense_add_bias(data=drop7, units=num_classes, name="fc8")
return fc8
def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_norm=False):
"""
Parameters
......@@ -68,7 +73,7 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no
The data type
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
Number of layers for the variant of vgg. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
......@@ -88,7 +93,12 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no
args = relay.ir_pass.free_vars(symbol)
return relay.Function(args, symbol)
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
def get_workload(batch_size,
num_classes=1000,
image_shape=(3, 224, 224),
dtype="float32",
num_layers=11):
"""Get benchmark workload for VGG nets.
Parameters
......@@ -105,6 +115,9 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype=
dtype : str, optional
The data type
num_layers : int
Number of layers for the variant of vgg. Options are 11, 13, 16, 19.
Returns
-------
net : nnvm.Symbol
......@@ -113,5 +126,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype=
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, image_shape, num_classes, dtype)
net = get_net(batch_size, image_shape, num_classes, dtype, num_layers)
return create_workload(net)
......@@ -163,6 +163,7 @@ class AttrsHashHandler :
* \param node The node to be hashed.
*/
size_t Hash(const NodeRef& node) {
if (!node.defined()) return 0;
return this->VisitAttr(node);
}
......
......@@ -31,7 +31,10 @@ class StorageAllocaBaseVisitor : public ExprVisitor {
for (Var param : func->params) {
CreateToken(param.operator->(), false);
}
this->VisitExpr(func->body);
// must always keep output alive.
for (StorageToken* tok : GetToken(func->body)) {
tok->ref_counter += 1;
}
}
void VisitExpr_(const ConstantNode* op) final {
......
......@@ -16,7 +16,7 @@ namespace tvm {
namespace relay {
template<typename T>
std::vector<T> AsVector(const Array<T> &array) {
inline std::vector<T> AsVector(const Array<T> &array) {
std::vector<T> result;
result.reserve(array.size());
for (const T& ele : array) {
......
......@@ -5,6 +5,8 @@
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <topi/elemwise.h>
#include <topi/reduction.h>
#include <numeric>
#include <limits>
#include "../op_common.h"
......@@ -15,12 +17,12 @@ namespace relay {
/*! \brief Attributes for Reduce operators */
struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
Array<IndexExpr> axis;
Array<Integer> axis;
bool keepdims;
bool exclude;
TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<IndexExpr>>())
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
......@@ -50,7 +52,7 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
* \return r_axes The new reduced axes of the output.
*/
inline std::vector<int64_t> GetReduceAxes(const uint32_t indim,
const Array<IndexExpr>& inaxis,
const Array<Integer>& inaxis,
bool exclude) {
if (!inaxis.defined()) {
std::vector<int64_t> r_axes(indim);
......@@ -60,9 +62,7 @@ inline std::vector<int64_t> GetReduceAxes(const uint32_t indim,
std::vector<int64_t> in_axes;
for (auto i : inaxis) {
const int64_t* k = as_const_int(i);
CHECK(k != nullptr) << "Reduce axis need to be constant, cannot be symbolic";
int64_t axis = k[0];
int64_t axis = i->value;
if (axis < 0) {
axis = axis + indim;
}
......@@ -97,6 +97,53 @@ inline std::vector<int64_t> GetReduceAxes(const uint32_t indim,
return r_axes;
}
// Get axis under exclude condition.
Array<Integer> GetExcludeAxes(size_t indim,
const Array<Integer>& inaxis) {
std::vector<bool> axis_flag(indim, true);
for (auto i : inaxis) {
int64_t axis = i->value;
if (axis < 0) {
axis = axis + static_cast<int64_t>(indim);
}
// Check out of bounds error
CHECK_GE(axis, 0)
<< "Axis out of bounds in reduce operator.";
CHECK_LT(axis, static_cast<int64_t>(indim))
<< "Axis out of bounds in reduce operator.";
axis_flag[axis] = false;
}
Array<Integer> r_axes;
for (size_t i = 0; i < axis_flag.size(); ++i) {
if (axis_flag[i]) {
r_axes.push_back(static_cast<int>(i));
}
}
return r_axes;
}
template<typename F>
Array<Tensor> ReduceCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target,
F f) {
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
}
if (axes.size() == 0) {
return { topi::identity(inputs[0]) };
}
return { f(inputs[0], axes, param->keepdims, false) };
}
/*!
* \brief ReduceShapeImpl get the outshape for the reduction operator
* \param in_shape Shape of input data.
......@@ -200,7 +247,7 @@ bool ReduceRel(const Array<Type>& types,
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body([](const TVMArgs& args, TVMRetValue* rv) { \
auto make_func = [](Expr data, \
Array<IndexExpr> axis, \
Array<Integer> axis, \
bool keepdims, \
bool exclude) { \
auto attrs = make_node<ReduceAttrs>(); \
......@@ -217,6 +264,14 @@ bool ReduceRel(const Array<Type>& types,
.add_argument("data", "Tensor", "The input tensor.")
Array<Tensor> ArgMaxCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::argmax);
}
RELAY_REGISTER_REDUCE_OP("argmax")
.describe(R"code(Creates an operation that finds the indices of the maximum
values over a given axis.
......@@ -224,8 +279,17 @@ values over a given axis.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);
.add_type_rel("ArgReduce", ArgReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> ArgMinCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::argmin);
}
RELAY_REGISTER_REDUCE_OP("argmin")
.describe(R"code(Creates an operation that finds the indices of the minimum
......@@ -234,7 +298,16 @@ values over a given axis.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);
.add_type_rel("ArgReduce", ArgReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> SumCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::sum);
}
RELAY_REGISTER_REDUCE_OP("sum")
......@@ -257,16 +330,35 @@ Example::
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> MaxCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::max);
}
RELAY_REGISTER_REDUCE_OP("max")
.describe(R"code(Computes the max of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> MinCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::min);
}
RELAY_REGISTER_REDUCE_OP("min")
......@@ -275,11 +367,20 @@ RELAY_REGISTER_REDUCE_OP("min")
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MinCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
RELAY_REGISTER_REDUCE_OP("mean")
.describe(R"code(Computes the mean of array elements over given axes.
Array<Tensor> ProdCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::prod);
}
RELAY_REGISTER_REDUCE_OP("prod")
.describe(R"code(Computes the products of array elements over given axes.
Example::
......@@ -287,20 +388,40 @@ Example::
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data)
[3.22]
mean(data, axis=1)
[35562240]
mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
RELAY_REGISTER_REDUCE_OP("prod")
.describe(R"code(Computes the products of array elements over given axes.
Array<Tensor> MeanCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
IndexExpr count = make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
auto axes = param->axis;
for (int64_t i : GetReduceAxes(inputs[0]->shape.size(),
param->axis,
param->exclude)) {
count *= inputs[0]->shape[i];
}
auto res = ReduceCompute(attrs, inputs, out_type, target, topi::sum);
return {topi::divide(res[0], count)};
}
RELAY_REGISTER_REDUCE_OP("mean")
.describe(R"code(Computes the mean of array elements over given axes.
Example::
......@@ -308,16 +429,17 @@ Example::
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
[35562240]
mean(data)
[3.22]
mean(data, axis=[1,2])
[ 36 480 2058]
[ 2. 3.16666667 4.5]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
} // namespace relay
} // namespace tvm
......@@ -344,6 +344,7 @@ Expr MultiplyForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_out_axes) {
if (!expected_out_axes.defined()) return Expr();
if (expected_out_axes.size() == 0) return Expr();
// TODO(tvm-team) allow same axes accumulation
// not as important because it is less common in nn.
const auto* slhs = new_args[0].as<ScaledExprNode>();
......@@ -681,7 +682,9 @@ AxesSet AddSubBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
// add of two elements.
return in_axes[0];
} else {
return NullValue<AxesSet>();
auto res = NullValue<AxesSet>();
CHECK(!res.defined());
return res;
}
}
......@@ -751,14 +754,14 @@ Expr MultiplyBackwardTransform(const Call& call,
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]);
AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]);
if (lhs_axes.defined()) {
if (lhs_axes.defined() && lhs_axes.size() != 0) {
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr rhs = call->args[1];
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) {
return transformer->Transform(call->args[0], lhs_axes, rhs);
}
} else if (rhs_axes.defined()) {
} else if (rhs_axes.defined() && rhs_axes.size() != 0) {
Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) {
return transformer->Transform(call->args[1], rhs_axes, lhs);
......
"""MXNet model zoo for testing purposes."""
from __future__ import absolute_import
from . import mlp, vgg, resnet, dqn, inception_v3, squeezenet, dcgan
import tvm.relay.testing
# mlp
def mx_mlp():
num_class = 10
return mlp.get_symbol(num_class)
def relay_mlp():
num_class = 10
return tvm.relay.testing.mlp.get_workload(1, num_class)[0]
# vgg
def mx_vgg(num_layers):
num_class = 1000
return vgg.get_symbol(num_class, num_layers)
def relay_vgg(num_layers):
num_class = 1000
return tvm.relay.testing.vgg.get_workload(
1, num_class, num_layers=num_layers)[0]
# resnet
def mx_resnet(num_layers):
num_class = 1000
return resnet.get_symbol(num_class, num_layers, '3,224,224')
def relay_resnet(num_layers):
num_class = 1000
return tvm.relay.testing.resnet.get_workload(
1, num_class, num_layers=num_layers)[0]
# dqn
mx_dqn = dqn.get_symbol
def relay_dqn():
return tvm.relay.testing.dqn.get_workload(1)[0]
# squeezenet
def mx_squeezenet(version):
return squeezenet.get_symbol(version=version)
def relay_squeezenet(version):
return tvm.relay.testing.squeezenet.get_workload(1, version=version)[0]
# inception
mx_inception_v3 = inception_v3.get_symbol
def relay_inception_v3():
return tvm.relay.testing.inception_v3.get_workload(1)[0]
# dcgan generator
mx_dcgan = dcgan.get_symbol
def relay_dcgan(batch_size):
return tvm.relay.testing.dcgan.get_workload(batch_size=batch_size)[0]
# pylint: disable=unused-argument
"""
The MXNet symbol of DCGAN generator
Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
import mxnet as mx
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])
pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
net = mx.sym.Deconvolution(data,
kernel=kshape,
stride=stride,
pad=(pad_y, pad_x),
adj=(adj_y, adj_x),
num_filter=oshape[0],
no_bias=True,
name=name)
return net
def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
return net
def get_symbol(oshape=(3, 64, 64), ngf=128, code=None):
"""get symbol of dcgan generator"""
assert oshape[-1] == 64, "Only support 64x64 image"
assert oshape[-2] == 64, "Only support 64x64 image"
code = mx.sym.Variable("data") if code is None else code
net = mx.sym.FullyConnected(code, name="g1", num_hidden=ngf*8*4*4, no_bias=True, flatten=False)
net = mx.sym.Activation(net, act_type='relu')
# 4 x 4
net = mx.sym.reshape(net, shape=(-1, ngf * 8, 4, 4))
# 8 x 8
net = deconv2d_bn_relu(
net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
net = deconv2d_bn_relu(
net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
net = deconv2d_bn_relu(
net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
# 64x64
net = deconv2d(
net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
net = mx.sym.Activation(net, act_type='tanh')
return net
"""
The mxnet symbol of Nature DQN
Reference:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529.
"""
import mxnet as mx
def get_symbol(num_action=18):
data = mx.sym.Variable(name='data')
net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
num_filter=32, name='conv1')
net = mx.sym.Activation(net, act_type='relu', name='relu1')
net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
num_filter=64, name='conv2')
net = mx.sym.Activation(net, act_type='relu', name='relu2')
net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
num_filter=64, name='conv3')
net = mx.sym.Activation(net, act_type='relu', name='relu3')
net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
net = mx.sym.Activation(net, act_type='relu', name='relu4')
net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)
return net
"""
Inception V3, suitable for images with around 299 x 299
Reference:
Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015).
Adopted from https://github.com/apache/incubator-mxnet/blob/
master/example/image-classification/symbols/inception-v3.py
"""
import mxnet as mx
import numpy as np
def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name='%s%s_batchnorm' % (name, suffix))
act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
return act
def Inception7A(data,
num_1x1,
num_3x3_red, num_3x3_1, num_3x3_2,
num_5x5_red, num_5x5,
pool, proj,
name):
tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name))
tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv')
tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv')
concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name)
return concat
# First Downsample
def Inception7B(data,
num_3x3,
num_d3x3_red, num_d3x3_1, num_d3x3_2,
pool,
name):
tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name))
tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv')
tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1')
tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name))
concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7C(data,
num_1x1,
num_d7_red, num_d7_1, num_d7_2,
num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4,
pool, proj,
name):
tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv')
tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1')
tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2')
tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv')
# concat
concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7D(data,
num_3x3_red, num_3x3,
num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3,
pool,
name):
tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv')
tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
# concat
concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7E(data,
num_1x1,
num_d3_red, num_d3_1, num_d3_2,
num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2,
pool, proj,
name):
tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv')
tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv')
tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1')
tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv')
tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv')
# concat
concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name)
return concat
def get_symbol(num_classes=1000, **kwargs):
data = mx.sym.Variable(name="data")
# stage 1
conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1")
conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2")
pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool")
# stage 2
conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3")
conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4")
pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1")
# # stage 3
in3a = Inception7A(pool1, 64,
64, 96, 96,
48, 64,
"avg", 32, "mixed")
in3b = Inception7A(in3a, 64,
64, 96, 96,
48, 64,
"avg", 64, "mixed_1")
in3c = Inception7A(in3b, 64,
64, 96, 96,
48, 64,
"avg", 64, "mixed_2")
in3d = Inception7B(in3c, 384,
64, 96, 96,
"max", "mixed_3")
# stage 4
in4a = Inception7C(in3d, 192,
128, 128, 192,
128, 128, 128, 128, 192,
"avg", 192, "mixed_4")
in4b = Inception7C(in4a, 192,
160, 160, 192,
160, 160, 160, 160, 192,
"avg", 192, "mixed_5")
in4c = Inception7C(in4b, 192,
160, 160, 192,
160, 160, 160, 160, 192,
"avg", 192, "mixed_6")
in4d = Inception7C(in4c, 192,
192, 192, 192,
192, 192, 192, 192, 192,
"avg", 192, "mixed_7")
in4e = Inception7D(in4d, 192, 320,
192, 192, 192, 192,
"max", "mixed_8")
# stage 5
in5a = Inception7E(in4e, 320,
384, 384, 384,
448, 384, 384, 384,
"avg", 192, "mixed_9")
in5b = Inception7E(in5a, 320,
384, 384, 384,
448, 384, 384, 384,
"max", 192, "mixed_10")
# pool
pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool")
flatten = mx.sym.Flatten(data=pool, name="flatten")
fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1', flatten=False)
softmax = mx.sym.SoftmaxOutput(data=fc1, name='softmax')
return softmax
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
import mxnet as mx
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.sym.Flatten(data=data)
try:
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
except:
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import mxnet as mx
import numpy as np
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=stride, pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
workspace=workspace, name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv3 + shortcut
else:
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv2 + shortcut
def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
workspace : int
Workspace used in convolution operator
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert(num_unit == num_stages)
data = mx.sym.Variable(name='data')
if dtype == 'float32':
# data = mx.sym.identity(data=data, name='id')
data = data
else:
if dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
(nchannel, height, width) = image_shape
if height <= 32: # such as cifar10
body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
else: # often expected to be 224 such as imagenet
body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
for i in range(num_stages):
body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace,
memonger=memonger)
for j in range(units[i]-1):
body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
bottle_neck=bottle_neck, workspace=workspace, memonger=memonger)
bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
flat = mx.sym.Flatten(data=pool1)
try:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False)
except:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
if dtype == 'float16':
fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
return mx.sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
"""
image_shape = [int(l) for l in image_shape.split(',')]
(nchannel, height, width) = image_shape
if height <= 28:
num_stages = 3
if (num_layers-2) % 9 == 0 and num_layers >= 164:
per_unit = [(num_layers-2)//9]
filter_list = [16, 64, 128, 256]
bottle_neck = True
elif (num_layers-2) % 6 == 0 and num_layers < 164:
per_unit = [(num_layers-2)//6]
filter_list = [16, 16, 32, 64]
bottle_neck = False
else:
raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
units = per_unit * num_stages
else:
if num_layers >= 50:
filter_list = [64, 256, 512, 1024, 2048]
bottle_neck = True
else:
filter_list = [64, 64, 128, 256, 512]
bottle_neck = False
num_stages = 4
if num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 50:
units = [3, 4, 6, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
return resnet(units = units,
num_stages = num_stages,
filter_list = filter_list,
num_classes = num_classes,
image_shape = image_shape,
bottle_neck = bottle_neck,
workspace = conv_workspace,
dtype = dtype)
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
import mxnet as mx
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = mx.sym.concat(left, right, dim=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size),
pad=(padding, padding))
net = mx.sym.Activation(net, act_type='relu')
return net
# Net
def get_symbol(num_classes=1000, version='1.0', **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
net = mx.sym.Variable("data")
if version == '1.0':
net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Dropout(net, p=0.5)
net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg')
net = mx.sym.flatten(net)
return mx.sym.softmax(net)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import mxnet as mx
import numpy as np
def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten")
try:
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False)
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False)
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False)
except:
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if num_layers not in vgg_spec:
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = mx.sym.Variable(name="data")
if dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
symbol = mx.sym.softmax(data=classifier, name='softmax')
return symbol
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm import relay
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import model_zoo
def verify_mxnet_frontend_impl(mx_symbol,
data_shape=(1, 3, 224, 224),
out_shape=(1, 1000),
gluon_impl=False,
name=None,
dtype='float32'):
"""Use name different from test to avoid let nose pick it up"""
if gluon_impl:
def get_gluon_output(name, x):
net = vision.get_model(name)
net.collect_params().initialize(mx.init.Xavier())
net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
inputs=mx.sym.var('data'),
params=net.collect_params())
out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
return out, net_sym
else:
def get_mxnet_output(symbol, x, dtype='float32'):
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
mod = mx.mod.Module(symbol, label_names=None)
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
mod.init_params()
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
out = mod.get_outputs()[0].asnumpy()
args, auxs = mod.get_params()
return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
shape_dict = {"data": x.shape}
if gluon_impl:
new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict)
else:
new_sym, params = relay.frontend.from_mxnet(symbol,
shape_dict,
arg_params=args,
aux_params=auxs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(new_sym, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("data", tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
# random input
x = np.random.uniform(size=data_shape)
if gluon_impl:
gluon_out, gluon_sym = get_gluon_output(name, x)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
else:
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
assert "data" not in args
for target, ctx in ctx_list():
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_mlp():
mlp = model_zoo.mx_mlp()
verify_mxnet_frontend_impl(mlp,
data_shape=(1, 1, 28, 28),
out_shape=(1, 10))
def test_forward_vgg():
for n in [11]:
mx_sym = model_zoo.mx_vgg(n)
verify_mxnet_frontend_impl(mx_sym)
def test_forward_resnet():
for n in [18]:
mx_sym = model_zoo.mx_resnet(18)
verify_mxnet_frontend_impl(mx_sym)
def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_rrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_prelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.Activation(data, act_type='softrelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_fc_flatten():
# test flatten=True option in mxnet 0.11.1
data = mx.sym.var('data')
try:
mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
except:
pass
def test_forward_clip():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicity
mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_split():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))
def test_forward_split_squeeze():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
def test_forward_expand_dims():
data = mx.sym.var('data')
mx_sym = mx.sym.expand_dims(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
def test_forward_pooling():
data = mx.sym.var('data')
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
def test_forward_lrn():
data = mx.sym.var('data')
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
def test_forward_ones():
data = mx.sym.var('data')
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, zeros)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_ones_like():
data = mx.sym.var('data')
mx_sym = mx.sym.ones_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_argmax():
data = mx.sym.var('data')
mx_sym = mx.sym.argmax(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,))
def test_forward_argmin():
data = mx.sym.var('data')
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
test_forward_softrelu()
test_forward_fc_flatten()
test_forward_clip()
test_forward_split()
test_forward_split_squeeze()
test_forward_expand_dims()
test_forward_pooling()
test_forward_lrn()
test_forward_ones()
test_forward_zeros()
test_forward_ones_like()
test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
import mxnet as mx
from tvm import relay
import model_zoo
def compare_graph(f1, f2):
f1 = relay.ir_pass.infer_type(f1)
f2 = relay.ir_pass.infer_type(f2)
assert relay.ir_pass.alpha_equal(f1, f2)
def test_mlp():
shape = {"data": (1, 1, 28, 28)}
mx_fun = model_zoo.mx_mlp()
from_mx_fun, _ = relay.frontend.from_mxnet(mx_fun, shape=shape)
relay_fun = model_zoo.relay_mlp()
compare_graph(from_mx_fun, relay_fun)
def test_vgg():
shape = {"data": (1, 3, 224, 224)}
for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg(n)
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
relay_sym = model_zoo.relay_vgg(n)
compare_graph(from_mx_sym, relay_sym)
def test_resnet():
shape = {"data": (1, 3, 224, 224)}
for n in [18, 34, 50, 101]:
mx_sym = model_zoo.mx_resnet(n)
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
relay_sym = model_zoo.relay_resnet(n)
compare_graph(from_mx_sym, relay_sym)
def test_squeezenet():
shape = {"data": (1, 3, 224, 224)}
for version in ['1.0', '1.1']:
mx_sym = model_zoo.mx_squeezenet(version)
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_squeezenet(version)
compare_graph(from_mx_sym, relay_sym)
def test_inception_v3():
shape = {"data": (1, 3, 299, 299)}
mx_sym = model_zoo.mx_inception_v3()
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_inception_v3()
compare_graph(from_mx_sym, relay_sym)
def test_dqn():
shape = {"data": (1, 4, 84, 84)}
mx_sym = model_zoo.mx_dqn()
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_dqn()
compare_graph(from_mx_sym, relay_sym)
def test_dcgan():
shape = {"data": (2, 100)}
mx_sym = model_zoo.mx_dcgan()
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_dcgan(batch_size=2)
compare_graph(from_mx_sym, relay_sym)
def test_multi_outputs():
xshape = (10, 27)
yshape = (10, 9)
def mx_compose(F, **kwargs):
x = F.sym.Variable("x")
y = F.sym.Variable("y")
z = F.sym.split(x, **kwargs)
return F.sym.broadcast_sub(F.sym.broadcast_add(z[0], z[2]), y)
def relay_compose(F, **kwargs):
x = F.var("x", shape=xshape)
y = F.var("y", shape=yshape)
z = F.split(x, **kwargs)
z = F.subtract(F.add(z[0], z[2]), y)
return relay.Function(relay.ir_pass.free_vars(z), z)
mx_sym = mx_compose(mx, num_outputs=3, axis=1)
from_mx_sym, _ = relay.frontend.from_mxnet(
mx_sym, shape={"x":xshape, "y":yshape})
relay_sym = relay_compose(relay, indices_or_sections=3, axis=1)
compare_graph(from_mx_sym, relay_sym)
if __name__ == "__main__":
test_mlp()
test_resnet()
test_vgg()
test_multi_outputs()
test_dqn()
test_dcgan()
test_squeezenet()
test_inception_v3()
......@@ -115,7 +115,7 @@ def test_squeeze_bad_axes_infer_type():
def test_reshape_infer_type():
n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20
n, t, d1, d2 = 10, 20, 100, 20
x = relay.var("x", relay.TensorType((n, t, d1, d2), "float32"))
y = relay.reshape(x, newshape=(n, t, 2000))
assert "newshape=" in y.astext()
......
......@@ -332,7 +332,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
* \return A Tensor whose op member is the split operation
*/
inline Array<Tensor> split(const Tensor& x,
Array<Expr> split_indices,
Array<Integer> split_indices,
int axis,
std::string name = "tensor",
std::string tag = kInjective) {
......@@ -342,14 +342,15 @@ inline Array<Tensor> split(const Tensor& x,
CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
auto split_indices_val = GetConstIntValues(split_indices, "split_indices");
CHECK(std::is_sorted(split_indices_val.begin(), split_indices_val.end())) <<
"split_indices must be sorted";
std::vector<int> begin_ids;
begin_ids.push_back(0);
std::copy(split_indices_val.begin(), split_indices_val.end(), std::back_inserter(begin_ids));
for (Integer idx : split_indices) {
int val = static_cast<int>(idx->value);
CHECK_GT(val, begin_ids.back())
<< "split_indices must be sorted";
begin_ids.push_back(val);
}
Array< Array<Expr> > out_shapes;
for (size_t i = 0; i < begin_ids.size(); ++i) {
......@@ -508,10 +509,10 @@ inline Tensor strided_slice(const Tensor& x,
* \return A Tensor whose op member is the split operation
*/
inline Array<Tensor> split_sections(const Tensor& x,
int num_sections,
int axis,
std::string name = "tensor",
std::string tag = kInjective) {
int num_sections,
int axis,
std::string name = "tensor",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
......@@ -524,7 +525,7 @@ inline Array<Tensor> split_sections(const Tensor& x,
<< "num_sections must be an integer factor of the size of axis " << axis
<< " (" << src_axis_size << ")";
Array<Expr> split_indices;
Array<Integer> split_indices;
auto seg_size = src_axis_size / num_sections;
for (int i = 0; i < num_sections; ++i) {
// region at index 0 is added by split()
......
......@@ -53,7 +53,7 @@ inline Tensor region(const Tensor &data,
input_shape[2],
input_shape[3]};
auto data_block = reshape(data, intermediate_shape);
Array <Expr> split_indices;
Array <Integer> split_indices;
for (int i = 1; i < split_size; ++i) {
split_indices.push_back(i);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment