From de6dd0cb10b56e3a3186e900553e6811a8806e12 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sun, 7 May 2017 08:44:53 -0700 Subject: [PATCH] [BUGFIX] Fix schedule dataflow rewrite with multiple scan states (#126) --- python/tvm/intrin.py | 32 +++++++++++++++++++ src/pass/storage_flatten.cc | 4 ++- src/schedule/schedule_dataflow_rewrite.cc | 15 +++++---- src/schedule/schedule_ops.cc | 8 +++-- tests/python/perf/rnn_matexp.py | 4 +-- .../unittest/test_schedule_schedule_ops.py | 26 ++++++++++++++- 6 files changed, 75 insertions(+), 14 deletions(-) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 4476fe0c3..2900cfdff 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -118,6 +118,38 @@ def exp(x): return call_pure_intrin(x.dtype, "exp", x) +def tanh(x): + """Take hyperbolic tanh of input x. + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return call_pure_intrin(x.dtype, "tanh", x) + + +def sigmoid(x): + """Quick function to get sigmoid + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return 1.0 / (1.0 + exp(-x)) + + def log(x): """Take log of input x. diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index a90c86758..ff4c5912d 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -90,7 +90,9 @@ class StorageFlattener : public IRMutator { buf_map_[key].released = true; // deduce current storage scope. auto it = storage_scope_.find(op->func.get()); - CHECK(it != storage_scope_.end()); + CHECK(it != storage_scope_.end()) + << "Cannot find storage scope of " << op->func + << " value_index=" << op->value_index; StorageScope skey; const std::string& strkey = it->second; if (strkey.length() == 0) { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 7534c7019..70b68e358 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -231,15 +231,16 @@ void InjectInline(ScheduleNode* sch) { std::unordered_map<Tensor, Tensor> repl; // rewrite dataflow for (size_t i = 0; i < sch->stages.size(); ++i) { - if (new_body[i].defined() && - !new_body[i].same_as(sch->stages[i]->op)) { + if (new_body[i].defined()) { const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>(); CHECK(compute); - Operation op = ComputeOpNode::make( - compute->name, compute->axis, new_body[i]); - repl[sch->stages[i]->op.output(0)] = op.output(0); - Stage s = sch->stages[i]; - s->op = op; + if (!new_body[i].same_as(compute->body)) { + Operation op = ComputeOpNode::make( + compute->name, compute->axis, new_body[i]); + Stage s = sch->stages[i]; + repl[s->op.output(0)] = op.output(0); + s->op = op; + } } } ReplaceDataFlow(sch->stages, &repl); diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index dab1318c1..edbc2878a 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -252,9 +252,11 @@ class SchedulePostProc : public IRMutator { } // This must be checked for all ops, including scan. if (!s->op.same_as(s->origin_op)) { - Tensor target = s->origin_op.output(0); - AddReplace(s->op.output(0), target, - target, s->origin_op); + for (int i = 0; i < s->op->num_outputs(); ++i) { + Tensor target = s->origin_op.output(0); + AddReplace(s->op.output(i), target, + target, s->origin_op); + } } // Specially add replacements for scan op. if (s->op.as<ScanOpNode>()) { diff --git a/tests/python/perf/rnn_matexp.py b/tests/python/perf/rnn_matexp.py index 1d58d1faf..c7a939ef9 100644 --- a/tests/python/perf/rnn_matexp.py +++ b/tests/python/perf/rnn_matexp.py @@ -126,11 +126,11 @@ def rnn_matexp(): Whh_a = tvm.nd.array(Whh_np, ctx) # Skip first pass as it is compilation f(res_a, Whh_a) - tvm.nd.sync(ctx) + ctx.sync() # measure time cost of second step. tstart = time.time() f(res_a, Whh_a) - tvm.nd.sync(ctx) + ctx.sync() tgap = time.time() - tstart print("Time cost=%g" % tgap) # correctness diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index ea02c60b8..76b3d9622 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -86,7 +86,30 @@ def test_inline_mixed(): s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) - print(stmt) + def check(x): + if isinstance(x, tvm.expr.Call): + assert x.func != A2 + tvm.ir_pass.PostOrderVisit(s[C].op.body, check) + + +def test_scan_inline(): + m = tvm.var("m") + n = tvm.var("n") + x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") + s_state1 = tvm.placeholder((m, n)) + s_state2 = tvm.placeholder((m, n)) + s_init1 = tvm.compute((1, n), lambda _, i: x[0, i]) + s_init2 = tvm.compute((1, n), lambda _, i: x[0, i]) + s_x1 = tvm.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="x1") + s_x2 = tvm.compute((m, n), lambda t, i: s_state2[t-1, i] + 1 , name="x2") + s_update1 = tvm.compute((m, n), lambda t, i: s_x1[t, i], "u1") + s_update2 = tvm.compute((m, n), lambda t, i: s_x2[t, i], "u2") + res1, res2 = tvm.scan([s_init1, s_init2], + [s_update1, s_update2], + [s_state1, s_state2]) + s = tvm.create_schedule(res1.op) + s[s_x1].compute_inline() + stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False) def test_schedule_cache(): @@ -105,6 +128,7 @@ def test_schedule_cache(): if __name__ == "__main__": + test_scan_inline() test_inline_mixed() test_auto_inline() test_schedule_scan() -- GitLab