Commit 9d002e8e by Yizhi Liu Committed by Tianqi Chen

[Lang] Fix undef BijectiveLayout and add scalar layout support (#3105)

parent 73f87ae0
...@@ -94,12 +94,13 @@ class Layout; ...@@ -94,12 +94,13 @@ class Layout;
// Internal node container Buffer // Internal node container Buffer
class LayoutNode : public Node { class LayoutNode : public Node {
public: public:
/*! \brief string representation of layout */ /*! \brief string representation of layout, "" for scalar. */
std::string name; std::string name;
/*! \brief specify each axis of the layout, /*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis. * in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis, * The IterVar's extent indicates the size of the axis,
* it is a variable for a primal axis, but a constant for a subordinate axis. * it is a variable for a primal axis, but a constant for a subordinate axis.
* Empty for scalar's layout.
*/ */
Array<IterVar> axes; Array<IterVar> axes;
...@@ -122,6 +123,7 @@ class LayoutNode : public Node { ...@@ -122,6 +123,7 @@ class LayoutNode : public Node {
* For example, NCHW16c can describe a 5-D tensor of * For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block]. * [batch_size, channel, height, width, channel_block].
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
* Layout for scalar is defined, while both its name and axes have size 0.
*/ */
class Layout : public NodeRef { class Layout : public NodeRef {
public: public:
...@@ -175,7 +177,7 @@ class Layout : public NodeRef { ...@@ -175,7 +177,7 @@ class Layout : public NodeRef {
* that starts at dimension \p pos and spans \p len dimensions * that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first). * (or until the end of the layout, whichever comes first).
* \param pos The start position. * \param pos The start position.
* \param len The length of the sub-layout. * \param len The length of the sub-layout. if 0, return layout of scalar
* \return A newly constructed Layout object. * \return A newly constructed Layout object.
*/ */
Layout SubLayout(size_t pos, size_t len) const; Layout SubLayout(size_t pos, size_t len) const;
......
...@@ -88,12 +88,14 @@ Layout::Layout(const Array<IterVar>& axes) { ...@@ -88,12 +88,14 @@ Layout::Layout(const Array<IterVar>& axes) {
} }
Layout::Layout(const std::string& name) { // NOLINT(*) Layout::Layout(const std::string& name) { // NOLINT(*)
if (name.empty() || name == "__undef__") return; if (name == "__undef__") return;
node_ = make_node<LayoutNode>(); node_ = make_node<LayoutNode>();
LayoutNode *node = operator->(); LayoutNode *node = operator->();
node->name = name; node->name = name;
if (name.empty()) return; // scalar
// parse layout string // parse layout string
int32_t factor = 0; int32_t factor = 0;
for (char c : name) { for (char c : name) {
...@@ -146,6 +148,7 @@ Layout LayoutNode::make(const std::string& layout) { ...@@ -146,6 +148,7 @@ Layout LayoutNode::make(const std::string& layout) {
Layout Layout::SubLayout(size_t pos, size_t len) const { Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef(); if (!defined() || pos > ndim()) return Layout::Undef();
if (len == 0) return Layout(Array<IterVar>());
if (pos + len > ndim()) len = ndim() - pos; if (pos + len > ndim()) len = ndim() - pos;
Array<IterVar> new_layout; Array<IterVar> new_layout;
const auto axes = operator->()->axes; const auto axes = operator->()->axes;
...@@ -195,6 +198,10 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { ...@@ -195,6 +198,10 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
inline bool GetStoreRule(Array<Expr>* rule, inline bool GetStoreRule(Array<Expr>* rule,
const Layout& src_layout, const Layout& src_layout,
const Layout& dst_layout) { const Layout& dst_layout) {
if (!src_layout.defined() || src_layout.name().empty() ||
!dst_layout.defined() || dst_layout.name().empty()) {
return false;
}
for (size_t i = 0; i < dst_layout.ndim(); ++i) { for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i]; const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i]; const IterVar& store_axis_impl = dst_layout->axes[i];
......
...@@ -97,15 +97,19 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs, ...@@ -97,15 +97,19 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx, layouts.Set(undef_idx,
layouts[defined_idx].SubLayout( layouts[defined_idx].SubLayout(
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size())); old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}}; return Array<Array<Layout> >{layouts, {layouts[defined_idx]}};
} else { } else {
// only know the tensor with smaller dimensions, // only know the tensor with smaller dimensions,
// so we cannot infer the final broadcasted output. // so we cannot infer the final broadcasted output.
// fails in this case. // fails in this case.
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}}; return Array<Array<Layout> >{{Layout::Undef()}, {Layout::Undef()}};
} }
} else if (layouts[0].defined() && layouts[1].defined() &&
(layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) {
int scalar = layouts[0].ndim() == 0 ? 0 : 1;
return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
} else { } else {
// try to broadcast the tensors to the larger dimension // try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1; int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
......
...@@ -57,7 +57,7 @@ def test_alter_op(): ...@@ -57,7 +57,7 @@ def test_alter_op():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_return_none(): def test_alter_return_none():
...@@ -81,7 +81,7 @@ def test_alter_return_none(): ...@@ -81,7 +81,7 @@ def test_alter_return_none():
b = before() b = before()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0]) assert(called[0])
...@@ -147,7 +147,7 @@ def test_alter_layout(): ...@@ -147,7 +147,7 @@ def test_alter_layout():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_dual_path(): def test_alter_layout_dual_path():
...@@ -213,7 +213,7 @@ def test_alter_layout_dual_path(): ...@@ -213,7 +213,7 @@ def test_alter_layout_dual_path():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_resnet(): def test_alter_layout_resnet():
"""Test alternating the layout of a residual block """Test alternating the layout of a residual block
...@@ -273,7 +273,7 @@ def test_alter_layout_resnet(): ...@@ -273,7 +273,7 @@ def test_alter_layout_resnet():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_op(): def test_alter_layout_broadcast_op():
...@@ -323,7 +323,7 @@ def test_alter_layout_broadcast_op(): ...@@ -323,7 +323,7 @@ def test_alter_layout_broadcast_op():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_scalar(): def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d. """Test alternating the layout of a conv2d.
...@@ -370,7 +370,7 @@ def test_alter_layout_scalar(): ...@@ -370,7 +370,7 @@ def test_alter_layout_scalar():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate(): def test_alter_layout_concatenate():
""" """ """ """
...@@ -425,7 +425,7 @@ def test_alter_layout_concatenate(): ...@@ -425,7 +425,7 @@ def test_alter_layout_concatenate():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nchw_upsamping_op(): def test_alter_layout_nchw_upsamping_op():
...@@ -469,7 +469,7 @@ def test_alter_layout_nchw_upsamping_op(): ...@@ -469,7 +469,7 @@ def test_alter_layout_nchw_upsamping_op():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_strided_slice(): def test_alter_layout_strided_slice():
...@@ -511,7 +511,7 @@ def test_alter_layout_strided_slice(): ...@@ -511,7 +511,7 @@ def test_alter_layout_strided_slice():
b = expected() b = expected()
b = infer_type(b) b = infer_type(b)
assert(alpha_equal(a, b)) assert alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -52,6 +52,12 @@ def test_layout(): ...@@ -52,6 +52,12 @@ def test_layout():
def test_bilayout_convertible(): def test_bilayout_convertible():
# not convertible # not convertible
assert tvm.bijective_layout("NCHW", "ABCD") is None assert tvm.bijective_layout("NCHW", "ABCD") is None
assert tvm.bijective_layout("__undef__", "NCHW") is None
assert tvm.bijective_layout("NCHW", "__undef__") is None
assert tvm.bijective_layout("__undef__", "__undef__") is None
assert tvm.bijective_layout("", "NCHW") is None
assert tvm.bijective_layout("NCHW", "") is None
assert tvm.bijective_layout("", "") is None
# convertible # convertible
assert tvm.bijective_layout("NCHW", "NCHW16c") is not None assert tvm.bijective_layout("NCHW", "NCHW16c") is not None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment