Commit 426e3bb0 by Lianmin Zheng Committed by Tianqi Chen

[RELAY] Port winograd ops to relay (#2356)

parent 9b4b360f
...@@ -38,7 +38,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { ...@@ -38,7 +38,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
IndexExpr channels; IndexExpr channels;
Array<IndexExpr> kernel_size; Array<IndexExpr> kernel_size;
std::string data_layout; std::string data_layout;
std::string weight_layout; std::string kernel_layout;
std::string out_layout; std::string out_layout;
DataType out_dtype; DataType out_dtype;
...@@ -68,7 +68,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { ...@@ -68,7 +68,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
TVM_ATTR_FIELD(weight_layout).set_default("OIHW") TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
...@@ -84,13 +84,85 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { ...@@ -84,13 +84,85 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
} }
}; };
/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradWeightTransformAttrs :
public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> {
int tile_size;
TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs,
"relay.attrs.Conv2DWinogradWeightTransformAttrs") {
TVM_ATTR_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
}
};
/*! \brief Attributes used in convolution operators with winograd algorithm */
struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
int tile_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") {
TVM_ATTR_FIELD(tile_size)
.describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes used in softmax operators */ /*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> { struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis; int axis;
TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") { TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1) TVM_ATTR_FIELD(axis).set_default(-1)
.describe("The axis to sum over when computing softmax."); .describe("The axis to sum over when computing softmax.");
} }
}; };
...@@ -104,7 +176,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> { ...@@ -104,7 +176,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
Array<IndexExpr> dilation; Array<IndexExpr> dilation;
int groups; int groups;
std::string data_layout; std::string data_layout;
std::string weight_layout; std::string kernel_layout;
std::string out_layout; std::string out_layout;
DataType out_dtype; DataType out_dtype;
...@@ -136,7 +208,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> { ...@@ -136,7 +208,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
TVM_ATTR_FIELD(weight_layout).set_default("OIHW") TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument
"""Convert an NNVM graph to Relay.""" """Convert an NNVM graph to Relay."""
import json import json
import numpy
from tvm import relay, nd from tvm import relay, nd
from tvm.relay import op, expr, var from tvm.relay import op, expr, var
from tvm.relay.frontend.common import StrAttrsDict from tvm.relay.frontend.common import StrAttrsDict
from tvm.relay.frontend.nnvm_common import _rename from tvm.relay.frontend.nnvm_common import _rename
import numpy
from .symbol import Symbol from .symbol import Symbol
from .compiler import graph_attr from .compiler import graph_attr
from .graph import create as graph_create from .graph import create as graph_create
...@@ -42,7 +43,7 @@ def _conv2d(children, attrs, odtype='float32'): ...@@ -42,7 +43,7 @@ def _conv2d(children, attrs, odtype='float32'):
dilation = attrs.get_int_tuple('dilation', (1, 1)) dilation = attrs.get_int_tuple('dilation', (1, 1))
groups = attrs.get_int('groups', 1) groups = attrs.get_int('groups', 1)
data_layout = attrs.get_str('layout', 'NCHW') data_layout = attrs.get_str('layout', 'NCHW')
weight_layout = attrs.get_str('kernel_layout', 'OIHW') kernel_layout = attrs.get_str('kernel_layout', 'OIHW')
out_layout = '' out_layout = ''
out_dtype = attrs.get_str('out_dtype', '') out_dtype = attrs.get_str('out_dtype', '')
...@@ -54,7 +55,7 @@ def _conv2d(children, attrs, odtype='float32'): ...@@ -54,7 +55,7 @@ def _conv2d(children, attrs, odtype='float32'):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
data_layout=data_layout, data_layout=data_layout,
weight_layout=weight_layout, kernel_layout=kernel_layout,
out_layout=out_layout, out_layout=out_layout,
out_dtype=out_dtype) out_dtype=out_dtype)
...@@ -77,7 +78,7 @@ def _conv2d_transpose(children, attrs, odtype='float32'): ...@@ -77,7 +78,7 @@ def _conv2d_transpose(children, attrs, odtype='float32'):
dilation = attrs.get_int_tuple('dilation', (1, 1)) dilation = attrs.get_int_tuple('dilation', (1, 1))
groups = attrs.get_int('groups', 1) groups = attrs.get_int('groups', 1)
data_layout = attrs.get_str('layout', 'NCHW') data_layout = attrs.get_str('layout', 'NCHW')
weight_layout = attrs.get_str('kernel_layout', 'OIHW') kernel_layout = attrs.get_str('kernel_layout', 'OIHW')
out_dtype = attrs.get_str('out_dtype', '') out_dtype = attrs.get_str('out_dtype', '')
out_conv2d = op.nn.conv2d_transpose( out_conv2d = op.nn.conv2d_transpose(
...@@ -88,7 +89,7 @@ def _conv2d_transpose(children, attrs, odtype='float32'): ...@@ -88,7 +89,7 @@ def _conv2d_transpose(children, attrs, odtype='float32'):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
data_layout=data_layout, data_layout=data_layout,
weight_layout=weight_layout, kernel_layout=kernel_layout,
out_dtype=out_dtype) out_dtype=out_dtype)
if use_bias: if use_bias:
......
...@@ -138,7 +138,7 @@ class AttrDict(object): ...@@ -138,7 +138,7 @@ class AttrDict(object):
else: else:
raise ValueError("Wrong bool format for key %s" % key) raise ValueError("Wrong bool format for key %s" % key)
def get_string(self, key): def get_str(self, key):
"""Get string from attr dict """Get string from attr dict
Parameters Parameters
......
...@@ -153,7 +153,25 @@ def schedule_conv2d(attrs, outs, target): ...@@ -153,7 +153,25 @@ def schedule_conv2d(attrs, outs, target):
@reg.register_alter_op_layout("conv2d") @reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos): def alter_conv2d_layout(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos) """Replace conv2d op with other layouts or algorithms"""
import nnvm.symbol as sym
# map relay op names to nnvm op names
sym.contrib_conv2d_winograd_without_weight_transform = \
sym.contrib.conv2d_winograd_without_weight_transform
sym.contrib_conv2d_winograd_weight_transform = \
sym.contrib.conv2d_winograd_weight_transform
sym.nn = sym
# map relay argument names to nnvm argument names
raw_reshape = sym.reshape
def _reshape(*args, **kwargs):
if "newshape" in kwargs:
kwargs['shape'] = kwargs.pop('newshape')
return raw_reshape(*args, **kwargs)
sym.reshape = _reshape
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, sym)
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -166,9 +184,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): ...@@ -166,9 +184,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
dilation = attrs.get_int_tuple("dilation") dilation = attrs.get_int_tuple("dilation")
out_channel = attrs.get_int("channels") out_channel = attrs.get_int("channels")
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
layout = attrs.get_string("layout") layout = attrs.get_str("layout")
out_layout = attrs.get_string("out_layout") out_layout = attrs.get_str("out_layout")
out_dtype = attrs.get_string("out_dtype") out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
if layout == "NCHW": if layout == "NCHW":
_, in_channel, _, _ = get_const_tuple(inputs[0].shape) _, in_channel, _, _ = get_const_tuple(inputs[0].shape)
...@@ -227,8 +245,8 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, _): ...@@ -227,8 +245,8 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, _):
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation") dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
layout = attrs.get_string("layout") layout = attrs.get_str("layout")
out_dtype = attrs.get_string("out_dtype") out_dtype = attrs.get_str("out_dtype")
tile_size = attrs.get_int("tile_size") tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now" assert dilation == (1, 1), "Do not support dilate now"
...@@ -262,7 +280,7 @@ def compute_conv2d_transpose(attrs, inputs, _): ...@@ -262,7 +280,7 @@ def compute_conv2d_transpose(attrs, inputs, _):
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation") dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
out_dtype = attrs.get_string("out_dtype") out_dtype = attrs.get_str("out_dtype")
layout = attrs["layout"] layout = attrs["layout"]
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
......
...@@ -33,6 +33,45 @@ class Attrs(NodeBase): ...@@ -33,6 +33,45 @@ class Attrs(NodeBase):
for field in fields: for field in fields:
yield field.name yield field.name
def get_int_tuple(self, key):
"""Get a python int tuple of a key
Parameters
----------
key: str
Returns
-------
value: Tuple of int
"""
return tuple(x.value for x in self.__getattr__(key))
def get_int(self, key):
"""Get a python int value of a key
Parameters
----------
key: str
Returns
-------
value: int
"""
return self.__getattr__(key)
def get_str(self, key):
"""Get a python int value of a key
Parameters
----------
key: str
Returns
-------
value: int
"""
return self.__getattr__(key)
def __getitem__(self, item): def __getitem__(self, item):
return self.__getattr__(item) return self.__getattr__(item)
......
...@@ -119,7 +119,7 @@ def _bind_params_by_name(func, params): ...@@ -119,7 +119,7 @@ def _bind_params_by_name(func, params):
return expr.bind(func, bind_dict) return expr.bind(func, bind_dict)
def optimize(func, params=None): def optimize(func, target, params=None):
"""Perform target invariant optimizations. """Perform target invariant optimizations.
Parameters Parameters
...@@ -127,6 +127,9 @@ def optimize(func, params=None): ...@@ -127,6 +127,9 @@ def optimize(func, params=None):
func : tvm.relay.Function func : tvm.relay.Function
The input to optimization. The input to optimization.
target: :any:`tvm.target.Target`
The optimization target. Some optimization passes are target specific.
params : Optional[Dict[str, tvm.nd.NDArray]] params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change Input parameters to the graph that do not change
during inference time. used for constant folding. during inference time. used for constant folding.
...@@ -164,7 +167,11 @@ def optimize(func, params=None): ...@@ -164,7 +167,11 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func) func = ir_pass.canonicalize_ops(func)
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.alter_op_layout(func) with target:
func = ir_pass.alter_op_layout(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
return func return func
...@@ -222,7 +229,7 @@ def build(func, ...@@ -222,7 +229,7 @@ def build(func,
cfg = BuildConfig.current cfg = BuildConfig.current
with tophub_context: with tophub_context:
func = optimize(func, params) func = optimize(func, target, params)
# Fuse ops before running code gen # Fuse ops before running code gen
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level) func = ir_pass.fuse_ops(func, cfg.opt_level)
......
...@@ -72,9 +72,9 @@ def _mx_conv2d(inputs, attrs): ...@@ -72,9 +72,9 @@ def _mx_conv2d(inputs, attrs):
channel_axis = _get_channel_axis(data_layout, "conv2d") channel_axis = _get_channel_axis(data_layout, "conv2d")
if "kernel_layout" in attrs.attrs: if "kernel_layout" in attrs.attrs:
weight_layout = attrs.get_str("kernel_layout") kernel_layout = attrs.get_str("kernel_layout")
else: else:
weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW" kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
new_attrs = {} new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter") new_attrs["channels"] = attrs.get_int("num_filter")
...@@ -84,7 +84,7 @@ def _mx_conv2d(inputs, attrs): ...@@ -84,7 +84,7 @@ def _mx_conv2d(inputs, attrs):
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1)) new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1))
new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout new_attrs["data_layout"] = data_layout
new_attrs["weight_layout"] = weight_layout new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", False) use_bias = not attrs.get_bool("no_bias", False)
res = _op.nn.conv2d(inputs[0], inputs[1], **new_attrs) res = _op.nn.conv2d(inputs[0], inputs[1], **new_attrs)
if use_bias: if use_bias:
...@@ -103,9 +103,9 @@ def _mx_conv2d_transpose(inputs, attrs): ...@@ -103,9 +103,9 @@ def _mx_conv2d_transpose(inputs, attrs):
channel_axis = _get_channel_axis(data_layout, "conv2d_transpose") channel_axis = _get_channel_axis(data_layout, "conv2d_transpose")
if "kernel_layout" in attrs.attrs: if "kernel_layout" in attrs.attrs:
weight_layout = attrs.get_str("kernel_layout") kernel_layout = attrs.get_str("kernel_layout")
else: else:
weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW" kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
new_attrs = {} new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter") new_attrs["channels"] = attrs.get_int("num_filter")
...@@ -116,7 +116,7 @@ def _mx_conv2d_transpose(inputs, attrs): ...@@ -116,7 +116,7 @@ def _mx_conv2d_transpose(inputs, attrs):
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1)) new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1))
new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout new_attrs["data_layout"] = data_layout
new_attrs["weight_layout"] = weight_layout new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", False) use_bias = not attrs.get_bool("no_bias", False)
res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs) res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs)
......
...@@ -55,7 +55,7 @@ def compute_conv2d(attrs, inputs, out_type, target): ...@@ -55,7 +55,7 @@ def compute_conv2d(attrs, inputs, out_type, target):
dilation = get_const_tuple(attrs.dilation) dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups groups = attrs.groups
layout = attrs.data_layout layout = attrs.data_layout
weight_layout = attrs.weight_layout kernel_layout = attrs.kernel_layout
out_dtype = attrs.out_dtype out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "") out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "")
else out_dtype) else out_dtype)
...@@ -70,13 +70,13 @@ def compute_conv2d(attrs, inputs, out_type, target): ...@@ -70,13 +70,13 @@ def compute_conv2d(attrs, inputs, out_type, target):
inputs[0], inputs[1], strides, padding, inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype) dilation, layout, out_dtype=out_dtype)
elif layout == "NCHW" and \ elif layout == "NCHW" and \
weight_layout == "OIHW" and \ kernel_layout == "OIHW" and \
get_const_int(inputs[1].shape[0]) == groups and \ get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1: get_const_int(inputs[1].shape[1]) == 1:
out = topi.nn.depthwise_conv2d_nchw( out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
elif layout == "NHWC" and \ elif layout == "NHWC" and \
weight_layout == "HWOI" and\ kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \ get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1: get_const_int(inputs[1].shape[3]) == 1:
out = topi.nn.depthwise_conv2d_nhwc( out = topi.nn.depthwise_conv2d_nhwc(
...@@ -91,7 +91,7 @@ def schedule_conv2d(attrs, outs, target): ...@@ -91,7 +91,7 @@ def schedule_conv2d(attrs, outs, target):
"""Schedule definition of conv2d""" """Schedule definition of conv2d"""
groups = attrs.groups groups = attrs.groups
layout = attrs.data_layout layout = attrs.data_layout
kernel_layout = attrs.weight_layout kernel_layout = attrs.kernel_layout
with target: with target:
if groups == 1 and layout == "NCHW": if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
...@@ -111,7 +111,8 @@ def schedule_conv2d(attrs, outs, target): ...@@ -111,7 +111,8 @@ def schedule_conv2d(attrs, outs, target):
@reg.register_alter_op_layout("nn.conv2d") @reg.register_alter_op_layout("nn.conv2d")
def alter_op_layout_conv2d(attrs, inputs, tinfos): def alter_op_layout_conv2d(attrs, inputs, tinfos):
"""Alternate the layout of conv2d""" """Alternate the layout of conv2d"""
return None from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -249,7 +250,7 @@ def schedule_l2_normalize(attrs, outs, target): ...@@ -249,7 +250,7 @@ def schedule_l2_normalize(attrs, outs, target):
reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
# Upsampling # upsampling
reg.register_schedule("nn.upsampling", reg.schedule_injective) reg.register_schedule("nn.upsampling", reg.schedule_injective)
def schedule_upsampling(_, outs, target): def schedule_upsampling(_, outs, target):
"""Schedule definition of upsampling""" """Schedule definition of upsampling"""
...@@ -257,3 +258,50 @@ def schedule_upsampling(_, outs, target): ...@@ -257,3 +258,50 @@ def schedule_upsampling(_, outs, target):
return topi.generic.schedule_injective(outs) return topi.generic.schedule_injective(outs)
# pad # pad
reg.register_schedule("nn.pad", schedule_broadcast) reg.register_schedule("nn.pad", schedule_broadcast)
# winograd related operators
@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of conv2d_winograd_without_weight_transform"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.get_str("out_dtype")
tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
out = topi.nn.conv2d_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, dilation, data_layout,
out_dtype, tile_size)
return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_without_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of contrib_conv2d_winograd_weight_transform"""
out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size'))
return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -13,7 +13,7 @@ def conv2d(data, ...@@ -13,7 +13,7 @@ def conv2d(data,
channels=None, channels=None,
kernel_size=None, kernel_size=None,
data_layout="NCHW", data_layout="NCHW",
weight_layout="OIHW", kernel_layout="OIHW",
out_layout="", out_layout="",
out_dtype=""): out_dtype=""):
r"""2D convolution. r"""2D convolution.
...@@ -23,7 +23,7 @@ def conv2d(data, ...@@ -23,7 +23,7 @@ def conv2d(data,
In the default case, where the data_layout is `NCHW` In the default case, where the data_layout is `NCHW`
and weight_layout is `OIHW`, conv2d takes in and kernel_layout is `OIHW`, conv2d takes in
a data Tensor with shape `(batch_size, in_channels, height, width)`, a data Tensor with shape `(batch_size, in_channels, height, width)`,
and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1])` and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1])`
to produce an output Tensor with the following rule: to produce an output Tensor with the following rule:
...@@ -70,7 +70,7 @@ def conv2d(data, ...@@ -70,7 +70,7 @@ def conv2d(data,
data_layout : str, optional data_layout : str, optional
Layout of the input. Layout of the input.
weight_layout : str, optional kernel_layout : str, optional
Layout of the weight. Layout of the weight.
out_layout : str, optional out_layout : str, optional
...@@ -86,7 +86,7 @@ def conv2d(data, ...@@ -86,7 +86,7 @@ def conv2d(data,
""" """
return _make.conv2d(data, weight, strides, padding, dilation, return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
weight_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
def conv2d_transpose(data, def conv2d_transpose(data,
...@@ -98,7 +98,7 @@ def conv2d_transpose(data, ...@@ -98,7 +98,7 @@ def conv2d_transpose(data,
channels=None, channels=None,
kernel_size=None, kernel_size=None,
data_layout="NCHW", data_layout="NCHW",
weight_layout="OIHW", kernel_layout="OIHW",
output_padding=(0, 0), output_padding=(0, 0),
out_dtype=""): out_dtype=""):
"""Two dimensional trnasposed convolution operator. """Two dimensional trnasposed convolution operator.
...@@ -126,7 +126,7 @@ def conv2d_transpose(data, ...@@ -126,7 +126,7 @@ def conv2d_transpose(data,
data_layout : str, optional data_layout : str, optional
Layout of the input. Layout of the input.
weight_layout : str, optional kernel_layout : str, optional
Layout of the weight. Layout of the weight.
output_padding : Tuple[int], optional output_padding : Tuple[int], optional
...@@ -142,7 +142,7 @@ def conv2d_transpose(data, ...@@ -142,7 +142,7 @@ def conv2d_transpose(data,
""" """
return _make.conv2d_transpose(data, weight, strides, padding, dilation, return _make.conv2d_transpose(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
weight_layout, output_padding, out_dtype) kernel_layout, output_padding, out_dtype)
def softmax(data, axis=-1): def softmax(data, axis=-1):
...@@ -765,3 +765,96 @@ def batch_norm(data, ...@@ -765,3 +765,96 @@ def batch_norm(data,
center, center,
scale) scale)
return TupleWrapper(result, 3) return TupleWrapper(result, 3)
def contrib_conv2d_winograd_without_weight_transform(data,
weight,
tile_size,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""2D convolution with winograd algorithm.
The basic parameters are the same as the ones in vanilla conv2d.
It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_winograd_without_weight_transform(
data, weight, tile_size, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 2D convolution with winograd algorithm.
We separate this as a single op to enable pre-compute for inference.
Use this together with nn.contrib_conv2d_winograd_without_weight_transform
Parameters
----------
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
...@@ -5,10 +5,20 @@ from ..base import register_relay_attr_node ...@@ -5,10 +5,20 @@ from ..base import register_relay_attr_node
@register_relay_attr_node @register_relay_attr_node
class Conv2DAttrs(Attrs): class Conv2DAttrs(Attrs):
"""Attribute of a Convolution Operator""" """Attribute of nn.conv2d"""
pass
@register_relay_attr_node
class Conv2DWinogradAttrs(Attrs):
"""Attribute of nn.contrib_conv2d_winograd_without_weight_transform"""
pass
@register_relay_attr_node
class Conv2DWinogradWeightTransformAttrs(Attrs):
"""Attribute of nn.contrib_conv2d_winograd_weight_transform"""
pass pass
@register_relay_attr_node @register_relay_attr_node
class GlobalPool2DAttrs(Attrs): class GlobalPool2DAttrs(Attrs):
"""Attribute of a Global 2D Pooling Operator""" """Attribute of nn.global_pool"""
pass pass
...@@ -11,4 +11,6 @@ from . import inception_v3 ...@@ -11,4 +11,6 @@ from . import inception_v3
from . import squeezenet from . import squeezenet
from . import vgg from . import vgg
from . import densenet from . import densenet
from .config import ctx_list from .config import ctx_list
from .init import create_workload
...@@ -60,7 +60,7 @@ class Layout : public NodeRef { ...@@ -60,7 +60,7 @@ class Layout : public NodeRef {
Layout() : Layout("__undef__") {} // NOLINT(*) Layout() : Layout("__undef__") {} // NOLINT(*)
/*! \brief construct from a string */ /*! \brief construct from a string */
Layout(const char* str) : Layout(std::string(str)) {} // NOLINT(*) Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
/*! /*!
* \brief construct from a string. * \brief construct from a string.
...@@ -70,11 +70,64 @@ class Layout : public NodeRef { ...@@ -70,11 +70,64 @@ class Layout : public NodeRef {
* indicates the split dimension. * indicates the split dimension.
* return undefined layout if "__undef__" is passed. * return undefined layout if "__undef__" is passed.
*/ */
Layout(const std::string& layout) { // NOLINT(*) Layout(const std::string& name) { // NOLINT(*)
if (layout.length() != 0) { node_ = make_node<LayoutNode>();
Parse(layout);
} else { std::vector<uint32_t> superdim_pos(kUniqueDim, -1);
Parse("__undef__"); std::vector<uint32_t> subdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_size(kUniqueDim, -1);
std::vector<char> layout_simplified;
if (name != "__undef__") { // parse layout string
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < name.size(); ++i) {
const LayoutDim c = name.at(i);
if (IsSuperdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << name
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
superdim_pos[pos] = curr++;
layout_simplified.push_back(c);
} else if (IsSubdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << name
<< ": duplicate dimension " << c;
subdim_pos[pos] = curr++;
subdim_size[pos] = factor;
layout_simplified.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << name;
}
}
for (LayoutDim dim : layout_simplified) {
CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
<< "Invalid layout " << name << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
LayoutNode *node = operator->();
node->name = name;
for (uint32_t i = 0; i < kUniqueDim; ++i) {
node->superdim_pos.push_back(superdim_pos[i]);
node->subdim_pos.push_back(subdim_pos[i]);
node->subdim_size.push_back(subdim_size[i]);
}
for (LayoutDim dim : layout_simplified) {
node->layout_simplified.push_back(dim);
} }
} }
...@@ -177,7 +230,6 @@ class Layout : public NodeRef { ...@@ -177,7 +230,6 @@ class Layout : public NodeRef {
const Array<Integer>& layout_simplified = operator->()->layout_simplified; const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (pos > ndim()) return Layout::Undef(); if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos; if (pos + len > ndim()) len = ndim() - pos;
if (len == 0) return Layout::Undef();
std::ostringstream new_layout; std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) { for (size_t i = pos; i < pos + len; ++i) {
if (IsSubdim(layout_simplified[i]->value)) { if (IsSubdim(layout_simplified[i]->value)) {
...@@ -349,69 +401,6 @@ class Layout : public NodeRef { ...@@ -349,69 +401,6 @@ class Layout : public NodeRef {
} }
using ContainerType = LayoutNode; using ContainerType = LayoutNode;
private:
void Parse(const std::string &layout) {
node_ = make_node<LayoutNode>();
std::vector<uint32_t> superdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_pos(kUniqueDim, -1);
std::vector<uint32_t> subdim_size(kUniqueDim, -1);
std::vector<char> layout_simplified;
if (layout != "__undef__") { // parse layout string
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < layout.size(); ++i) {
const LayoutDim c = layout.at(i);
if (IsSuperdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << layout
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
superdim_pos[pos] = curr++;
layout_simplified.push_back(c);
} else if (IsSubdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
subdim_pos[pos] = curr++;
subdim_size[pos] = factor;
layout_simplified.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
CHECK(!layout_simplified.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified) {
CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
<< "Invalid layout " << layout << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
LayoutNode *node = operator->();
node->name = layout;
for (uint32_t i = 0; i < kUniqueDim; ++i) {
node->superdim_pos.push_back(superdim_pos[i]);
node->subdim_pos.push_back(subdim_pos[i]);
node->subdim_size.push_back(subdim_size[i]);
}
for (LayoutDim dim : layout_simplified) {
node->layout_simplified.push_back(dim);
}
}
}; };
/*! /*!
......
...@@ -166,7 +166,7 @@ Call CallAlter(const Call& ref_call, ...@@ -166,7 +166,7 @@ Call CallAlter(const Call& ref_call,
} }
if (!modified) { if (!modified) {
new_e = CallNode::make(ref_call->op, new_args, new_e = CallNode::make(ref_call->op, new_args,
ref_call->attrs, ref_call->type_args); ref_call->attrs);
} }
const CallNode *new_call = new_e.as<CallNode>(); const CallNode *new_call = new_e.as<CallNode>();
...@@ -184,30 +184,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, ...@@ -184,30 +184,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
// NOTE: discard the "const" qualifier // NOTE: discard the "const" qualifier
TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx); TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);
// fill incomplete state and expand tuple // fill incomplete state and flatten tuple
for (auto new_arg : new_args) { auto push_back_one_arg = [&inputs, memorizer](Expr arg) {
auto push_back_one_arg = [&](Expr arg) { // We always expect LayoutAlternatedExpr.
// We always expect LayoutAlternatedExpr. // This is used to convert the normal Expr to LayoutAlternatedExpr.
// This is used to convert the normal Expr to LayoutAlternatedExpr. if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) {
if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) { inputs.push_back(GetRef<LayoutAlternatedExpr>(inp));
inputs.push_back(GetRef<LayoutAlternatedExpr>(inp)); return inp->value;
normal_new_args.push_back(inp->value); } else {
} else { auto inode = make_node<LayoutAlternatedExprNode>();
auto inode = make_node<LayoutAlternatedExprNode>(); inode->value = arg;
inode->value = arg; inode->memorizer = memorizer;
inode->memorizer = memorizer; inputs.push_back(LayoutAlternatedExpr(inode));
inputs.push_back(LayoutAlternatedExpr(inode)); return arg;
normal_new_args.push_back(arg); }
} };
};
for (auto new_arg : new_args) {
// NOTE: do not support nested tuple
if (new_arg->is_type<TupleNode>()) { if (new_arg->is_type<TupleNode>()) {
Tuple tuple_new_arg = Downcast<Tuple>(new_arg); Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
std::vector<Expr> fields;
for (auto x : tuple_new_arg->fields) { for (auto x : tuple_new_arg->fields) {
push_back_one_arg(x); Expr tmp = push_back_one_arg(x);
fields.push_back(tmp);
} }
normal_new_args.push_back(TupleNode::make(fields));
} else { } else {
push_back_one_arg(new_arg); Expr tmp = push_back_one_arg(new_arg);
normal_new_args.push_back(tmp);
} }
} }
...@@ -219,7 +224,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, ...@@ -219,7 +224,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
} }
for (auto arg : ref_call->args) { for (auto arg : ref_call->args) {
if (arg->is_type<TupleNode>()) { // expand tuple if (arg->is_type<TupleNode>()) { // flatten tuple
Tuple tuple_arg = Downcast<Tuple>(arg); Tuple tuple_arg = Downcast<Tuple>(arg);
for (auto x : tuple_arg->fields) { for (auto x : tuple_arg->fields) {
input_shapes.push_back(x->type_as<TensorTypeNode>()->shape); input_shapes.push_back(x->type_as<TensorTypeNode>()->shape);
...@@ -263,17 +268,30 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, ...@@ -263,17 +268,30 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
// if (new_in != new_in2): insert transform (new_in -> new_in2) // if (new_in != new_in2): insert transform (new_in -> new_in2)
Array<Expr> transformed_args; Array<Expr> transformed_args;
for (size_t i = 0; i < inputs.size(); ++i) { size_t pt = 0;
transformed_args.push_back(memorizer.Transform(new_call->args[i], new_in[i], new_in2[i])); for (auto arg : new_call->args) {
if (arg->is_type<TupleNode>()) { // unflatten tuple
Tuple tuple_arg = Downcast<Tuple>(arg);
std::vector<Expr> transformed_tuple_arg;
for (auto arg_item : tuple_arg->fields) {
transformed_tuple_arg.push_back(
memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
pt++;
}
transformed_args.push_back(TupleNode::make(transformed_tuple_arg));
} else {
transformed_args.push_back(
memorizer.Transform(arg, new_in[pt], new_in2[pt]));
pt++;
}
} }
CHECK_EQ(pt, inputs.size());
// state[node] = (old_out, new_out) // state[node] = (old_out, new_out)
CHECK(ref_call->checked_type_.defined()) // (handle tuple output)
<< "Call infer_type pass before alter_op_layout pass";
if (ref_call->checked_type()->is_type<TupleTypeNode>()) { if (ref_call->checked_type()->is_type<TupleTypeNode>()) {
Expr tuple_output = CallNode::make(new_call->op, transformed_args, Expr tuple_output = CallNode::make(new_call->op, transformed_args,
new_call->attrs, new_call->type_args); new_call->attrs);
Array<Expr> fields; Array<Expr> fields;
for (size_t i = 0; i < new_out.size(); ++i) { for (size_t i = 0; i < new_out.size(); ++i) {
auto rnode = make_node<LayoutAlternatedExprNode>(); auto rnode = make_node<LayoutAlternatedExprNode>();
...@@ -288,7 +306,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, ...@@ -288,7 +306,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
auto rnode = make_node<LayoutAlternatedExprNode>(); auto rnode = make_node<LayoutAlternatedExprNode>();
CHECK_EQ(new_out.size(), 1); CHECK_EQ(new_out.size(), 1);
rnode->value = CallNode::make(new_call->op, transformed_args, rnode->value = CallNode::make(new_call->op, transformed_args,
new_call->attrs, new_call->type_args); new_call->attrs);
rnode->old_layout = old_out[0]; rnode->old_layout = old_out[0];
rnode->new_layout = new_out[0]; rnode->new_layout = new_out[0];
rnode->memorizer = memorizer; rnode->memorizer = memorizer;
...@@ -296,6 +314,9 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, ...@@ -296,6 +314,9 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
} }
} }
// Limiations:
// 1. the altered op should have the same number of arguments as the previous one
// 2. do not support nested tuple arguments
TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>()); TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
......
...@@ -91,13 +91,13 @@ class BranchGroupFinder : private ExprVisitor { ...@@ -91,13 +91,13 @@ class BranchGroupFinder : private ExprVisitor {
CHECK(attrs_b); CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>(); const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>(); const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW);
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
eq(attrs_a->data_layout, attrs_b->data_layout) && eq(attrs_a->data_layout, attrs_b->data_layout) &&
eq(attrs_a->weight_layout, attrs_b->weight_layout) && eq(attrs_a->kernel_layout, attrs_b->kernel_layout) &&
eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) && eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
eq(shape_a[3], shape_b[3]); eq(shape_a[3], shape_b[3]);
...@@ -159,7 +159,7 @@ class ParallelConv2DCombiner { ...@@ -159,7 +159,7 @@ class ParallelConv2DCombiner {
auto channels = GetConv2DSuperChannelsDim(conv2d); auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels; num_filters += channels;
} }
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->weight_layout.find('O'); auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
CHECK_NE(index, std::string::npos); CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters)); MakeConstScalar(Int(32), num_filters));
...@@ -182,7 +182,7 @@ class ParallelConv2DCombiner { ...@@ -182,7 +182,7 @@ class ParallelConv2DCombiner {
new_attrs->groups = attrs->groups; new_attrs->groups = attrs->groups;
new_attrs->kernel_size = attrs->kernel_size; new_attrs->kernel_size = attrs->kernel_size;
new_attrs->data_layout = attrs->data_layout; new_attrs->data_layout = attrs->data_layout;
new_attrs->weight_layout = attrs->weight_layout; new_attrs->kernel_layout = attrs->kernel_layout;
new_attrs->out_layout = attrs->out_layout; new_attrs->out_layout = attrs->out_layout;
new_attrs->out_dtype = attrs->out_dtype; new_attrs->out_dtype = attrs->out_dtype;
new_attrs->channels = new_channels; new_attrs->channels = new_channels;
......
...@@ -384,7 +384,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -384,7 +384,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
const auto* param = call->attrs.as<Conv2DAttrs>(); const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout); Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C'); int c_big_axis = data_layout.Indexof('C');
int c_small_axis = data_layout.Indexof('c'); int c_small_axis = data_layout.Indexof('c');
...@@ -397,8 +397,8 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -397,8 +397,8 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
// //
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (weight_layout.Indexof('i') < 0 && if (kernel_layout.Indexof('i') < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis}; data_axes = {c_big_axis};
...@@ -418,19 +418,19 @@ Expr Conv2DForwardRewrite(const Call& ref_call, ...@@ -418,19 +418,19 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
const auto* param = ref_call->attrs.as<Conv2DAttrs>(); const auto* param = ref_call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout data_layout(param->data_layout); Layout data_layout(param->data_layout);
Layout weight_layout(param->weight_layout); Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C'); int c_big_axis = data_layout.Indexof('C');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.Indexof('i'), -1); CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK(sdata->axes.size() == 1 && CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value); c_big_axis == sdata->axes[0]->value);
int big_oc_axis = weight_layout.Indexof('O'); int big_oc_axis = kernel_layout.Indexof('O');
int big_ic_axis = weight_layout.Indexof('I'); int big_ic_axis = kernel_layout.Indexof('I');
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d); CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr weight = new_args[1]; Expr weight = new_args[1];
...@@ -438,11 +438,11 @@ Expr Conv2DForwardRewrite(const Call& ref_call, ...@@ -438,11 +438,11 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
// match the ic_axis // match the ic_axis
if (is_depthwise_conv2d) { if (is_depthwise_conv2d) {
Expr scale = ExpandBiasToMatchAxis( Expr scale = ExpandBiasToMatchAxis(
sdata->scale, weight_layout.ndim(), {big_oc_axis}); sdata->scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, scale); weight = Multiply(weight, scale);
} else { } else {
Expr scale = ExpandBiasToMatchAxis( Expr scale = ExpandBiasToMatchAxis(
sdata->scale, weight_layout.ndim(), {big_ic_axis}); sdata->scale, kernel_layout.ndim(), {big_ic_axis});
weight = Multiply(weight, scale); weight = Multiply(weight, scale);
} }
// return transformed conv2d // return transformed conv2d
...@@ -799,11 +799,8 @@ RELAY_REGISTER_OP("multiply") ...@@ -799,11 +799,8 @@ RELAY_REGISTER_OP("multiply")
AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
const auto* param = call->attrs.as<Conv2DAttrs>(); const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout out_layout(param->out_layout); Layout kernel_layout(param->kernel_layout);
if (!out_layout.defined()) { Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
out_layout = Layout(param->data_layout);
}
Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.Indexof('C'); int c_big_axis = out_layout.Indexof('C');
int c_small_axis = out_layout.Indexof('c'); int c_small_axis = out_layout.Indexof('c');
...@@ -815,9 +812,9 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { ...@@ -815,9 +812,9 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
// //
// only handle depthwise or full conv2d. // only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast // TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (weight_layout.Indexof('o') < 0 && if (kernel_layout.Indexof('o') < 0 &&
weight_layout.Indexof('i') < 0 && kernel_layout.Indexof('i') < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
return {c_big_axis}; return {c_big_axis};
...@@ -836,23 +833,20 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -836,23 +833,20 @@ Expr Conv2DBackwardTransform(const Call& call,
} }
const auto* param = call->attrs.as<Conv2DAttrs>(); const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout out_layout(param->out_layout); Layout kernel_layout(param->kernel_layout);
if (!out_layout.defined()) { Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
out_layout = Layout(param->data_layout);
}
Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.Indexof('C'); int c_big_axis = out_layout.Indexof('C');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.Indexof('o'), -1); CHECK_EQ(kernel_layout.Indexof('o'), -1);
CHECK_EQ(weight_layout.Indexof('i'), -1); CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK(axes.size() == 1 && CHECK(axes.size() == 1 &&
c_big_axis == axes[0]->value); c_big_axis == axes[0]->value);
int big_oc_axis = weight_layout.Indexof('O'); int big_oc_axis = kernel_layout.Indexof('O');
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d); CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr data = transformer->Transform( Expr data = transformer->Transform(
...@@ -861,7 +855,7 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -861,7 +855,7 @@ Expr Conv2DBackwardTransform(const Call& call,
call->args[1], NullValue<AxesSet>(), NullValue<Expr>()); call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
// scale on input for deptwise. // scale on input for deptwise.
Expr wscale = ExpandBiasToMatchAxis( Expr wscale = ExpandBiasToMatchAxis(
scale, weight_layout.ndim(), {big_oc_axis}); scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, wscale); weight = Multiply(weight, wscale);
return CallNode::make( return CallNode::make(
call->op, {data, weight}, call->attrs, call->type_args); call->op, {data, weight}, call->attrs, call->type_args);
......
...@@ -112,11 +112,11 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, ...@@ -112,11 +112,11 @@ inline Expr ExpandBiasToMatchAxis(Expr bias,
*/ */
inline bool IsDepthwiseConv2D(const Call& call, inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param, const Conv2DAttrs* param,
const Layout& weight_layout) { const Layout& kernel_layout) {
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout( auto wshape = ConvertLayout(
call->args[1]->type_as<TensorTypeNode>()->shape, call->args[1]->type_as<TensorTypeNode>()->shape,
weight_layout, kOIHW); kernel_layout, kOIHW);
return is_const_int(wshape[0], param->groups) && return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1); is_const_int(wshape[1], 1);
} }
...@@ -129,7 +129,7 @@ inline bool IsDepthwiseConv2D(const Call& call, ...@@ -129,7 +129,7 @@ inline bool IsDepthwiseConv2D(const Call& call,
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
auto param = call->attrs.as<Conv2DAttrs>(); auto param = call->attrs.as<Conv2DAttrs>();
auto tweight = call->args[1]->type_as<TensorTypeNode>(); auto tweight = call->args[1]->type_as<TensorTypeNode>();
auto index = param->weight_layout.find('O'); auto index = param->kernel_layout.find('O');
CHECK_NE(index, std::string::npos); CHECK_NE(index, std::string::npos);
auto channels = as_const_int(tweight->shape[index]); auto channels = as_const_int(tweight->shape[index]);
return *channels; return *channels;
......
...@@ -41,7 +41,7 @@ def test_conv2d_infer_type(): ...@@ -41,7 +41,7 @@ def test_conv2d_infer_type():
padding=(1, 1), padding=(1, 1),
channels=16, channels=16,
data_layout="NCHW4n4c", data_layout="NCHW4n4c",
weight_layout="OIHW4o4i", kernel_layout="OIHW4o4i",
out_dtype="int32") out_dtype="int32")
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType( assert yy.checked_type == relay.TensorType(
......
...@@ -91,7 +91,7 @@ def test_alter_layout(): ...@@ -91,7 +91,7 @@ def test_alter_layout():
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c' new_attrs['data_layout'] = 'NCHW16c'
new_attrs['weight_layout'] = 'OIHW16i' new_attrs['kernel_layout'] = 'OIHW16i'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
def expected(): def expected():
...@@ -105,7 +105,7 @@ def test_alter_layout(): ...@@ -105,7 +105,7 @@ def test_alter_layout():
channels=64, channels=64,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
weight_layout="OIHW16i", kernel_layout="OIHW16i",
data_layout="NCHW16c") data_layout="NCHW16c")
b = relay.expand_dims(bias, axis=1, num_newaxis=2) b = relay.expand_dims(bias, axis=1, num_newaxis=2)
b = relay.layout_transform(b, "CHW", "CHW16c") b = relay.layout_transform(b, "CHW", "CHW16c")
...@@ -269,7 +269,7 @@ def test_alter_layout_broadcast_op(): ...@@ -269,7 +269,7 @@ def test_alter_layout_broadcast_op():
y = relay.Function(free_vars(y), y) y = relay.Function(free_vars(y), y)
return y return y
@register_alter_op_layout("nn.conv2d", level=102) @register_alter_op_layout("nn.conv2d", level=105)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
...@@ -305,6 +305,107 @@ def test_alter_layout_broadcast_op(): ...@@ -305,6 +305,107 @@ def test_alter_layout_broadcast_op():
assert(alpha_equal(a, b)) assert(alpha_equal(a, b))
def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight")
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
y = relay.add(y, relay.const(1, "float32"))
y = relay.Function(free_vars(y), y)
return y
@register_alter_op_layout("nn.conv2d", level=106)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
w = relay.var("weight")
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, w,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
y = relay.add(y, relay.const(1.0, "float32"))
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(free_vars(y), y)
return y
a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)
a = alter_op_layout(a)
a = infer_type(a)
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
def test_alter_layout_concatenate():
""" """
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
weight2 = relay.var('weight2')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.conv2d(y, weight2,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
ret = relay.concatenate([y, y1], axis=1)
y = relay.Function(free_vars(ret), ret)
return y
@register_alter_op_layout("nn.conv2d", level=107)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
weight2 = relay.var('weight2')
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
y1 = relay.nn.conv2d(y, weight2,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NCHW16c')
ret = relay.concatenate([y, y1], axis=1)
ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
y = relay.Function(free_vars(ret), ret)
return y
a = before()
a = infer_type(a)
a = alter_op_layout(a)
a = infer_type(a)
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
if __name__ == "__main__": if __name__ == "__main__":
test_alter_op() test_alter_op()
...@@ -313,3 +414,5 @@ if __name__ == "__main__": ...@@ -313,3 +414,5 @@ if __name__ == "__main__":
test_alter_layout_dual_path() test_alter_layout_dual_path()
test_alter_layout_resnet() test_alter_layout_resnet()
test_alter_layout_broadcast_op() test_alter_layout_broadcast_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
...@@ -67,14 +67,14 @@ def test_fold_fwd_dual_path(): ...@@ -67,14 +67,14 @@ def test_fold_fwd_dual_path():
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWIO", kernel_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight, y2 = relay.nn.conv2d(x, conv_weight,
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWIO", kernel_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
z = relay.add(y1, y2) z = relay.add(y1, y2)
...@@ -90,7 +90,7 @@ def test_fold_fwd_dual_path(): ...@@ -90,7 +90,7 @@ def test_fold_fwd_dual_path():
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWIO", kernel_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
y2 = relay.nn.conv2d(x, y2 = relay.nn.conv2d(x,
...@@ -98,7 +98,7 @@ def test_fold_fwd_dual_path(): ...@@ -98,7 +98,7 @@ def test_fold_fwd_dual_path():
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWIO", kernel_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
z = relay.add(y1, y2) z = relay.add(y1, y2)
......
...@@ -523,9 +523,25 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs): ...@@ -523,9 +523,25 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
##### REGISTER ALTER OP LAYOUT ##### ##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["arm_cpu"]) @conv2d_alter_layout.register(["arm_cpu"])
def _alter_conv2d_layout_arm(attrs, inputs, tinfos): def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
"""Alter op layout for pre-computing kernel transformation""" """Alter op layout for pre-computing kernel transformation
import nnvm.symbol as sym
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : nnvm.symbol or tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
"""
copy_inputs = [s for s in inputs] copy_inputs = [s for s in inputs]
new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs = {k: attrs[k] for k in attrs.keys()}
...@@ -534,9 +550,11 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -534,9 +550,11 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
groups = attrs.get_int('groups') groups = attrs.get_int('groups')
layout = attrs["layout"] data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"] out_dtype = attrs["out_dtype"]
out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype if out_dtype == "" or out_dtype == "same":
out_dtype = tinfos[0].dtype
if layout != 'NCHW' or groups != 1: if layout != 'NCHW' or groups != 1:
return None return None
...@@ -570,7 +588,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -570,7 +588,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
return sym.conv2d(*copy_inputs, **new_attrs) return F.nn.conv2d(*copy_inputs, **new_attrs)
else: # pre-compute weight transformation in winograd else: # pre-compute weight transformation in winograd
if "-device=arm_cpu" in target.options: if "-device=arm_cpu" in target.options:
tile_size = 4 tile_size = 4
...@@ -580,10 +598,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -580,10 +598,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
tile_size = _pick_tile_size(tinfos[0], tinfos[1]) tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val VC = cfg['tile_bna'].val
weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
weight = sym.reshape(weight, weight = F.reshape(weight,
shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3]) weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
copy_inputs[1] = weight copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size new_attrs['tile_size'] = tile_size
...@@ -594,8 +612,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -594,8 +612,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
kernel.dtype) kernel.dtype)
new_workload = autotvm.task.args_to_workload( new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation, [new_data, new_weight, strides, padding, dilation,
new_attrs['layout'], out_dtype, tile_size], new_attrs[data_layout_key], out_dtype, tile_size],
conv2d_winograd_without_weight_transform) conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
...@@ -330,23 +330,40 @@ def schedule_conv2d_winograd_without_weight_transform_cuda(cfg, outs): ...@@ -330,23 +330,40 @@ def schedule_conv2d_winograd_without_weight_transform_cuda(cfg, outs):
##### REGISTER ALTER OP LAYOUT ##### ##### REGISTER ALTER OP LAYOUT #####
@nn.conv2d_alter_layout.register(["cuda", "gpu"]) @nn.conv2d_alter_layout.register(["cuda", "gpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos): def _alter_conv2d_layout(attrs, inputs, tinfos, F):
"""Alter op layout for pre-computing kernel transformation""" """Alter op layout for pre-computing kernel transformation
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : nnvm.symbol or tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
"""
if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs: if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs:
return None return None
import nnvm.symbol as sym
copy_inputs = [s for s in inputs] copy_inputs = [s for s in inputs]
new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs = {k: attrs[k] for k in attrs.keys()}
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
dilation = attrs.get_int_tuple("dilation") dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int('groups') groups = attrs.get_int('groups')
layout = attrs["layout"] data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"] out_dtype = attrs["out_dtype"]
out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype if out_dtype == "" or out_dtype == "same":
out_dtype = tinfos[0].dtype
data, kernel = tinfos[0:2] data, kernel = tinfos[0:2]
N, CI, H, W = get_const_tuple(data.shape) N, CI, H, W = get_const_tuple(data.shape)
...@@ -371,7 +388,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -371,7 +388,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
if cfg.template_key == 'int8': if cfg.template_key == 'int8':
assert 'cuda' in target.keys assert 'cuda' in target.keys
new_layout = 'NCHW4c' new_layout = 'NCHW4c'
new_attrs['layout'] = new_layout new_attrs[data_layout_key] = new_layout
new_attrs['out_layout'] = new_layout new_attrs['out_layout'] = new_layout
new_attrs['kernel_layout'] = 'OIHW4o4i' new_attrs['kernel_layout'] = 'OIHW4o4i'
ic_block_factor = oc_block_factor = 4 ic_block_factor = oc_block_factor = 4
...@@ -386,7 +403,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -386,7 +403,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
conv2d conv2d
) )
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
return sym.conv2d(*copy_inputs, **new_attrs) return F.nn.conv2d(*copy_inputs, **new_attrs)
if attrs.get_int_tuple("dilation") != (1, 1): if attrs.get_int_tuple("dilation") != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.") warnings.warn("Does not support weight pre-transform for dilated convolution.")
...@@ -395,9 +412,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -395,9 +412,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
# pre-compute weight transformation in winograd # pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1]) tile_size = _infer_tile_size(tinfos[0], tinfos[1])
weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size) tile_size=tile_size)
weight = sym.transpose(weight, axes=[0, 1, 3, 2]) weight = F.transpose(weight, axes=[0, 1, 3, 2])
copy_inputs[1] = weight copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size new_attrs['tile_size'] = tile_size
...@@ -410,7 +427,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -410,7 +427,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
conv2d_winograd_without_weight_transform conv2d_winograd_without_weight_transform
) )
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
elif groups != CI: elif groups != CI:
workload = autotvm.task.args_to_workload( workload = autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype], [tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype],
...@@ -424,7 +441,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -424,7 +441,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
if cfg.template_key == 'int8': if cfg.template_key == 'int8':
assert 'cuda' in target.keys assert 'cuda' in target.keys
new_layout = 'NCHW4c' new_layout = 'NCHW4c'
new_attrs['layout'] = new_layout new_attrs[data_layout_key] = new_layout
new_attrs['out_layout'] = new_layout new_attrs['out_layout'] = new_layout
new_attrs['kernel_layout'] = 'OIHW4o4i' new_attrs['kernel_layout'] = 'OIHW4o4i'
ic_block_factor = oc_block_factor = 4 ic_block_factor = oc_block_factor = 4
...@@ -440,7 +457,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -440,7 +457,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
group_conv2d_nchw group_conv2d_nchw
) )
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
return sym.conv2d(*copy_inputs, **new_attrs) return F.nn.conv2d(*copy_inputs, **new_attrs)
# do nothing for depthwise convolution # do nothing for depthwise convolution
return None return None
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import warnings
import tvm import tvm
from .. import generic from .. import generic
...@@ -37,8 +38,13 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None ...@@ -37,8 +38,13 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None
return xi, thread_z, thread_y, thread_x return xi, thread_z, thread_y, thread_x
@conv2d_alter_layout.register(["intel_graphics"]) @conv2d_alter_layout.register(["intel_graphics"])
def _alter_conv2d_layout(attrs, inputs, tinfos): def _alter_conv2d_layout(attrs, inputs, tinfos, F):
import nnvm.symbol as sym import nnvm.symbol as sym
if F != sym:
warnings.warn("Only support alter layout for intel graphics in NNVM now. "
"This pass is ignored in relay.")
return None
copy_inputs = [s for s in inputs] copy_inputs = [s for s in inputs]
data = tinfos[0] data = tinfos[0]
......
...@@ -465,9 +465,9 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs): ...@@ -465,9 +465,9 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
##### REGISTER ALTER OP LAYOUT ##### ##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["mali"]) @conv2d_alter_layout.register(["mali"])
def _alter_conv2d_layout(attrs, inputs, tinfos): def _alter_conv2d_layout(attrs, inputs, tinfos, F):
try: try:
return _alter_conv2d_layout_arm(attrs, inputs, tinfos) return _alter_conv2d_layout_arm(attrs, inputs, tinfos, F)
except KeyError: # to filter out fallback opencl templates except KeyError: # to filter out fallback opencl templates
return None return None
......
...@@ -57,17 +57,24 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N ...@@ -57,17 +57,24 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
@tvm.target.generic_func @tvm.target.generic_func
def conv2d_alter_layout(attrs, inputs, tinfos): def conv2d_alter_layout(attrs, inputs, tinfos, F):
"""Change Conv2D layout. """Change Conv2D layout.
Parameters Parameters
---------- ----------
attrs : nnvm.top.AttrDict attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution Attributes of current convolution
inputs : nnvm.symbol inputs : nnvm.symbol or tvm.relay.Expr
Grouped input symbols Grouped input symbols
tinfos : list tinfos : list
Input shape and dtype Input shape and dtype
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
""" """
# not to change by default # not to change by default
return None return None
......
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member # pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D schedule on x86""" """Conv2D schedule on x86"""
import warnings
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task.topi_integration import deserialize_args from tvm.autotvm.task.topi_integration import deserialize_args
...@@ -281,8 +283,13 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): ...@@ -281,8 +283,13 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
@conv2d_alter_layout.register("cpu") @conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo): def _alter_conv2d_layout(attrs, inputs, tinfo, F):
import nnvm.symbol as sym import nnvm.symbol as sym
if F != sym:
warnings.warn("Only support alter layout for x86 in NNVM now. "
"This pass is ignored in relay.")
return None
copy_inputs = [s for s in inputs] copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs = {k : attrs[k] for k in attrs.keys()}
data, kernel = tinfo[0], tinfo[1] data, kernel = tinfo[0], tinfo[1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment