Commit 5d2ccd66 by Tianqi Chen Committed by GitHub

[SCHEDULE] Fuse support for 0 rank tensor (#1328)

parent 0134fabb
...@@ -130,6 +130,20 @@ class Stage : public NodeRef { ...@@ -130,6 +130,20 @@ class Stage : public NodeRef {
*/ */
EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*) EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*! /*!
* \brief Fuse all the axes together into a single axis.
*
* \param axes All the axes to be fused.
* \param p_target The result target domain.
*
* \note axes can be an empty array,
* in that case, a singleton itervar is created and
* inserted to the outermost loop.
* The fuse of empty array is used to support zero-dimension tensors.
*
* \return reference to self.
*/
EXPORT Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration * \brief Reorder the iteration
* \param order The order of iteration variable. * \param order The order of iteration variable.
* \return reference to self. * \return reference to self.
...@@ -151,9 +165,9 @@ class Stage : public NodeRef { ...@@ -151,9 +165,9 @@ class Stage : public NodeRef {
* \return reference to self. * \return reference to self.
*/ */
EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor, Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner); IterVar* p_x_inner, IterVar* p_y_inner);
/*! /*!
* \brief Vectorize iteration. * \brief Vectorize iteration.
* \param var The axis to be vectorized. * \param var The axis to be vectorized.
...@@ -674,6 +688,25 @@ class RebaseNode : public IterVarRelationNode { ...@@ -674,6 +688,25 @@ class RebaseNode : public IterVarRelationNode {
}; };
/*!
* \brief Singleton iterator [0, 1)
*/
class SingletonNode : public IterVarRelationNode {
public:
/*! \brief The singleton iterator */
IterVar iter;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter", &iter);
}
static IterVarRelation make(IterVar iter);
static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
};
// implementations // implementations
inline const StageNode* Stage::operator->() const { inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get()); return static_cast<const StageNode*>(node_.get());
......
...@@ -153,6 +153,12 @@ class Fuse(NodeBase): ...@@ -153,6 +153,12 @@ class Fuse(NodeBase):
@register_node @register_node
class Singleton(NodeBase):
"""Singleton axis."""
pass
@register_node
class IterVar(NodeBase, _expr.ExprOp): class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable. """Represent iteration variable.
...@@ -380,10 +386,7 @@ class Stage(NodeBase): ...@@ -380,10 +386,7 @@ class Stage(NodeBase):
fused : IterVar fused : IterVar
The fused variable of iteration. The fused variable of iteration.
""" """
assert len(args) >= 1, "Length of the arguments must be >=1 for fuse." fused = _api_internal._StageFuse(self, args)
fused = args[0]
for i in range(1, len(args)):
fused = _api_internal._StageFuse(self, fused, args[i])
return fused return fused
def set_scope(self, scope): def set_scope(self, scope):
......
...@@ -350,7 +350,7 @@ TVM_REGISTER_API("_StageFuse") ...@@ -350,7 +350,7 @@ TVM_REGISTER_API("_StageFuse")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused; IterVar fused;
args[0].operator Stage() args[0].operator Stage()
.fuse(args[1], args[2], &fused); .fuse(args[1], &fused);
*ret = fused; *ret = fused;
}); });
......
...@@ -82,6 +82,8 @@ void PassDownDomain(const Stage& stage, ...@@ -82,6 +82,8 @@ void PassDownDomain(const Stage& stage,
Update(p_state, r->rebased, Update(p_state, r->rebased,
Range::make_by_min_extent( Range::make_by_min_extent(
0, state.at(r->parent)->extent)); 0, state.at(r->parent)->extent));
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::make_by_min_extent(0, 1));
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -147,6 +149,7 @@ void PassUpIndex(const Stage& stage, ...@@ -147,6 +149,7 @@ void PassUpIndex(const Stage& stage,
} else { } else {
state[s->parent] = value; state[s->parent] = value;
} }
} else if (rel.as<SingletonNode>()) {
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -192,6 +195,8 @@ void PassDownIndex(const Stage& stage, ...@@ -192,6 +195,8 @@ void PassDownIndex(const Stage& stage,
Expr parent_min = dom_map.at(s->parent)->min; Expr parent_min = dom_map.at(s->parent)->min;
CHECK(is_zero(parent_min)); CHECK(is_zero(parent_min));
state[s->rebased] = value; state[s->rebased] = value;
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = make_zero(s->iter->var.type());
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -296,6 +301,7 @@ void PassUpDomain(const Stage& stage, ...@@ -296,6 +301,7 @@ void PassUpDomain(const Stage& stage,
state.at(r->rebased), state.at(r->rebased),
&parent); &parent);
state[r->parent] = parent; state[r->parent] = parent;
} else if (rel.as<SingletonNode>()) {
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -344,6 +350,7 @@ void PassUpBitMaskOr(const Stage& stage, ...@@ -344,6 +350,7 @@ void PassUpBitMaskOr(const Stage& stage,
} else { } else {
state[s->parent] |= state[s->rebased]; state[s->parent] |= state[s->rebased];
} }
} else if (rel.as<SingletonNode>()) {
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -390,6 +397,8 @@ void PassDownBitMaskOr(const Stage& stage, ...@@ -390,6 +397,8 @@ void PassDownBitMaskOr(const Stage& stage,
} else { } else {
state[s->rebased] |= state.at(s->parent); state[s->rebased] |= state.at(s->parent);
} }
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = 0;
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -438,6 +447,8 @@ void PassUpBoundCheck(const Stage& s, ...@@ -438,6 +447,8 @@ void PassUpBoundCheck(const Stage& s,
} else if (rel.as<RebaseNode>()) { } else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>(); const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased); state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) {
// nop
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
......
...@@ -237,7 +237,6 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT ...@@ -237,7 +237,6 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
IterVar fused = IterVarNode::make( IterVar fused = IterVarNode::make(
Range(), Var(fused_name, outer->var.type()), iter_type); Range(), Var(fused_name, outer->var.type()), iter_type);
*p_target = fused;
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
...@@ -255,6 +254,31 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT ...@@ -255,6 +254,31 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
leaf_vars->data.begin() + pos_inner + 1); leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_); fused.node_);
*p_target = fused;
return *this;
}
Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*)
if (axes.size() != 0) {
IterVar fused = axes[0];
for (size_t i = 1; i < axes.size(); ++i) {
this->fuse(fused, axes[i], &fused);
}
*p_target = std::move(fused);
} else {
StageNode* self = operator->();
// special handle fuse empty array.
// insert at the outer most loop
IterVar singleton = IterVarNode::make(
Range::make_by_min_extent(0, 1),
Var("singleton", Int(32)), kDataPar);
self->relations.push_back(SingletonNode::make(singleton));
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
all_vars->data.push_back(singleton.node_);
leaf_vars->data.insert(leaf_vars->data.begin(), singleton.node_);
*p_target = singleton;
}
return *this; return *this;
} }
...@@ -732,11 +756,18 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { ...@@ -732,11 +756,18 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return IterVarRelation(n); return IterVarRelation(n);
} }
IterVarRelation SingletonNode::make(IterVar iter) {
auto n = std::make_shared<SingletonNode>();
n->iter = iter;
return IterVarRelation(n);
}
TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode); TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode); TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer // Printer
...@@ -778,6 +809,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -778,6 +809,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->print(op->rebased); p->print(op->rebased);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
p->stream << "singleton(";
p->print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) { .set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
p->stream << "schedule(" << op << ")"; p->stream << "schedule(" << op << ")";
}); });
......
...@@ -44,10 +44,10 @@ def test_multiple_cache_write(): ...@@ -44,10 +44,10 @@ def test_multiple_cache_write():
n = tvm.convert(1024) n = tvm.convert(1024)
A0 = tvm.placeholder((n,), name='A0', dtype = "float32") A0 = tvm.placeholder((n,), name='A0', dtype = "float32")
A1 = tvm.placeholder((n,), name='A1', dtype = "float32") A1 = tvm.placeholder((n,), name='A1', dtype = "float32")
B0, B1 = tvm.compute((n,), B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)), lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
name='B') name='B')
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i), C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
name='C') name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
# create iter var and assign them tags. # create iter var and assign them tags.
...@@ -76,7 +76,7 @@ def test_multiple_cache_write(): ...@@ -76,7 +76,7 @@ def test_multiple_cache_write():
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
func(a0, a1, c) func(a0, a1, c)
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()), c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
rtol=1e-5) rtol=1e-5)
check_device("cuda", "llvm") check_device("cuda", "llvm")
...@@ -235,7 +235,6 @@ def try_warp_memory(): ...@@ -235,7 +235,6 @@ def try_warp_memory():
f(a, b) f(a, b)
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), a.asnumpy() + 3, rtol=1e-6) b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)
check_device("cuda") check_device("cuda")
......
...@@ -84,6 +84,19 @@ def test_fuse(): ...@@ -84,6 +84,19 @@ def test_fuse():
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations) assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi) assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
def test_singleton():
A = tvm.placeholder((), name='A')
T = tvm.compute((), lambda : A() + 1)
s = tvm.create_schedule(T.op)
fused = s[T].fuse()
assert any(isinstance(x, tvm.schedule.Singleton) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused,)
dump = pkl.dumps(s)
s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule)
def test_vectorize(): def test_vectorize():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -174,6 +187,7 @@ def test_tensor_intrin(): ...@@ -174,6 +187,7 @@ def test_tensor_intrin():
if __name__ == "__main__": if __name__ == "__main__":
test_singleton()
test_pragma() test_pragma()
test_tensor_intrin() test_tensor_intrin()
test_rfactor() test_rfactor()
......
...@@ -94,6 +94,8 @@ def test_broadcast_to(): ...@@ -94,6 +94,8 @@ def test_broadcast_to():
def test_add(): def test_add():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(), (), topi.add, np.add)
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add) (5, 2, 3), (2, 1), topi.add, np.add)
def test_subtract(): def test_subtract():
...@@ -114,6 +116,8 @@ def test_divide(): ...@@ -114,6 +116,8 @@ def test_divide():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001) None, (10,), topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(), None, topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001) (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
def test_maximum_minmum(): def test_maximum_minmum():
...@@ -157,10 +161,10 @@ def test_shift(): ...@@ -157,10 +161,10 @@ def test_shift():
if __name__ == "__main__": if __name__ == "__main__":
test_add()
test_shift() test_shift()
test_cmp() test_cmp()
test_mod() test_mod()
test_add()
test_subtract() test_subtract()
test_multiply() test_multiply()
test_divide() test_divide()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment