From fbb472b8f5916b0e73a912a2ae5c08828e892200 Mon Sep 17 00:00:00 2001 From: libing4752 <libing475211023@sjtu.edu.cn> Date: Mon, 5 Feb 2018 01:07:30 +0800 Subject: [PATCH] enhance pragma to support single point copy (#863) * modified schedule_dataflow_rewrite.cc to fix losing tensor problem * modified schedule_dataflow_rewrite.cc for lint scan * modified schedule_dataflow_rewrite.cc for lint scan * using tensor's value_index to index output of stage op * repare address offset for different kinds of dtype * bc * aaa * aaaaa * repare address for different dtypes * remove nonsense files * add whitespace of line 581 * use base alloc elem_type * enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits * use extends[0]->type() as dtype of offset * clear program writes * enhance inject_copy_intin to support of pragma stmt with no loops * fix cpplint errors * fix cpplint error of ! * enhance detectLinearEquation to support with no loop vars * fix cpplint errors --- src/arithmetic/detect_linear_equation.cc | 31 ++++++++++--------- src/pass/inject_copy_intrin.cc | 26 +++++++++++----- .../unittest/test_pass_inject_copy_intrin.py | 20 ++++++++++++ 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 63f582160..642a86686 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -123,25 +123,28 @@ class LinearEqDetector }; Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) { - CHECK_GE(vars.size(), 1U); Expr base = e; Array<Expr> coeff; - for (Var v : vars) { - LinearEqEntry ret; - if (!LinearEqDetector(v).Detect(base, &ret)) { - return Array<Expr>(); + if (0 == vars.size()) { + coeff.push_back(make_const(Int(32), 1)); + } else { + for (Var v : vars) { + LinearEqEntry ret; + if (!LinearEqDetector(v).Detect(base, &ret)) { + return Array<Expr>(); + } + coeff.push_back(ret.coeff); + base = std::move(ret.base); } - coeff.push_back(ret.coeff); - base = std::move(ret.base); - } - std::unordered_set<const Variable*> vset; - for (size_t i = vars.size(); i != 1; --i) { - vset.insert(vars[i - 1].get()); - // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset)) { - return Array<Expr>(); + std::unordered_set<const Variable*> vset; + for (size_t i = vars.size(); i != 1; --i) { + vset.insert(vars[i - 1].get()); + // The previous coeff contains the variable + if (ExprUseVar(coeff[i - 2], vset)) { + return Array<Expr>(); + } } } coeff.push_back(base); diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index cafcddcb9..ba44253a0 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator { private: bool MatchCopyPattern(Stmt stmt, Stmt *out) { Stmt body = stmt; + bool is_single_point_copy = false; // strip the loops std::vector<const For*> loops; @@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator { const Select* select = store->value.as<Select>(); const Cast* cast = store->value.as<Cast>(); const Load* load = store->value.as<Load>(); - + if (0 == loops.size()) { + is_single_point_copy = true; + CHECK(select == nullptr); + } // for now only support true condition matching if (select != nullptr) { load = select->true_value.as<Load>(); @@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator { arith::DetectLinearEquation(load->index, loop_vars); if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array<Expr> dst_shape; - for (const For* op : loops) { - dst_shape.push_back(op->extent); + auto loop_var_size = loop_vars.size(); + if (is_single_point_copy) { + loop_var_size = 1; + dst_shape.push_back(make_const(Int(32), 1)); + } else { + for (const For* op : loops) { + dst_shape.push_back(op->extent); + } } Array<Expr> src_shape = dst_shape; Array<Expr> pad_before, pad_after; Expr pad_value; - Expr src_elem_offset = load_strides[loop_vars.size()]; + Expr src_elem_offset = load_strides[loop_var_size]; if (select != nullptr) { Array<Expr> clip_bound = arith::DetectClipBound(select->condition, loop_vars); @@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator { src_elem_offset = Simplify(src_elem_offset); } CHECK_EQ(load_strides.size(), store_strides.size()); - CHECK_EQ(load_strides.size(), loop_vars.size() + 1); - Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size()); - Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size()); + CHECK_EQ(load_strides.size(), loop_var_size + 1); + Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); + Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); Buffer dst = BufferNode::make( Var(store->buffer_var.node_), store->value.type(), dst_shape, dst_strides, - store_strides[loop_vars.size()], + store_strides[loop_var_size], store->buffer_var->name_hint, GetStorageScope(store->buffer_var.get()), 0, 0); diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index 08477895b..c6ed19d65 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -44,6 +44,25 @@ def test_copy_pad(): return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) +def test_single_point_test(): + A = tvm.placeholder((1,), name='A') + B = tvm.compute((1,), lambda i: + A[i], name='B') + s = tvm.create_schedule(B.op) + s[B].pragma(B.op.axis[0], "memcpy") + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') + Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + def cb(src, dst, pad_before, pad_after, pad_value): + assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 + assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0 + assert tvm.ir_pass.Simplify(src.strides[0]).value == 1 + assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1 + return tvm.make.Evaluate(0) + stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + def assert_expr_equal(a, b): assert tvm.ir_pass.Simplify(a - b).value == 0 @@ -80,3 +99,4 @@ if __name__ == "__main__": test_copy2d() test_copy_pad() test_copy_pad_split() + test_single_point_test() -- GitLab