Commit 0c72ca97 by tqchen

Finish schedule operation

parent 59bb0dd4
Subproject commit 29fd3defa3dbf810e52dbc2ecd3933604989dcc8
Subproject commit ea1a81be8baa43665f6ebd4d75d51c081283ebc8
......@@ -50,16 +50,19 @@ class Schedule : public NodeRef {
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline, attach it at parent.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
Schedule& compute_inline(Schedule parent); // NOLINT(*)
/*!
* \brief Compute the function at root, attach it to its parent.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
Schedule& compute_root(Schedule parent); // NOLINT(*)
/*!
......@@ -68,7 +71,7 @@ class Schedule : public NodeRef {
* \param p_outer The result outer domain
* \param p_inner The result inner domain.
* \param factor The split factor of the loop.
* \param outer The generated
* \return reference to self.
*/
Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
/*!
......@@ -80,6 +83,7 @@ class Schedule : public NodeRef {
* \param p_inner The result inner domain.
* \param factor Optional, the factor of the split,
* factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self.
*/
Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
/*!
......@@ -87,11 +91,13 @@ class Schedule : public NodeRef {
* \param inner The inner domain to be fused
* \param outer The outer domain to be fused.
* \param p_target The result target domain.
* \return reference to self.
*/
Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
*/
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
};
......
......@@ -79,6 +79,9 @@ def compute(shape, fcompute, name="TensorCompute"):
tensor: tensor.Tensor
The created tensor
"""
if isinstance(shape, _expr.Expr):
shape = (shape, )
ndim = len(shape)
arg_names = fcompute.__code__.co_varnames
if ndim != len(arg_names):
......@@ -86,6 +89,7 @@ def compute(shape, fcompute, name="TensorCompute"):
dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
body = convert(body)
op_node = _function_internal._ComputeOp(
name, dim_var, body)
return _function_internal._Tensor(
......@@ -174,8 +178,4 @@ def Schedule(tensor, scope="global"):
return _function_internal._Schedule(tensor, scope)
def Split(dim, factor, over_rdom=False):
return _function_internal._DimSplit(dim, factor, over_rdom)
_init_function_module("tvm")
......@@ -4,13 +4,106 @@ from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node
class DimSplit(NodeBase):
class Split(NodeBase):
pass
@register_node
class AttachSpec(NodeBase):
class Fuse(NodeBase):
pass
@register_node
class Schedule(NodeBase):
pass
def split(self, parent, factor=None, outer=None):
"""Split the schedule either by factor providing outer scope, or both
Parameters
----------
parent : IterVar
The parent iter var.
factor : Expr, optional
The splitting factor
outer : IterVar, optional
The outer split variable
Returns
-------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
"""
if outer is not None:
if outer.thread_tag == '':
raise ValueError("split by outer must have special thread_tag")
if outer.dom is None:
raise ValueError("split by outer must have specified domain")
inner = _function_internal._ScheduleSplitByOuter(self, parent, outer, factor)
else:
if factor is None:
raise ValueError("either outer or factor need to be provided")
outer, inner = _function_internal._ScheduleSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, inner, outer):
"""Fuse inner and outer to a single iteration variable.
Parameters
----------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
Returns
-------
inner : IterVar
The fused variable of iteration.
"""
return _function_internal._ScheduleFuse(self, inner, outer)
def compute_at(self, parent, scope):
"""Attach the schedule at parent's scope
Parameters
----------
parent : Schedule
The parent schedule
scope : IterVar
The loop scope t be attached to.
"""
_function_internal._ScheduleComputeAt(self, parent, scope)
def compute_inline(self, parent):
"""Attach the schedule at parent, and mark it as inline
Parameters
----------
parent : Schedule
The parent schedule
"""
_function_internal._ScheduleComputeInline(self, parent)
def compute_root(self, parent):
"""Attach the schedule at parent, and mark it as root
Parameters
----------
parent : Schedule
The parent schedule
"""
_function_internal._ScheduleComputeInline(self, parent)
def reorder(self, *args):
"""reorder the arguments in the specified order.
Parameters
----------
args : list of IterVar
The order to be ordered
"""
_function_internal._ScheduleReorder(self, args)
......@@ -7,6 +7,8 @@ from . import expr as _expr
class TensorSlice(SliceBase, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices):
if not isinstance(indices, tuple):
indices = (indices,)
self.tensor = tensor
self.indices = indices
......
......@@ -103,4 +103,53 @@ TVM_REGISTER_API(_Schedule)
*ret = Schedule(args.at(0), args.at(1));
});
TVM_REGISTER_API(_ScheduleSplitByFactor)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar outer, inner;
args.at(0).operator Schedule()
.split(args.at(1), &outer, &inner, args.at(2));
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_ScheduleSplitByOuter)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar inner;
args.at(0).operator Schedule()
.split(args.at(1), args.at(2), &inner, args.at(3));
*ret = inner;
});
TVM_REGISTER_API(_ScheduleFuse)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar fused;
args.at(0).operator Schedule()
.split(args.at(1), args.at(2), &fused);
*ret = fused;
});
TVM_REGISTER_API(_ScheduleComputeAt)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_at(args.at(1), args.at(2));
});
TVM_REGISTER_API(_ScheduleComputeInline)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_inline(args.at(1));
});
TVM_REGISTER_API(_ScheduleComputeRoot)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_root(args.at(1));
});
TVM_REGISTER_API(_ScheduleReorder)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.reorder(args.at(1));
});
} // namespace tvm
......@@ -115,7 +115,7 @@ class APIVariantValue {
CHECK_EQ(type_id, kNodeHandle);
// use dynamic RTTI for safety
CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
<< "wrong type specified";
<< "wrong type specified, expected " << typeid(typename T::ContainerType).name();
return T(sptr);
}
inline operator Expr() const {
......
......@@ -57,7 +57,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
if (op->var->name_hint.length() != 0) {
p->stream << op->var->name_hint << ", ";
}
if (op->dom.defined()) {
p->stream << op->dom;
}
if (op->thread_tag.length() != 0) {
p->stream << ", " << op->thread_tag;
}
......
......@@ -17,12 +17,38 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
return array_node->data.size();
}
size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* const IterVar& v) {
size_t pos = Find(leaf_iter_vars, parent);
size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
size_t pos = FindIterVar(leaf_vars, v);
if (pos < leaf_vars->data.size()) return pos;
if (FindIterVar(all_vars, v) < all_vars->data.size()) {
LOG(FATAL) << "Operate on iter var " << v
<< "that has already been splitted";
} else {
LOG(FATAL) << "Operate on iter var " << v
<< "that is not part of the schedule";
}
return 0;
}
void Split(ScheduleNode* self, IterVar parent,
IterVar outer, IterVar inner, Expr factor) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
self->relations.push_back(SplitNode::make(parent, outer, inner, factor));
// add vars to all vars
all_vars->data.push_back(outer.node_);
all_vars->data.push_back(inner.node_);
// replace the position.
leaf_vars->data.erase(leaf_vars->data.begin() + pos);
leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner.node_);
leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer.node_);
}
} // namespace
Schedule::Schedule(Operation op, std::string scope) {
auto n = std::make_shared<ScheduleNode>();
n->op = op;
......@@ -36,6 +62,14 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
(*this)->attach_type = kScope;
(*this)->attach_parent = scope;
bool found = false;
for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
if (scope == parent->leaf_iter_vars[i]) {
found = true; break;
}
}
CHECK(found)
<< "Cannot compute at a iteration variable that is not part of parent leaf vars";
parent->children.push_back(*this);
return *this;
}
......@@ -56,17 +90,63 @@ Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*)
Schedule& Schedule::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
ScheduleNode* self = operator->();
ArrayNode* leaf_iter_vars = self->leaf_iter_vars.CopyOnWrite();
// place holder for the splitted results.
IterVar outer(Range(), parent->var->name_hint + ".outer");
IterVar inner(Range(), parent->var->name_hint + ".inner");
*p_outer = outer; *p_inner = inner;
Split(operator->(), parent, outer, inner, factor);
return *this;
}
Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results.
IterVar inner(Range(), parent->var->name_hint + ".inner");
*p_inner = inner;
Split(operator->(), parent, outer, inner, factor);
CHECK(pos != leaf_iter_vars->data.size())
<< "Cannot find IterVar " << parent << " in the active leaf vars"
<< " this means "
return *this;
}
Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
ScheduleNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
self->relations.push_back(FuseNode::make(inner, outer, fused));
all_vars->data.push_back(fused.node_);
size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other";
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
leaf_vars->data.begin() + pos_inner);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_);
return *this;
}
Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
ScheduleNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos;
for (size_t i = 0; i < order.size(); ++i) {
pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
}
std::vector<std::shared_ptr<Node> > temp;
for (size_t i = 0; i < pos.size(); ++i) {
temp.emplace_back(leaf_vars->data[pos[i]]);
}
std::sort(pos.begin(), pos.end());
for (size_t i = 0; i < pos.size(); ++i) {
leaf_vars->data[pos[i]] = temp[i];
}
return *this;
}
IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
......
......@@ -6,28 +6,36 @@ def test_schedule_create():
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B')
AA = tvm.compute((m, l), lambda i, j: A[i, j])
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
Tsch = tvm.Schedule(T.op, scope="shared")
Asch = tvm.Schedule(A.op)
sch_T = tvm.Schedule(T.op, scope="shared")
sch_A = tvm.Schedule(AA.op, scope="global")
T.op.
xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2)
sch_A.compute_at(sch_T, xi1)
xo, xi = sch_A.split(AA.op.dim_var[0], factor=10)
xo, xi = sch.split(sch.dim_var[0], factor)
Asch.compute_at(Tsch, xi)
sch_T.reorder(xi2, xi1)
assert T.op.dim_var[1] in sch_T.leaf_iter_vars
xf = sch.fuse(xo, xi)
tk1 = tvm.Split(T.op.dim_var[0], 10)
assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit)
def test_reorder():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute(m, lambda i: A[i+1])
print(tk1.var)
print(sch.scope)
print(sch.attachs)
sch_T = tvm.Schedule(T.op, scope="shared")
xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2)
order = (xi2, xi1, xo)
assert tuple(sch_T.leaf_iter_vars) != order
sch_T.reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order
if __name__ == "__main__":
test_schedule_create()
test_reorder()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment