Commit 6961ad14 by Animesh Jain Committed by Yizhi Liu

[Relay][AlterOp] Minor refactor. (#4064)

parent d703fb4e
...@@ -53,10 +53,10 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o ...@@ -53,10 +53,10 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o
for (auto axis : src_layout->axes) { for (auto axis : src_layout->axes) {
if (!LayoutAxis::Get(axis).IsPrimal()) { if (!LayoutAxis::Get(axis).IsPrimal()) {
// 1) Find the corresponding dual axis // 1) Find the corresponding dual axis
auto dual_axis = LayoutAxis::Get(axis).ToPrimal().name()[0]; const auto& dual_axis = LayoutAxis::Get(axis).ToPrimal();
// 2) Find the index of this dual axis in old_layout // 2) Find the index of this dual axis in old_layout
int old_axis = old_layout.IndexOf(LayoutAxis::Get(dual_axis)); int old_axis = old_layout.IndexOf(dual_axis);
// 3) Find the shape of this index in old_shape // 3) Find the shape of this index in old_shape
auto shape_val = old_shape[old_axis]; auto shape_val = old_shape[old_axis];
...@@ -72,7 +72,7 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o ...@@ -72,7 +72,7 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o
// 4) b) If shape is not 1, retain the factor. // 4) b) If shape is not 1, retain the factor.
if (!is_shape_one) { if (!is_shape_one) {
auto new_shape_val = src_layout.FactorOf(LayoutAxis::Get(dual_axis)); auto new_shape_val = src_layout.FactorOf(dual_axis);
new_layout += std::to_string(new_shape_val); new_layout += std::to_string(new_shape_val);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment