Commit 5d2ccd66 by Tianqi Chen Committed by GitHub

[SCHEDULE] Fuse support for 0 rank tensor (#1328)

parent 0134fabb
......@@ -130,6 +130,20 @@ class Stage : public NodeRef {
*/
EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*!
* \brief Fuse all the axes together into a single axis.
*
* \param axes All the axes to be fused.
* \param p_target The result target domain.
*
* \note axes can be an empty array,
* in that case, a singleton itervar is created and
* inserted to the outermost loop.
* The fuse of empty array is used to support zero-dimension tensors.
*
* \return reference to self.
*/
EXPORT Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
......@@ -151,9 +165,9 @@ class Stage : public NodeRef {
* \return reference to self.
*/
EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
......@@ -674,6 +688,25 @@ class RebaseNode : public IterVarRelationNode {
};
/*!
* \brief Singleton iterator [0, 1)
*/
class SingletonNode : public IterVarRelationNode {
public:
/*! \brief The singleton iterator */
IterVar iter;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter", &iter);
}
static IterVarRelation make(IterVar iter);
static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
};
// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
......
......@@ -153,6 +153,12 @@ class Fuse(NodeBase):
@register_node
class Singleton(NodeBase):
"""Singleton axis."""
pass
@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.
......@@ -380,10 +386,7 @@ class Stage(NodeBase):
fused : IterVar
The fused variable of iteration.
"""
assert len(args) >= 1, "Length of the arguments must be >=1 for fuse."
fused = args[0]
for i in range(1, len(args)):
fused = _api_internal._StageFuse(self, fused, args[i])
fused = _api_internal._StageFuse(self, args)
return fused
def set_scope(self, scope):
......
......@@ -350,7 +350,7 @@ TVM_REGISTER_API("_StageFuse")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused;
args[0].operator Stage()
.fuse(args[1], args[2], &fused);
.fuse(args[1], &fused);
*ret = fused;
});
......
......@@ -82,6 +82,8 @@ void PassDownDomain(const Stage& stage,
Update(p_state, r->rebased,
Range::make_by_min_extent(
0, state.at(r->parent)->extent));
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::make_by_min_extent(0, 1));
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -147,6 +149,7 @@ void PassUpIndex(const Stage& stage,
} else {
state[s->parent] = value;
}
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -192,6 +195,8 @@ void PassDownIndex(const Stage& stage,
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(is_zero(parent_min));
state[s->rebased] = value;
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = make_zero(s->iter->var.type());
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -296,6 +301,7 @@ void PassUpDomain(const Stage& stage,
state.at(r->rebased),
&parent);
state[r->parent] = parent;
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -344,6 +350,7 @@ void PassUpBitMaskOr(const Stage& stage,
} else {
state[s->parent] |= state[s->rebased];
}
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -390,6 +397,8 @@ void PassDownBitMaskOr(const Stage& stage,
} else {
state[s->rebased] |= state.at(s->parent);
}
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = 0;
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -438,6 +447,8 @@ void PassUpBoundCheck(const Stage& s,
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) {
// nop
} else {
LOG(FATAL) << "unknown relation type";
}
......
......@@ -237,7 +237,6 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
IterVar fused = IterVarNode::make(
Range(), Var(fused_name, outer->var.type()), iter_type);
*p_target = fused;
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
......@@ -255,6 +254,31 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_);
*p_target = fused;
return *this;
}
Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*)
if (axes.size() != 0) {
IterVar fused = axes[0];
for (size_t i = 1; i < axes.size(); ++i) {
this->fuse(fused, axes[i], &fused);
}
*p_target = std::move(fused);
} else {
StageNode* self = operator->();
// special handle fuse empty array.
// insert at the outer most loop
IterVar singleton = IterVarNode::make(
Range::make_by_min_extent(0, 1),
Var("singleton", Int(32)), kDataPar);
self->relations.push_back(SingletonNode::make(singleton));
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
all_vars->data.push_back(singleton.node_);
leaf_vars->data.insert(leaf_vars->data.begin(), singleton.node_);
*p_target = singleton;
}
return *this;
}
......@@ -732,11 +756,18 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return IterVarRelation(n);
}
IterVarRelation SingletonNode::make(IterVar iter) {
auto n = std::make_shared<SingletonNode>();
n->iter = iter;
return IterVarRelation(n);
}
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer
......@@ -778,6 +809,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
p->stream << "singleton(";
p->print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
p->stream << "schedule(" << op << ")";
});
......
......@@ -44,10 +44,10 @@ def test_multiple_cache_write():
n = tvm.convert(1024)
A0 = tvm.placeholder((n,), name='A0', dtype = "float32")
A1 = tvm.placeholder((n,), name='A1', dtype = "float32")
B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
name='B')
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
name='C')
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
......@@ -76,7 +76,7 @@ def test_multiple_cache_write():
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
func(a0, a1, c)
np.testing.assert_allclose(
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
rtol=1e-5)
check_device("cuda", "llvm")
......@@ -235,7 +235,6 @@ def try_warp_memory():
f(a, b)
np.testing.assert_allclose(
b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)
check_device("cuda")
......
......@@ -84,6 +84,19 @@ def test_fuse():
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
def test_singleton():
A = tvm.placeholder((), name='A')
T = tvm.compute((), lambda : A() + 1)
s = tvm.create_schedule(T.op)
fused = s[T].fuse()
assert any(isinstance(x, tvm.schedule.Singleton) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused,)
dump = pkl.dumps(s)
s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule)
def test_vectorize():
m = tvm.var('m')
n = tvm.var('n')
......@@ -174,6 +187,7 @@ def test_tensor_intrin():
if __name__ == "__main__":
test_singleton()
test_pragma()
test_tensor_intrin()
test_rfactor()
......
......@@ -94,6 +94,8 @@ def test_broadcast_to():
def test_add():
verify_broadcast_binary_ele(
(), (), topi.add, np.add)
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add)
def test_subtract():
......@@ -114,6 +116,8 @@ def test_divide():
verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(), None, topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
def test_maximum_minmum():
......@@ -157,10 +161,10 @@ def test_shift():
if __name__ == "__main__":
test_add()
test_shift()
test_cmp()
test_mod()
test_add()
test_subtract()
test_multiply()
test_divide()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment