Commit de6dd0cb by Tianqi Chen Committed by GitHub

[BUGFIX] Fix schedule dataflow rewrite with multiple scan states (#126)

parent be1a29ed
...@@ -118,6 +118,38 @@ def exp(x): ...@@ -118,6 +118,38 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x) return call_pure_intrin(x.dtype, "exp", x)
def tanh(x):
"""Take hyperbolic tanh of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "tanh", x)
def sigmoid(x):
"""Quick function to get sigmoid
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return 1.0 / (1.0 + exp(-x))
def log(x): def log(x):
"""Take log of input x. """Take log of input x.
......
...@@ -90,7 +90,9 @@ class StorageFlattener : public IRMutator { ...@@ -90,7 +90,9 @@ class StorageFlattener : public IRMutator {
buf_map_[key].released = true; buf_map_[key].released = true;
// deduce current storage scope. // deduce current storage scope.
auto it = storage_scope_.find(op->func.get()); auto it = storage_scope_.find(op->func.get());
CHECK(it != storage_scope_.end()); CHECK(it != storage_scope_.end())
<< "Cannot find storage scope of " << op->func
<< " value_index=" << op->value_index;
StorageScope skey; StorageScope skey;
const std::string& strkey = it->second; const std::string& strkey = it->second;
if (strkey.length() == 0) { if (strkey.length() == 0) {
......
...@@ -231,15 +231,16 @@ void InjectInline(ScheduleNode* sch) { ...@@ -231,15 +231,16 @@ void InjectInline(ScheduleNode* sch) {
std::unordered_map<Tensor, Tensor> repl; std::unordered_map<Tensor, Tensor> repl;
// rewrite dataflow // rewrite dataflow
for (size_t i = 0; i < sch->stages.size(); ++i) { for (size_t i = 0; i < sch->stages.size(); ++i) {
if (new_body[i].defined() && if (new_body[i].defined()) {
!new_body[i].same_as(sch->stages[i]->op)) {
const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>(); const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
CHECK(compute); CHECK(compute);
Operation op = ComputeOpNode::make( if (!new_body[i].same_as(compute->body)) {
compute->name, compute->axis, new_body[i]); Operation op = ComputeOpNode::make(
repl[sch->stages[i]->op.output(0)] = op.output(0); compute->name, compute->axis, new_body[i]);
Stage s = sch->stages[i]; Stage s = sch->stages[i];
s->op = op; repl[s->op.output(0)] = op.output(0);
s->op = op;
}
} }
} }
ReplaceDataFlow(sch->stages, &repl); ReplaceDataFlow(sch->stages, &repl);
......
...@@ -252,9 +252,11 @@ class SchedulePostProc : public IRMutator { ...@@ -252,9 +252,11 @@ class SchedulePostProc : public IRMutator {
} }
// This must be checked for all ops, including scan. // This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) { if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0); for (int i = 0; i < s->op->num_outputs(); ++i) {
AddReplace(s->op.output(0), target, Tensor target = s->origin_op.output(0);
target, s->origin_op); AddReplace(s->op.output(i), target,
target, s->origin_op);
}
} }
// Specially add replacements for scan op. // Specially add replacements for scan op.
if (s->op.as<ScanOpNode>()) { if (s->op.as<ScanOpNode>()) {
......
...@@ -126,11 +126,11 @@ def rnn_matexp(): ...@@ -126,11 +126,11 @@ def rnn_matexp():
Whh_a = tvm.nd.array(Whh_np, ctx) Whh_a = tvm.nd.array(Whh_np, ctx)
# Skip first pass as it is compilation # Skip first pass as it is compilation
f(res_a, Whh_a) f(res_a, Whh_a)
tvm.nd.sync(ctx) ctx.sync()
# measure time cost of second step. # measure time cost of second step.
tstart = time.time() tstart = time.time()
f(res_a, Whh_a) f(res_a, Whh_a)
tvm.nd.sync(ctx) ctx.sync()
tgap = time.time() - tstart tgap = time.time() - tstart
print("Time cost=%g" % tgap) print("Time cost=%g" % tgap)
# correctness # correctness
......
...@@ -86,7 +86,30 @@ def test_inline_mixed(): ...@@ -86,7 +86,30 @@ def test_inline_mixed():
s = s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) def check(x):
if isinstance(x, tvm.expr.Call):
assert x.func != A2
tvm.ir_pass.PostOrderVisit(s[C].op.body, check)
def test_scan_inline():
m = tvm.var("m")
n = tvm.var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state1 = tvm.placeholder((m, n))
s_state2 = tvm.placeholder((m, n))
s_init1 = tvm.compute((1, n), lambda _, i: x[0, i])
s_init2 = tvm.compute((1, n), lambda _, i: x[0, i])
s_x1 = tvm.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="x1")
s_x2 = tvm.compute((m, n), lambda t, i: s_state2[t-1, i] + 1 , name="x2")
s_update1 = tvm.compute((m, n), lambda t, i: s_x1[t, i], "u1")
s_update2 = tvm.compute((m, n), lambda t, i: s_x2[t, i], "u2")
res1, res2 = tvm.scan([s_init1, s_init2],
[s_update1, s_update2],
[s_state1, s_state2])
s = tvm.create_schedule(res1.op)
s[s_x1].compute_inline()
stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False)
def test_schedule_cache(): def test_schedule_cache():
...@@ -105,6 +128,7 @@ def test_schedule_cache(): ...@@ -105,6 +128,7 @@ def test_schedule_cache():
if __name__ == "__main__": if __name__ == "__main__":
test_scan_inline()
test_inline_mixed() test_inline_mixed()
test_auto_inline() test_auto_inline()
test_schedule_scan() test_schedule_scan()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment