From fbb472b8f5916b0e73a912a2ae5c08828e892200 Mon Sep 17 00:00:00 2001
From: libing4752 <libing475211023@sjtu.edu.cn>
Date: Mon, 5 Feb 2018 01:07:30 +0800
Subject: [PATCH] enhance pragma to support single point copy  (#863)

* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op

* repare address offset for different kinds of dtype

* bc

* aaa

* aaaaa

* repare address for different dtypes

* remove nonsense files

* add whitespace of line 581

* use base alloc elem_type

* enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits

* use extends[0]->type() as dtype of offset

* clear program writes

* enhance inject_copy_intin to support of pragma stmt with no loops

* fix cpplint errors

* fix cpplint error of !

* enhance detectLinearEquation to support with no loop vars

* fix cpplint errors
---
 src/arithmetic/detect_linear_equation.cc      | 31 ++++++++++---------
 src/pass/inject_copy_intrin.cc                | 26 +++++++++++-----
 .../unittest/test_pass_inject_copy_intrin.py  | 20 ++++++++++++
 3 files changed, 55 insertions(+), 22 deletions(-)

diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc
index 63f582160..642a86686 100644
--- a/src/arithmetic/detect_linear_equation.cc
+++ b/src/arithmetic/detect_linear_equation.cc
@@ -123,25 +123,28 @@ class LinearEqDetector
 };
 
 Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
-  CHECK_GE(vars.size(), 1U);
   Expr base = e;
   Array<Expr> coeff;
 
-  for (Var v : vars) {
-    LinearEqEntry ret;
-    if (!LinearEqDetector(v).Detect(base, &ret)) {
-      return Array<Expr>();
+  if (0 == vars.size()) {
+    coeff.push_back(make_const(Int(32), 1));
+  } else {
+    for (Var v : vars) {
+      LinearEqEntry ret;
+      if (!LinearEqDetector(v).Detect(base, &ret)) {
+        return Array<Expr>();
+      }
+      coeff.push_back(ret.coeff);
+      base = std::move(ret.base);
     }
-    coeff.push_back(ret.coeff);
-    base = std::move(ret.base);
-  }
 
-  std::unordered_set<const Variable*> vset;
-  for (size_t i = vars.size(); i != 1; --i) {
-    vset.insert(vars[i - 1].get());
-    // The previous coeff contains the variable
-    if (ExprUseVar(coeff[i - 2], vset)) {
-      return Array<Expr>();
+    std::unordered_set<const Variable*> vset;
+    for (size_t i = vars.size(); i != 1; --i) {
+      vset.insert(vars[i - 1].get());
+      // The previous coeff contains the variable
+      if (ExprUseVar(coeff[i - 2], vset)) {
+        return Array<Expr>();
+      }
     }
   }
   coeff.push_back(base);
diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc
index cafcddcb9..ba44253a0 100644
--- a/src/pass/inject_copy_intrin.cc
+++ b/src/pass/inject_copy_intrin.cc
@@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator {
  private:
   bool MatchCopyPattern(Stmt stmt, Stmt *out) {
     Stmt body = stmt;
+    bool is_single_point_copy = false;
 
     // strip the loops
     std::vector<const For*> loops;
@@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator {
     const Select* select = store->value.as<Select>();
     const Cast* cast = store->value.as<Cast>();
     const Load* load = store->value.as<Load>();
-
+    if (0 == loops.size()) {
+      is_single_point_copy = true;
+      CHECK(select == nullptr);
+    }
     // for now only support true condition matching
     if (select != nullptr) {
       load = select->true_value.as<Load>();
@@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator {
         arith::DetectLinearEquation(load->index, loop_vars);
     if (load_strides.size()  == 0 || store_strides.size() == 0) return false;
     Array<Expr> dst_shape;
-    for (const For* op : loops) {
-      dst_shape.push_back(op->extent);
+    auto loop_var_size = loop_vars.size();
+    if (is_single_point_copy) {
+      loop_var_size = 1;
+      dst_shape.push_back(make_const(Int(32), 1));
+    } else {
+      for (const For* op : loops) {
+        dst_shape.push_back(op->extent);
+      }
     }
     Array<Expr> src_shape = dst_shape;
     Array<Expr> pad_before, pad_after;
     Expr pad_value;
-    Expr src_elem_offset = load_strides[loop_vars.size()];
+    Expr src_elem_offset = load_strides[loop_var_size];
     if (select != nullptr) {
       Array<Expr> clip_bound =
           arith::DetectClipBound(select->condition, loop_vars);
@@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator {
       src_elem_offset = Simplify(src_elem_offset);
     }
     CHECK_EQ(load_strides.size(), store_strides.size());
-    CHECK_EQ(load_strides.size(), loop_vars.size() + 1);
-    Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size());
-    Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size());
+    CHECK_EQ(load_strides.size(), loop_var_size + 1);
+    Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
+    Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
     Buffer dst = BufferNode::make(
         Var(store->buffer_var.node_),
         store->value.type(),
         dst_shape,
         dst_strides,
-        store_strides[loop_vars.size()],
+        store_strides[loop_var_size],
         store->buffer_var->name_hint,
         GetStorageScope(store->buffer_var.get()),
         0, 0);
diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py
index 08477895b..c6ed19d65 100644
--- a/tests/python/unittest/test_pass_inject_copy_intrin.py
+++ b/tests/python/unittest/test_pass_inject_copy_intrin.py
@@ -44,6 +44,25 @@ def test_copy_pad():
         return tvm.make.Evaluate(0)
     stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
 
+def test_single_point_test():
+    A = tvm.placeholder((1,), name='A')
+    B = tvm.compute((1,), lambda i:
+                    A[i], name='B')
+    s = tvm.create_schedule(B.op)
+    s[B].pragma(B.op.axis[0], "memcpy")
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
+    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
+    def cb(src, dst, pad_before, pad_after, pad_value):
+        assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
+        assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
+        assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
+        assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
+        return tvm.make.Evaluate(0)
+    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
 def assert_expr_equal(a, b):
     assert tvm.ir_pass.Simplify(a - b).value == 0
 
@@ -80,3 +99,4 @@ if __name__ == "__main__":
     test_copy2d()
     test_copy_pad()
     test_copy_pad_split()
+    test_single_point_test()
-- 
GitLab