Commit 35af4c8b by Animesh Jain Committed by Yizhi Liu

[Relay][Convert Layout] Handling batch norm layout change. (#4600)

parent 55bd786f
......@@ -617,6 +617,34 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>());
size_t axis =
param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
Layout ret = Layout::Undef();
// If new_in_layouts are defined, this code tries to modify the layout.
if (new_in_layouts.defined() && old_in_layouts.defined()) {
// Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout.
const auto& bn_dim = old_in_layouts[0][axis];
auto new_index = new_in_layouts[0].IndexOf(bn_dim);
param->axis = new_index;
ret = new_in_layouts[0];
} else if (old_in_layouts.defined()) {
ret = old_in_layouts[0];
}
// BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout.
Layout c_layout = Layout("C");
return Array<Array<Layout>>{{ret, c_layout, c_layout, c_layout, c_layout},
{ret, c_layout, c_layout}};
}
bool BatchNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
......@@ -708,6 +736,7 @@ axis to be the last item in the input shape.
.add_argument("beta", "Tensor", "The beta offset factor.")
.add_argument("moving_mean", "Tensor", "Running mean of input.")
.add_argument("moving_var", "Tensor", "Running variance of input.")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout)
.set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel);
......
......@@ -134,7 +134,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
};
return CreateFunctionPass(
pass_func, 3, "ConvertLayout",
{ir::StringImm::make("InferType"), ir::StringImm::make("SimplifyInference"),
{ir::StringImm::make("InferType"),
ir::StringImm::make("CanonicalizeOps")});
}
......
......@@ -349,6 +349,54 @@ def test_scalar_convert_layout():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_conv_bn_convert_layout():
""" Check that layout transforms are propagated through bn. """
def before():
x = relay.var("x", shape=(1, 56, 56, 64))
weight = relay.var("weight", shape=(3, 3, 64, 64))
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
data_layout='NHWC', kernel_layout='HWIO')
dtype = "float32"
beta = relay.var("beta", relay.TensorType((64,), dtype))
gamma = relay.var("gamma", relay.TensorType((64,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))
y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=3)
y = relay.nn.relu(y[0])
y = relay.Function(analysis.free_vars(y), y)
return y
def expected():
x = relay.var("x", shape=(1, 56, 56, 64))
w = relay.var("weight", shape=(3, 3, 64, 64))
x = relay.layout_transform(x, 'NHWC', 'NCHW')
w = relay.layout_transform(w, 'HWIO', 'OIHW')
y = relay.nn.conv2d(x, w,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
dtype = "float32"
beta = relay.var("beta", relay.TensorType((64,), dtype))
gamma = relay.var("gamma", relay.TensorType((64,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))
y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=1)
y = relay.nn.relu(y[0])
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
test_no_convert_layout()
test_conv_convert_layout()
......@@ -358,3 +406,4 @@ if __name__ == "__main__":
test_bn_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
test_conv_bn_convert_layout()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment