Commit 83d98042 by kun-zh Committed by Tianqi Chen

[SCHEDULE] New Reduction Mode for Tensorize (#727)

* when there is no intrin func, using body for initialization. For issue 714.

* Refine code per review comments, and add a test case.

* Fix lint issues.
parent 2a8e0746
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "./op_util.h" #include "./op_util.h"
#include "./compute_op.h" #include "./compute_op.h"
#include "../schedule/message_passing.h" #include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -322,6 +323,50 @@ void VerifyTensorizeBody( ...@@ -322,6 +323,50 @@ void VerifyTensorizeBody(
} }
} }
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param stage The stage for tensorizing.
* \param dom_map The range of each iter var.
* \param n The loop nest structured used in compute.
* \param body The body func in tensorize intrin
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
Stmt TransformUpdate(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const ComputeLoopNest& n,
Stmt body,
Stmt update) {
Array<Expr> conds;
std::unordered_set<const Variable*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
if (iit != stage->iter_var_attrs.end()) {
const IterVarAttr& attr = (*iit).second;
if (attr->iter_type == kTensorized) {
break;
}
}
if (iv->iter_type == kCommReduce) {
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range& vrange = vit->second;
conds.push_back(likely(iv->var > vrange->min));
banned.insert(iv->var.get());
}
}
for (const Expr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition "
<< pred << " has a conflict with the reset condition";
}
}
return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
update, body);
}
Stmt MakeTensorize(const ComputeOpNode* self, Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) { const std::unordered_map<IterVar, Range>& dom_map) {
...@@ -416,32 +461,47 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -416,32 +461,47 @@ Stmt MakeTensorize(const ComputeOpNode* self,
return MergeNest(nest, body); return MergeNest(nest, body);
} else { } else {
// Need to split reduction // Need to split reduction
CHECK(intrin->reduce_init.defined())
<< "Reduction init op for intrin " << intrin << " is not defined";
CHECK(intrin->reduce_update.defined()) CHECK(intrin->reduce_update.defined())
<< "Reduction update op for intrin " << intrin << " is not defined"; << "Reduction update op for intrin " << intrin << " is not defined";
// Need init and update steps // Need init and update steps
CHECK_NE(self->reduce_axis.size(), 0U); CHECK_NE(self->reduce_axis.size(), 0U);
std::vector<std::vector<Stmt> > common( std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
// init nest
std::vector<std::vector<Stmt> > init_nest(
n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
init = Substitute(init, n.init_vmap);
init = MergeNest(init_nest, init);
// The update
std::vector<std::vector<Stmt> > update_nest( std::vector<std::vector<Stmt> > update_nest(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
update = MergeNest(input_bind_nest, update); if (intrin->reduce_init.defined()) {
update = Substitute(update, vmap); // init nest
update = MergeNest(binder.asserts(), update); std::vector<std::vector<Stmt> > init_nest(
update = Substitute(update, n.main_vmap); n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
update = MergeNest(update_nest, update); init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
return MergeNest(common, Block::make(init, update)); Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
init = Substitute(init, n.init_vmap);
init = MergeNest(init_nest, init);
// The update
Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
update = MergeNest(input_bind_nest, update);
update = Substitute(update, vmap);
update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, Block::make(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
CHECK(intrin->body.defined())
<< "Normal body op for intrin " << intrin << " is not defined";
Stmt update = TransformUpdate(stage, dom_map, n,
intrin->body,
intrin->reduce_update);
update = MergeNest(output_bind_nest, update);
update = MergeNest(input_bind_nest, update);
update = Substitute(update, vmap);
update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, update);
}
} }
} }
......
import tvm
def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute((m,), lambda i:
tvm.sum(w[i, k] * x[k], axis=k), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W",
offset_factor=16,
strides=[tvm.var('ldw'), 1])
def intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
ww_ptr = ww.access_ptr("r")
xx_ptr = xx.access_ptr("r")
zz_ptr = zz.access_ptr("w")
body = tvm.call_packed(
"gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
update = tvm.call_packed(
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, None, update
with tvm.build_config(data_alignment=16,
offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
def test_tensorize_matmul():
n = 1024
m = n
l = n
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j:
tvm.sum(B[j, k] * A[i, k], axis=k), name='C')
def check(factor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
yo, yi = s[C].split(y, factor=factor)
gemv = intrin_gemv(factor, l)
s[C].tensorize(yi, gemv)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
def check_rfactor(factor, rfactor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
rk = C.op.reduce_axis[0]
yo, yi = s[C].split(y, factor=factor)
ro, ri = s[C].split(rk, factor=rfactor)
s[C].reorder(yo, ro, yi, ri)
gemv = intrin_gemv(factor, rfactor)
s[C].tensorize(yi, gemv)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
check(16)
check_rfactor(16, 16)
if __name__ == "__main__":
test_tensorize_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment