Commit 9d002e8e by Yizhi Liu Committed by Tianqi Chen

[Lang] Fix undef BijectiveLayout and add scalar layout support (#3105)

parent 73f87ae0
......@@ -94,12 +94,13 @@ class Layout;
// Internal node container Buffer
class LayoutNode : public Node {
public:
/*! \brief string representation of layout */
/*! \brief string representation of layout, "" for scalar. */
std::string name;
/*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis,
* it is a variable for a primal axis, but a constant for a subordinate axis.
* Empty for scalar's layout.
*/
Array<IterVar> axes;
......@@ -122,6 +123,7 @@ class LayoutNode : public Node {
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
* Layout for scalar is defined, while both its name and axes have size 0.
*/
class Layout : public NodeRef {
public:
......@@ -175,7 +177,7 @@ class Layout : public NodeRef {
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \param len The length of the sub-layout. if 0, return layout of scalar
* \return A newly constructed Layout object.
*/
Layout SubLayout(size_t pos, size_t len) const;
......
......@@ -88,12 +88,14 @@ Layout::Layout(const Array<IterVar>& axes) {
}
Layout::Layout(const std::string& name) { // NOLINT(*)
if (name.empty() || name == "__undef__") return;
if (name == "__undef__") return;
node_ = make_node<LayoutNode>();
LayoutNode *node = operator->();
node->name = name;
if (name.empty()) return; // scalar
// parse layout string
int32_t factor = 0;
for (char c : name) {
......@@ -146,6 +148,7 @@ Layout LayoutNode::make(const std::string& layout) {
Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef();
if (len == 0) return Layout(Array<IterVar>());
if (pos + len > ndim()) len = ndim() - pos;
Array<IterVar> new_layout;
const auto axes = operator->()->axes;
......@@ -195,6 +198,10 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
inline bool GetStoreRule(Array<Expr>* rule,
const Layout& src_layout,
const Layout& dst_layout) {
if (!src_layout.defined() || src_layout.name().empty() ||
!dst_layout.defined() || dst_layout.name().empty()) {
return false;
}
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
......
......@@ -99,13 +99,17 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
layouts[defined_idx].SubLayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
return Array<Array<Layout> >{layouts, {layouts[defined_idx]}};
} else {
// only know the tensor with smaller dimensions,
// so we cannot infer the final broadcasted output.
// fails in this case.
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
return Array<Array<Layout> >{{Layout::Undef()}, {Layout::Undef()}};
}
} else if (layouts[0].defined() && layouts[1].defined() &&
(layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) {
int scalar = layouts[0].ndim() == 0 ? 0 : 1;
return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
} else {
// try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
......
......@@ -57,7 +57,7 @@ def test_alter_op():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_return_none():
......@@ -81,7 +81,7 @@ def test_alter_return_none():
b = before()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
......@@ -147,7 +147,7 @@ def test_alter_layout():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_dual_path():
......@@ -213,7 +213,7 @@ def test_alter_layout_dual_path():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_resnet():
"""Test alternating the layout of a residual block
......@@ -273,7 +273,7 @@ def test_alter_layout_resnet():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_op():
......@@ -323,7 +323,7 @@ def test_alter_layout_broadcast_op():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d.
......@@ -370,7 +370,7 @@ def test_alter_layout_scalar():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate():
""" """
......@@ -425,7 +425,7 @@ def test_alter_layout_concatenate():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nchw_upsamping_op():
......@@ -469,7 +469,7 @@ def test_alter_layout_nchw_upsamping_op():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_strided_slice():
......@@ -511,7 +511,7 @@ def test_alter_layout_strided_slice():
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
assert alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
......
......@@ -52,6 +52,12 @@ def test_layout():
def test_bilayout_convertible():
# not convertible
assert tvm.bijective_layout("NCHW", "ABCD") is None
assert tvm.bijective_layout("__undef__", "NCHW") is None
assert tvm.bijective_layout("NCHW", "__undef__") is None
assert tvm.bijective_layout("__undef__", "__undef__") is None
assert tvm.bijective_layout("", "NCHW") is None
assert tvm.bijective_layout("NCHW", "") is None
assert tvm.bijective_layout("", "") is None
# convertible
assert tvm.bijective_layout("NCHW", "NCHW16c") is not None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment