Commit 1071e242 by Animesh Jain Committed by Yizhi Liu

[Relay][AlterLayout] Broadcast with scalar shape (#4577)

parent 73dda6be
...@@ -200,6 +200,28 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { ...@@ -200,6 +200,28 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
} }
/*! /*!
* \brief Is single value tensor (scalar).
* \param expr The expr.
* \return True if single value tensor.
*/
inline bool IsScalar(const Expr& expr) {
if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
for (auto dim_index_expr : tensor_type->shape) {
if (auto dim_index = dim_index_expr.as<IntImm>()) {
if (dim_index->value != 1) {
return false;
}
} else {
return false;
}
}
} else {
return false;
}
return true;
}
/*!
* \brief Create a Constant with a scalar * \brief Create a Constant with a scalar
* *
* \param dtype The data type. * \param dtype The data type.
......
...@@ -119,6 +119,11 @@ class TransformMemorizer : public NodeRef { ...@@ -119,6 +119,11 @@ class TransformMemorizer : public NodeRef {
Expr input_expr = raw; Expr input_expr = raw;
Layout new_src_layout = src_layout; Layout new_src_layout = src_layout;
if (src_layout.ndim_primal() < dst_layout.ndim_primal()) { if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
// If scalar, then no need of layout transformation as scalar can be broadcasted easily even
// if the other operand has a transformed layout.
if (IsScalar(input_expr)) {
return raw;
}
int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal(); int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
new_src_layout = src_layout.ExpandPrimal(dst_layout); new_src_layout = src_layout.ExpandPrimal(dst_layout);
input_expr = MakeExpandDims(input_expr, 0, num_new_axis); input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
......
...@@ -318,6 +318,70 @@ def test_alter_layout_broadcast_op(): ...@@ -318,6 +318,70 @@ def test_alter_layout_broadcast_op():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_scalar_op():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def before():
x = relay.var("x", shape=(1, 500, 500, 64))
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var("bias", shape=(64,))
multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')
y = relay.nn.conv2d(x, kernel,
data_layout='NHWC',
kernel_layout="HWIO",
kernel_size=(3, 3))
y = relay.add(bias, y)
y = relay.nn.relu(y)
y = relay.multiply(multiplier1, y)
y = relay.multiply(y, multiplier2)
y = relay.Function(analysis.free_vars(y), y)
return y
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 500, 500, 64))
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var("bias", shape=(64,))
multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')
b = relay.expand_dims(bias, axis=0, num_newaxis=3)
b = relay.layout_transform(b, "NHWC", "NCHW16c")
y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, kernel,
data_layout='NCHW16c',
kernel_layout="HWIO",
kernel_size=(3, 3))
y = relay.add(b, y)
y = relay.nn.relu(y)
y = relay.multiply(multiplier1, y)
y = relay.multiply(y, multiplier2)
y = relay.layout_transform(y, "NCHW16c", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_scalar(): def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d. """Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly. The layout of broadcast operators and the weight should be changed accordingly.
...@@ -980,6 +1044,7 @@ if __name__ == "__main__": ...@@ -980,6 +1044,7 @@ if __name__ == "__main__":
test_alter_layout_dual_path() test_alter_layout_dual_path()
test_alter_layout_resnet() test_alter_layout_resnet()
test_alter_layout_broadcast_op() test_alter_layout_broadcast_op()
test_alter_layout_broadcast_scalar_op()
test_alter_layout_scalar() test_alter_layout_scalar()
test_alter_layout_concatenate() test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op() test_alter_layout_nchw_upsamping_op()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment