Commit 61cdf903 by libing4752 Committed by Tianqi Chen

[SCHEDULE] Add factor_axis to rfactor (#895)

parent 12d1ab5a
......@@ -313,10 +313,12 @@ class Schedule : public NodeRef {
*
* \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored.
* \param factor_axis The position where the new axis is placed.
* \return The created factored tensors.
*/
EXPORT Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis);
const IterVar& axis,
int factor_axis = 0);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
......
......@@ -279,7 +279,7 @@ class Schedule(NodeBase):
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
def rfactor(self, tensor, axis):
def rfactor(self, tensor, axis, factor_axis=0):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.
This will create a new stage that generated the new tensor with axis
......@@ -292,13 +292,15 @@ class Schedule(NodeBase):
The tensor to be factored.
axis : IterVar
The reduction axis in the schedule to be factored.
factor_axis : int
The position where the new axis is placed.
Returns
-------
tfactor : Tensor or Array of Tensor
The created factored tensor.
"""
factored = _api_internal._ScheduleRFactor(self, tensor, axis)
factored = _api_internal._ScheduleRFactor(self, tensor, axis, factor_axis)
return factored[0] if len(factored) == 1 else factored
......
......@@ -432,7 +432,7 @@ TVM_REGISTER_API("_ScheduleCacheWrite")
TVM_REGISTER_API("_ScheduleRFactor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.rfactor(args[1], args[2]);
.rfactor(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_CommReducerCombine")
......
......@@ -395,7 +395,8 @@ Schedule Schedule::normalize() {
// Handle reduction factor.
Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) {
const IterVar& axis,
int factor_axis) {
(*this)->InvalidateCache();
using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce)
......@@ -448,6 +449,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
reduce_stage, dom_map, value_map, true, skip_bound_check);
// Get the factored op node.
const int factor_axis_pos = \
factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
CHECK_LE(factor_axis_pos, compute_op->axis.size());
auto n = std::make_shared<ComputeOpNode>();
n->name = compute_op->name + ".rf";
{
......@@ -458,10 +462,16 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
<< "Can only factor reduction domain starting from 0";
iv_node->var = axis->var;
iv_node->iter_type = kDataPar;
n->axis.push_back(IterVar(iv_node));
for (IterVar iv : compute_op->axis) {
n->axis.push_back(iv);
const int size = compute_op->axis.size();
for (int idx = 0; idx < size; ++idx) {
if (factor_axis_pos == idx) {
n->axis.push_back(IterVar(iv_node));
}
n->axis.push_back(compute_op->axis[idx]);
}
if (factor_axis_pos == size) {
n->axis.push_back(IterVar(iv_node));
}
}
// predicate generation, copy not touched axis.
......@@ -548,9 +558,15 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
[&](const Array<Var>& i) {
Array<Expr> indices;
indices.push_back(repl_red_axis->var);
for (Var v : i) {
indices.push_back(v);
const int idx_size = static_cast<int>(i.size());
for (int idx = 0; idx < idx_size; ++idx) {
if (factor_axis_pos == idx) {
indices.push_back(repl_red_axis->var);
}
indices.push_back(i[idx]);
}
if (factor_axis_pos == idx_size) {
indices.push_back(repl_red_axis->var);
}
Array<Expr> factor_exprs;
for (int idx = 0; idx < size; ++idx) {
......
......@@ -83,6 +83,36 @@ def test_rfactor():
check_target()
def test_rfactor_factor_axis():
n = tvm.convert(1027)
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
# schedule
s = tvm.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf, 1)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
if not tvm.module.enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi,
target=target,
name="mysum")
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
check_target()
def test_rfactor_threads():
nn = 1027
......@@ -294,6 +324,7 @@ def test_rfactor_argmax():
if __name__ == "__main__":
test_rfactor_elemwise_threads()
test_rfactor_threads()
test_rfactor_factor_axis()
test_rfactor()
test_reduce_prims()
test_argmax()
......
......@@ -137,6 +137,16 @@ def test_rfactor():
assert(BF.op.body[0].axis[0] == k2)
assert(BF.op.body[0].axis[1].var == ko.var)
assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
# schedule with factor_axis
s = tvm.create_schedule(B.op)
ko, ki = s[B].split(k1, factor=4)
xo, xi = s[B].split(B.op.axis[0], factor=8)
BF = s.rfactor(B, ki, 1)
assert(n == BF.shape[0])
assert(BF.shape[1].value == 4)
assert(BF.op.body[0].axis[0] == k2)
assert(BF.op.body[0].axis[1].var == ko.var)
assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
def test_tensor_intrin():
n = 16
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment