Commit 0c72ca97 by tqchen

Finish schedule operation

parent 59bb0dd4
Subproject commit 29fd3defa3dbf810e52dbc2ecd3933604989dcc8 Subproject commit ea1a81be8baa43665f6ebd4d75d51c081283ebc8
...@@ -50,16 +50,19 @@ class Schedule : public NodeRef { ...@@ -50,16 +50,19 @@ class Schedule : public NodeRef {
* \brief specify the schedule to be computed at the parent schedule's scope. * \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule. * \param parent The parent schedule.
* \param scope The iteration point to carry the schedule. * \param scope The iteration point to carry the schedule.
* \return reference to self.
*/ */
Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*) Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*)
/*! /*!
* \brief Compute the function inline, attach it at parent. * \brief Compute the function inline, attach it at parent.
* \param parent The parent schedule to be attached to. * \param parent The parent schedule to be attached to.
* \return reference to self.
*/ */
Schedule& compute_inline(Schedule parent); // NOLINT(*) Schedule& compute_inline(Schedule parent); // NOLINT(*)
/*! /*!
* \brief Compute the function at root, attach it to its parent. * \brief Compute the function at root, attach it to its parent.
* \param parent The parent schedule to be attached to. * \param parent The parent schedule to be attached to.
* \return reference to self.
*/ */
Schedule& compute_root(Schedule parent); // NOLINT(*) Schedule& compute_root(Schedule parent); // NOLINT(*)
/*! /*!
...@@ -68,7 +71,7 @@ class Schedule : public NodeRef { ...@@ -68,7 +71,7 @@ class Schedule : public NodeRef {
* \param p_outer The result outer domain * \param p_outer The result outer domain
* \param p_inner The result inner domain. * \param p_inner The result inner domain.
* \param factor The split factor of the loop. * \param factor The split factor of the loop.
* \param outer The generated * \return reference to self.
*/ */
Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*) Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
/*! /*!
...@@ -80,6 +83,7 @@ class Schedule : public NodeRef { ...@@ -80,6 +83,7 @@ class Schedule : public NodeRef {
* \param p_inner The result inner domain. * \param p_inner The result inner domain.
* \param factor Optional, the factor of the split, * \param factor Optional, the factor of the split,
* factor must be provided such that factor * outer.extent >= parent.extent. * factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self.
*/ */
Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*) Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
/*! /*!
...@@ -87,11 +91,13 @@ class Schedule : public NodeRef { ...@@ -87,11 +91,13 @@ class Schedule : public NodeRef {
* \param inner The inner domain to be fused * \param inner The inner domain to be fused
* \param outer The outer domain to be fused. * \param outer The outer domain to be fused.
* \param p_target The result target domain. * \param p_target The result target domain.
* \return reference to self.
*/ */
Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*) Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
/*! /*!
* \brief Reorder the iteration * \brief Reorder the iteration
* \param order The order of iteration variable. * \param order The order of iteration variable.
* \return reference to self.
*/ */
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*) Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
}; };
......
...@@ -79,6 +79,9 @@ def compute(shape, fcompute, name="TensorCompute"): ...@@ -79,6 +79,9 @@ def compute(shape, fcompute, name="TensorCompute"):
tensor: tensor.Tensor tensor: tensor.Tensor
The created tensor The created tensor
""" """
if isinstance(shape, _expr.Expr):
shape = (shape, )
ndim = len(shape) ndim = len(shape)
arg_names = fcompute.__code__.co_varnames arg_names = fcompute.__code__.co_varnames
if ndim != len(arg_names): if ndim != len(arg_names):
...@@ -86,6 +89,7 @@ def compute(shape, fcompute, name="TensorCompute"): ...@@ -86,6 +89,7 @@ def compute(shape, fcompute, name="TensorCompute"):
dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)] dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var]) body = fcompute(*[v.var for v in dim_var])
body = convert(body)
op_node = _function_internal._ComputeOp( op_node = _function_internal._ComputeOp(
name, dim_var, body) name, dim_var, body)
return _function_internal._Tensor( return _function_internal._Tensor(
...@@ -174,8 +178,4 @@ def Schedule(tensor, scope="global"): ...@@ -174,8 +178,4 @@ def Schedule(tensor, scope="global"):
return _function_internal._Schedule(tensor, scope) return _function_internal._Schedule(tensor, scope)
def Split(dim, factor, over_rdom=False):
return _function_internal._DimSplit(dim, factor, over_rdom)
_init_function_module("tvm") _init_function_module("tvm")
...@@ -4,13 +4,106 @@ from ._ctypes._api import NodeBase, register_node ...@@ -4,13 +4,106 @@ from ._ctypes._api import NodeBase, register_node
from . import _function_internal from . import _function_internal
@register_node @register_node
class DimSplit(NodeBase): class Split(NodeBase):
pass pass
@register_node @register_node
class AttachSpec(NodeBase): class Fuse(NodeBase):
pass pass
@register_node @register_node
class Schedule(NodeBase): class Schedule(NodeBase):
pass def split(self, parent, factor=None, outer=None):
"""Split the schedule either by factor providing outer scope, or both
Parameters
----------
parent : IterVar
The parent iter var.
factor : Expr, optional
The splitting factor
outer : IterVar, optional
The outer split variable
Returns
-------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
"""
if outer is not None:
if outer.thread_tag == '':
raise ValueError("split by outer must have special thread_tag")
if outer.dom is None:
raise ValueError("split by outer must have specified domain")
inner = _function_internal._ScheduleSplitByOuter(self, parent, outer, factor)
else:
if factor is None:
raise ValueError("either outer or factor need to be provided")
outer, inner = _function_internal._ScheduleSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, inner, outer):
"""Fuse inner and outer to a single iteration variable.
Parameters
----------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
Returns
-------
inner : IterVar
The fused variable of iteration.
"""
return _function_internal._ScheduleFuse(self, inner, outer)
def compute_at(self, parent, scope):
"""Attach the schedule at parent's scope
Parameters
----------
parent : Schedule
The parent schedule
scope : IterVar
The loop scope t be attached to.
"""
_function_internal._ScheduleComputeAt(self, parent, scope)
def compute_inline(self, parent):
"""Attach the schedule at parent, and mark it as inline
Parameters
----------
parent : Schedule
The parent schedule
"""
_function_internal._ScheduleComputeInline(self, parent)
def compute_root(self, parent):
"""Attach the schedule at parent, and mark it as root
Parameters
----------
parent : Schedule
The parent schedule
"""
_function_internal._ScheduleComputeInline(self, parent)
def reorder(self, *args):
"""reorder the arguments in the specified order.
Parameters
----------
args : list of IterVar
The order to be ordered
"""
_function_internal._ScheduleReorder(self, args)
...@@ -7,6 +7,8 @@ from . import expr as _expr ...@@ -7,6 +7,8 @@ from . import expr as _expr
class TensorSlice(SliceBase, _expr.ExprOp): class TensorSlice(SliceBase, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor.""" """Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices): def __init__(self, tensor, indices):
if not isinstance(indices, tuple):
indices = (indices,)
self.tensor = tensor self.tensor = tensor
self.indices = indices self.indices = indices
......
...@@ -103,4 +103,53 @@ TVM_REGISTER_API(_Schedule) ...@@ -103,4 +103,53 @@ TVM_REGISTER_API(_Schedule)
*ret = Schedule(args.at(0), args.at(1)); *ret = Schedule(args.at(0), args.at(1));
}); });
TVM_REGISTER_API(_ScheduleSplitByFactor)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar outer, inner;
args.at(0).operator Schedule()
.split(args.at(1), &outer, &inner, args.at(2));
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_ScheduleSplitByOuter)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar inner;
args.at(0).operator Schedule()
.split(args.at(1), args.at(2), &inner, args.at(3));
*ret = inner;
});
TVM_REGISTER_API(_ScheduleFuse)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar fused;
args.at(0).operator Schedule()
.split(args.at(1), args.at(2), &fused);
*ret = fused;
});
TVM_REGISTER_API(_ScheduleComputeAt)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_at(args.at(1), args.at(2));
});
TVM_REGISTER_API(_ScheduleComputeInline)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_inline(args.at(1));
});
TVM_REGISTER_API(_ScheduleComputeRoot)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_root(args.at(1));
});
TVM_REGISTER_API(_ScheduleReorder)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.reorder(args.at(1));
});
} // namespace tvm } // namespace tvm
...@@ -115,7 +115,7 @@ class APIVariantValue { ...@@ -115,7 +115,7 @@ class APIVariantValue {
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
// use dynamic RTTI for safety // use dynamic RTTI for safety
CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get())) CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
<< "wrong type specified"; << "wrong type specified, expected " << typeid(typename T::ContainerType).name();
return T(sptr); return T(sptr);
} }
inline operator Expr() const { inline operator Expr() const {
......
...@@ -57,7 +57,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -57,7 +57,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
if (op->var->name_hint.length() != 0) { if (op->var->name_hint.length() != 0) {
p->stream << op->var->name_hint << ", "; p->stream << op->var->name_hint << ", ";
} }
if (op->dom.defined()) {
p->stream << op->dom; p->stream << op->dom;
}
if (op->thread_tag.length() != 0) { if (op->thread_tag.length() != 0) {
p->stream << ", " << op->thread_tag; p->stream << ", " << op->thread_tag;
} }
......
...@@ -17,12 +17,38 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) { ...@@ -17,12 +17,38 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
return array_node->data.size(); return array_node->data.size();
} }
size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* const IterVar& v) { size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
size_t pos = Find(leaf_iter_vars, parent); size_t pos = FindIterVar(leaf_vars, v);
if (pos < leaf_vars->data.size()) return pos;
if (FindIterVar(all_vars, v) < all_vars->data.size()) {
LOG(FATAL) << "Operate on iter var " << v
<< "that has already been splitted";
} else {
LOG(FATAL) << "Operate on iter var " << v
<< "that is not part of the schedule";
}
return 0;
} }
void Split(ScheduleNode* self, IterVar parent,
IterVar outer, IterVar inner, Expr factor) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
self->relations.push_back(SplitNode::make(parent, outer, inner, factor));
// add vars to all vars
all_vars->data.push_back(outer.node_);
all_vars->data.push_back(inner.node_);
// replace the position.
leaf_vars->data.erase(leaf_vars->data.begin() + pos);
leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner.node_);
leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer.node_);
} }
} // namespace
Schedule::Schedule(Operation op, std::string scope) { Schedule::Schedule(Operation op, std::string scope) {
auto n = std::make_shared<ScheduleNode>(); auto n = std::make_shared<ScheduleNode>();
n->op = op; n->op = op;
...@@ -36,6 +62,14 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*) ...@@ -36,6 +62,14 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone); CHECK_EQ((*this)->attach_type, kNone);
(*this)->attach_type = kScope; (*this)->attach_type = kScope;
(*this)->attach_parent = scope; (*this)->attach_parent = scope;
bool found = false;
for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
if (scope == parent->leaf_iter_vars[i]) {
found = true; break;
}
}
CHECK(found)
<< "Cannot compute at a iteration variable that is not part of parent leaf vars";
parent->children.push_back(*this); parent->children.push_back(*this);
return *this; return *this;
} }
...@@ -56,17 +90,63 @@ Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*) ...@@ -56,17 +90,63 @@ Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*)
Schedule& Schedule::split( Schedule& Schedule::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*) IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
ScheduleNode* self = operator->(); // place holder for the splitted results.
ArrayNode* leaf_iter_vars = self->leaf_iter_vars.CopyOnWrite(); IterVar outer(Range(), parent->var->name_hint + ".outer");
IterVar inner(Range(), parent->var->name_hint + ".inner");
*p_outer = outer; *p_inner = inner;
Split(operator->(), parent, outer, inner, factor);
return *this;
}
Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results.
IterVar inner(Range(), parent->var->name_hint + ".inner");
*p_inner = inner;
Split(operator->(), parent, outer, inner, factor);
CHECK(pos != leaf_iter_vars->data.size()) return *this;
<< "Cannot find IterVar " << parent << " in the active leaf vars" }
<< " this means "
Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
ScheduleNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
self->relations.push_back(FuseNode::make(inner, outer, fused));
all_vars->data.push_back(fused.node_);
size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other";
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
leaf_vars->data.begin() + pos_inner);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_);
return *this; return *this;
} }
Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
ScheduleNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos;
for (size_t i = 0; i < order.size(); ++i) {
pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
}
std::vector<std::shared_ptr<Node> > temp;
for (size_t i = 0; i < pos.size(); ++i) {
temp.emplace_back(leaf_vars->data[pos[i]]);
}
std::sort(pos.begin(), pos.end());
for (size_t i = 0; i < pos.size(); ++i) {
leaf_vars->data[pos[i]] = temp[i];
}
return *this;
}
IterVarRelation SplitNode::make( IterVarRelation SplitNode::make(
IterVar parent, IterVar outer, IterVar parent, IterVar outer,
......
...@@ -6,28 +6,36 @@ def test_schedule_create(): ...@@ -6,28 +6,36 @@ def test_schedule_create():
l = tvm.Var('l') l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
AA = tvm.compute((m, l), lambda i, j: A[i, j])
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
Tsch = tvm.Schedule(T.op, scope="shared") sch_T = tvm.Schedule(T.op, scope="shared")
Asch = tvm.Schedule(A.op) sch_A = tvm.Schedule(AA.op, scope="global")
T.op. xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2)
sch_A.compute_at(sch_T, xi1)
xo, xi = sch_A.split(AA.op.dim_var[0], factor=10)
xo, xi = sch.split(sch.dim_var[0], factor) sch_T.reorder(xi2, xi1)
Asch.compute_at(Tsch, xi) assert T.op.dim_var[1] in sch_T.leaf_iter_vars
xf = sch.fuse(xo, xi) def test_reorder():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
tk1 = tvm.Split(T.op.dim_var[0], 10) T = tvm.compute(m, lambda i: A[i+1])
assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit)
print(tk1.var) sch_T = tvm.Schedule(T.op, scope="shared")
print(sch.scope) xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
print(sch.attachs) xi1, xi2 = sch_T.split(xi, factor=2)
order = (xi2, xi1, xo)
assert tuple(sch_T.leaf_iter_vars) != order
sch_T.reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_create() test_schedule_create()
test_reorder()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment