Commit a64a0f5f by Tianqi Chen Committed by GitHub

[BUGFIX/REGRESSION] Complex inline call, regression test on lstm cell (#128)

parent de6dd0cb
...@@ -147,7 +147,7 @@ def sigmoid(x): ...@@ -147,7 +147,7 @@ def sigmoid(x):
y : Expr y : Expr
The result. The result.
""" """
return 1.0 / (1.0 + exp(-x)) return call_pure_intrin(x.dtype, "sigmoid", x)
def log(x): def log(x):
...@@ -265,3 +265,6 @@ def _rule_float_direct(op): ...@@ -265,3 +265,6 @@ def _rule_float_direct(op):
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
# default pattern for exp # default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True) register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
# default pattern for sigmoid
register_intrin_rule("default", "sigmoid", lambda op: 1.0 / (1.0 + exp(-op.args[0])))
...@@ -23,14 +23,13 @@ class IRInline : public IRMutator { ...@@ -23,14 +23,13 @@ class IRInline : public IRMutator {
if (op->func == f_) { if (op->func == f_) {
CHECK_EQ(op->value_index, 0); CHECK_EQ(op->value_index, 0);
Expr expr = body_; expr = body_;
CHECK_EQ(args_.size(), op->args.size()); CHECK_EQ(args_.size(), op->args.size());
bool has_side_effect = false; bool has_side_effect = false;
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
if (HasSideEffect(op->args[i])) has_side_effect = true; if (HasSideEffect(op->args[i])) has_side_effect = true;
} }
if (has_side_effect) { if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) { for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr); expr = Let::make(args_[i], op->args[i], expr);
...@@ -45,7 +44,7 @@ class IRInline : public IRMutator { ...@@ -45,7 +44,7 @@ class IRInline : public IRMutator {
} }
return expr; return expr;
} else { } else {
return e; return expr;
} }
} }
......
...@@ -197,6 +197,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -197,6 +197,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
void InjectInline(ScheduleNode* sch) { void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache(); sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size()); std::vector<Expr> new_body(sch->stages.size());
// inline all the ops // inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
...@@ -231,19 +232,32 @@ void InjectInline(ScheduleNode* sch) { ...@@ -231,19 +232,32 @@ void InjectInline(ScheduleNode* sch) {
std::unordered_map<Tensor, Tensor> repl; std::unordered_map<Tensor, Tensor> repl;
// rewrite dataflow // rewrite dataflow
for (size_t i = 0; i < sch->stages.size(); ++i) { for (size_t i = 0; i < sch->stages.size(); ++i) {
Stage s = sch->stages[i];
if (s->attach_type == kInlinedAlready) continue;
if (new_body[i].defined()) { if (new_body[i].defined()) {
// Logics from ReplaceDataFlow
const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>(); const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
CHECK(compute); CHECK(compute);
Operation op = s->op;
if (!new_body[i].same_as(compute->body)) { if (!new_body[i].same_as(compute->body)) {
Operation op = ComputeOpNode::make( op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]); compute->name, compute->axis, new_body[i]);
Stage s = sch->stages[i]; }
op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) {
repl[s->op.output(0)] = op.output(0); repl[s->op.output(0)] = op.output(0);
s->op = op; s->op = op;
} }
} else {
Operation op = s->op->ReplaceInputs(s->op, repl);
if (!op.same_as(s->op)) {
for (int j = 0; j < op->num_outputs(); ++j) {
repl[s->op.output(j)] = op.output(j);
}
s->op = op;
}
} }
} }
ReplaceDataFlow(sch->stages, &repl);
} }
Schedule Schedule::normalize() { Schedule Schedule::normalize() {
......
...@@ -19,7 +19,19 @@ def test_inline(): ...@@ -19,7 +19,19 @@ def test_inline():
except tvm.TVMError: except tvm.TVMError:
pass pass
def test_inline2():
m = tvm.var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body)
def check(op):
if isinstance(op, tvm.expr.Call):
assert op.func != T.op
tvm.ir_pass.PostOrderVisit(stmt, check)
if __name__ == "__main__": if __name__ == "__main__":
test_inline2()
test_inline() test_inline()
import tvm
def test_lstm_cell_inline():
num_step = 128
num_input = 256
num_hidden = 1152
batch_size = 4
# Global transition matrix
X = tvm.placeholder((num_step - 1, batch_size, num_input), name="X")
Wi2h = tvm.placeholder((4, num_hidden, num_input), name="Wi2h")
Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
# h: output hidden state, c: cell state.
s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
s_init_c = tvm.compute((1, batch_size, num_hidden),
lambda *i: 0.0, name="init_c")
s_init_h = tvm.compute((1, batch_size, num_hidden),
lambda *i: 0.0, name="init_h")
# LSTM transition
k = tvm.reduce_axis((0, num_input), name="ki2h")
s_i2h = tvm.compute(
(num_step, 4, batch_size, num_hidden),
lambda t, x, i, j: tvm.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k),
name="s_i2h")
k = tvm.reduce_axis((0, num_hidden), name="ki2h")
s_h2h = tvm.compute(
(num_step, 4, batch_size, num_hidden),
lambda t, x, i, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
name="s_h2h")
# Gate rules
gates = tvm.compute(s_i2h.shape, lambda *i:
s_i2h(*i) + s_h2h(*i), name="gates")
gshape = (num_step, batch_size, num_hidden)
in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 0, i, j]), name="in_gate")
in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, 1, i, j]), name="in_transform")
forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 2, i, j]), name="forget_gate")
out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 3, i, j]), name="out_gate")
next_c = tvm.compute(gshape,
lambda t, i, j:
forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
next_h = tvm.compute(gshape,
lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h")
update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
# schedule
scan_h, scan_c = tvm.scan(
[s_init_h, s_init_c],
[update_h, update_c],
[s_state_h, s_state_c],
inputs=[X],
name="lstm_scan")
# schedule
s = tvm.create_schedule(scan_h.op)
# Inline gate computations
s[gates].compute_inline()
s[in_gate].compute_inline()
s[in_transform].compute_inline()
s[forget_gate].compute_inline()
s[out_gate].compute_inline()
# verify we can lower correctly
tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c], with_api_wrapper=False)
if __name__ == "__main__":
test_lstm_cell_inline()
...@@ -92,7 +92,7 @@ def test_inline_mixed(): ...@@ -92,7 +92,7 @@ def test_inline_mixed():
tvm.ir_pass.PostOrderVisit(s[C].op.body, check) tvm.ir_pass.PostOrderVisit(s[C].op.body, check)
def test_scan_inline(): def test_scan_inline1():
m = tvm.var("m") m = tvm.var("m")
n = tvm.var("n") n = tvm.var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
...@@ -111,6 +111,28 @@ def test_scan_inline(): ...@@ -111,6 +111,28 @@ def test_scan_inline():
s[s_x1].compute_inline() s[s_x1].compute_inline()
stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False) stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False)
def test_scan_inline2():
m = tvm.var("m")
n = tvm.var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state1 = tvm.placeholder((m, n))
s_state2 = tvm.placeholder((m, n))
s_init1 = tvm.compute((1, n), lambda _, i: x[0, i])
s_init2 = tvm.compute((1, n), lambda _, i: x[0, i])
s_xx = tvm.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="xx")
s_x1 = tvm.compute((m, n), lambda t, i: s_xx[t, i] + 1, name="x1")
s_x2 = tvm.compute((m, n), lambda t, i: s_xx[t, i] + s_state2[t-1, 2], name="x2")
s_update1 = tvm.compute((m, n), lambda t, i: s_x1[t, i], "u1")
s_update2 = tvm.compute((m, n), lambda t, i: s_x2[t, i], "u2")
res1, res2 = tvm.scan([s_init1, s_init2],
[s_update1, s_update2],
[s_state1, s_state2])
s = tvm.create_schedule(res1.op)
s[s_xx].compute_inline()
s[s_x1].compute_inline()
s[s_x2].compute_inline()
stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False)
def test_schedule_cache(): def test_schedule_cache():
m = tvm.var('m') m = tvm.var('m')
...@@ -128,7 +150,8 @@ def test_schedule_cache(): ...@@ -128,7 +150,8 @@ def test_schedule_cache():
if __name__ == "__main__": if __name__ == "__main__":
test_scan_inline() test_scan_inline1()
test_scan_inline2()
test_inline_mixed() test_inline_mixed()
test_auto_inline() test_auto_inline()
test_schedule_scan() test_schedule_scan()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment