From 33910970444a9a0839d2a9846354740dd779c2a2 Mon Sep 17 00:00:00 2001
From: Salem Derisavi <33945117+derisavi-huawei@users.noreply.github.com>
Date: Tue, 12 Dec 2017 20:10:35 -0500
Subject: [PATCH] 1) Make unroll code reusable 2) reduce non-determinisim in
 CanonicalSimplify (#701)

* 1) Refactored some parts of the unrolling code into their own methods so we can reuse unrolling functionality in other parts of the code. E.g., to explicitly unroll loops with count of 1 when they are programmatically created.
2) Reorder based on top operator before resorting to pointers, which causes non-determinism.

* Fixed lint errors
---
 src/arithmetic/canonical.cc |  2 +
 src/pass/unroll_loop.cc     | 79 +++++++++++++++++++++++--------------
 2 files changed, 52 insertions(+), 29 deletions(-)

diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 24369db02..473e330de 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -29,6 +29,8 @@ struct ComExprEntry {
   inline bool operator<(const ComExprEntry& other) const {
     if (level < other.level) return true;
     if (level > other.level) return false;
+    if (value.type_index() < other.value.type_index()) return true;
+    if (value.type_index() > other.value.type_index()) return false;
     return value.get() < other.value.get();
   }
 };
diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc
index 01c5e6ebf..22fd38945 100644
--- a/src/pass/unroll_loop.cc
+++ b/src/pass/unroll_loop.cc
@@ -30,17 +30,7 @@ class LoopUnroller : public IRMutator {
   Stmt Mutate_(const For* op, const Stmt& s) {
     Stmt stmt = IRMutator::Mutate_(op, s);
     op = stmt.as<For>();
-    // constant folding.
-    Expr extent = ir::Simplify(op->extent);
-    const IntImm* v1 = extent.as<IntImm>();
-    const UIntImm* v2 = extent.as<UIntImm>();
-    int value = -1;
-    if (v1 != nullptr) {
-      value = static_cast<int>(v1->value);
-    }
-    if (v2 != nullptr) {
-      value = static_cast<int>(v2->value);
-    }
+    int value = GetExtent(op);
     // condition for auto unroll
     bool auto_unroll = (
         op->for_type == ForType::Serial &&
@@ -66,24 +56,7 @@ class LoopUnroller : public IRMutator {
     }
 
     if (auto_unroll && explicit_unroll_) {
-      using arith::ComputeExpr;
-      if (value == 0) return Evaluate::make(0);
-      Stmt body = op->body;
-      Map<Var, Expr> vmap;
-      Stmt unrolled;
-      for (int i = 0; i < value; ++i) {
-        Var lv(op->loop_var.node_);
-        vmap.Set(lv,
-                 ComputeExpr<Add>(
-                     op->min, make_const(op->loop_var.type(), i)));
-        Stmt step = Substitute(body, vmap);
-        if (unrolled.defined()) {
-          unrolled = Block::make(unrolled, step);
-        } else {
-          unrolled = step;
-        }
-      }
-      return unrolled;
+      return Unroll(op);
     } else {
       if (auto_unroll) {
         if (op->for_type != ForType::Unrolled) {
@@ -128,7 +101,47 @@ class LoopUnroller : public IRMutator {
     }
   }
 
+  Stmt Unroll(const For* op) {
+    using arith::ComputeExpr;
+    int value = GetExtent(op);
+    // For loop must have a constant integer extent
+    CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
+    if (value == 0) return Evaluate::make(0);
+    Stmt body = op->body;
+    Map<Var, Expr> vmap;
+    Stmt unrolled;
+    for (int i = 0; i < value; ++i) {
+      Var lv(op->loop_var.node_);
+      vmap.Set(lv,
+               ComputeExpr<Add>(
+                       op->min, make_const(op->loop_var.type(), i)));
+      Stmt step = Substitute(body, vmap);
+      if (unrolled.defined()) {
+        unrolled = Block::make(unrolled, step);
+      } else {
+        unrolled = step;
+      }
+    }
+    return unrolled;
+  }
+
  private:
+  // returns the extent of the loop if it's a constant integer, otherwise return -1
+  int GetExtent(const For* op) {
+    // constant folding.
+    Expr extent = ir::Simplify(op->extent);
+    const IntImm  *v1 = extent.as<IntImm>();
+    const UIntImm *v2 = extent.as<UIntImm>();
+    int value = -1;
+    if (v1 != nullptr) {
+      value = static_cast<int>(v1->value);
+    }
+    if (v2 != nullptr) {
+      value = static_cast<int>(v2->value);
+    }
+    return value;
+  }
+
   // maximum number of step to perform auto unroll.
   int auto_max_step_;
   int auto_max_depth_;
@@ -162,5 +175,13 @@ Stmt UnrollLoop(Stmt stmt,
   }
 }
 
+Stmt UnrollLoopExplicitly(Stmt stmt) {
+  const For* op = stmt.as<For>();
+  if (!op) {
+    LOG(FATAL) << "attempted to unroll a non-loop statement";
+  }
+  return LoopUnroller(0, 0, 0, false).Unroll(op);
+}
+
 }  // namespace ir
 }  // namespace tvm
-- 
GitLab