From 883779888896453ad6e21f33dcd163d1a3358ae1 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Mon, 6 Feb 2017 16:26:15 -0800
Subject: [PATCH] [PASS] Canonical form simplify (#34)

---
 include/tvm/ir_pass.h                       |   7 +
 python/tvm/build.py                         |   8 +-
 src/api/api_pass.cc                         |   1 +
 src/arithmetic/canonical.cc                 | 486 ++++++++++++++++++++
 src/arithmetic/canonical.h                  |  55 +++
 src/arithmetic/int_set.cc                   |  26 +-
 src/arithmetic/int_set.h                    |   4 +
 src/codegen/codegen_cuda.cc                 |   2 +-
 src/codegen/codegen_opencl.cc               |   2 +-
 tests/python/integration/test_gemm.py       |   8 +-
 tests/python/unittest/test_pass_simplify.py |  26 ++
 11 files changed, 614 insertions(+), 11 deletions(-)
 create mode 100644 src/arithmetic/canonical.cc
 create mode 100644 src/arithmetic/canonical.h
 create mode 100644 tests/python/unittest/test_pass_simplify.py

diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index f8412dc36..b11486d90 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -62,6 +62,13 @@ bool HasSideEffect(const Expr& e);
  */
 Stmt ConvertSSA(Stmt stmt);
 
+/*!
+ * \brief Simplify by applying canonical form.
+ * \param stmt The statement to be canonically simplifed.
+ * \return Canonicalized statement.
+ */
+Stmt CanonicalSimplify(Stmt stmt);
+
 /*!
  * \brief Substitute the var specified in key->var to be value.
  * \param stmt The source statement to be substituted
diff --git a/python/tvm/build.py b/python/tvm/build.py
index fbed0a33f..bb03e8395 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -17,7 +17,8 @@ def build(sch,
           target,
           name="default_function",
           binds=None,
-          record_codes=None):
+          record_codes=None,
+          max_auto_unroll_step=8):
     """Build a function with arguments as signiture.
 
     Parameters
@@ -38,6 +39,9 @@ def build(sch,
         Dictionary that maps the binding of symbolic buffer to Tensor.
         By default, a new buffer is created for each tensor in the argument.
 
+    max_auto_unroll_step: int
+        Maximum step to perform automatic unrolling
+
     Returns
     -------
     f : Function, or pair of functions
@@ -64,6 +68,8 @@ def build(sch,
     bounds = schedule.InferBound(sch)
     stmt = schedule.ScheduleOps(sch, bounds)
     stmt = ir_pass.StorageFlatten(stmt, binds)
+    stmt = ir_pass.CanonicalSimplify(stmt)
+    stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
     stmt = ir_pass.Simplify(stmt)
     fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
     fsplits = ir_pass.SplitHostDevice(fapi)
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index df79996e4..ff67ac7a8 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -59,6 +59,7 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
 
 REGISTER_PASS1(ConvertSSA);
 REGISTER_PASS1(VerifySSA);
+REGISTER_PASS1(CanonicalSimplify);
 REGISTER_PASS4(Inline);
 REGISTER_PASS2(StorageFlatten);
 REGISTER_PASS2(UnrollLoop);
diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
new file mode 100644
index 000000000..2c9909455
--- /dev/null
+++ b/src/arithmetic/canonical.cc
@@ -0,0 +1,486 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file canonical.cc
+ * \brief Canonicalize simplification.
+ */
+#include <tvm/ir_mutator.h>
+#include "./int_set.h"
+#include "./canonical.h"
+#include "./compute_expr.h"
+
+namespace tvm {
+namespace arith {
+using namespace ir;
+
+// Canonical entry for communicative ops.
+struct ComExprEntry {
+  // the value of the expression.
+  Expr value;
+  // the level of the expression.
+  int level{0};
+  // The integer scale on value
+  int64_t scale{1};
+
+  ComExprEntry() {}
+  ComExprEntry(Expr value, int level)
+      : value(value), level(level) {}
+  inline bool operator<(const ComExprEntry& other) const {
+    if (level < other.level) return true;
+    if (level > other.level) return false;
+    return value.get() < other.value.get();
+  }
+};
+
+// canonical expression for communicative expression.
+struct ComExprNode {
+  // base constant value.
+  int64_t base{0};
+  // The values to be sumed.
+  std::vector<ComExprEntry> elem;
+};
+
+// canonical communicative expression
+struct ComExpr {
+ public:
+  // constructor
+  ComExpr() {}
+  explicit ComExpr(std::shared_ptr<ComExprNode> ptr) : ptr_(ptr) {}
+  // get member
+  ComExprNode* operator->() const {
+    return ptr_.get();
+  }
+  void reset() {
+    ptr_.reset();
+  }
+  bool defined() const {
+    return ptr_.get() != nullptr;
+  }
+  // comparator
+  bool operator<(const ComExpr& b) const {
+    const ComExpr& a = *this;
+    if (a->base < b->base) return true;
+    if (a->base > b->base) return false;
+    if (a->elem.size() < b->elem.size()) return true;
+    if (a->elem.size() > b->elem.size()) return false;
+    for (size_t i = 0; i < a->elem.size(); ++i) {
+      const ComExprEntry& ea = a->elem[i];
+      const ComExprEntry& eb = b->elem[i];
+      if (ea.level < eb.level) return true;
+      if (ea.level > eb.level) return false;
+      if (ea.value.get() < eb.value.get()) return true;
+      if (ea.value.get() > eb.value.get()) return false;
+      if (ea.scale < eb.scale) return true;
+      if (ea.scale > eb.scale) return false;
+    }
+    return false;
+  }
+  // equality
+  bool operator==(const ComExpr& b) const {
+    const ComExpr& a = *this;
+    if (a->base != b->base) return false;
+    if (a->elem.size() != b->elem.size()) return false;
+    for (size_t i = 0; i < a->elem.size(); ++i) {
+      const ComExprEntry& ea = a->elem[i];
+      const ComExprEntry& eb = b->elem[i];
+      if (ea.level != eb.level) return false;
+      if (ea.value.get() != eb.value.get()) return false;
+      if (ea.scale != eb.scale) return false;
+    }
+    return true;
+  }
+
+ private:
+  std::shared_ptr<ComExprNode> ptr_;
+};
+
+template<typename T>
+inline Expr Binary_(const T* op,
+                    const Expr& e,
+                    Expr a, Expr b) {
+  if (a.same_as(op->a) && b.same_as(op->b)) {
+    return e;
+  } else {
+    return T::make(a, b);
+  }
+}
+
+template<typename T>
+inline Expr Binary(
+    const T* op, const Expr& e, IRMutator* m) {
+  return Binary_(op, e, m->Mutate(op->a), m->Mutate(op->b));
+}
+
+// internal of canonical engine.
+class Canonical::Internal : public IRMutator {
+ public:
+  // stack entry.
+  struct StackEntry {
+    int max_level{0};
+    bool has_side_effect{false};
+  };
+  // aggressively canonicalized expression
+  struct CacheEntry {
+    // The canonical value of the expression.
+    Expr value;
+    // The level of the expression.
+    int max_level{0};
+    // whether the expression might have side effect.
+    bool has_side_effect{false};
+    // if not null, corresponds to to sum
+    ComExpr sum;
+    // reset the return entry.
+    void reset() {
+      sum.reset();
+    }
+    // as sum expr
+    ComExpr AsSum() const {
+      if (sum.defined()) return sum;
+      const int64_t *v1 = as_const_int(value);
+      const uint64_t *v2 = as_const_uint(value);
+      std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
+      if (v1) {
+        n->base = *v1;
+      } else if (v2) {
+        CHECK_LE(*v2,
+               static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
+        n->base = static_cast<int64_t>(*v2);
+      } else {
+        n->elem.push_back(ComExprEntry(value, max_level));
+      }
+      return ComExpr(n);
+    }
+  };
+  // Set range and level of var.
+  void SetRange(Var v, Range r, int level) {
+    var_range_[v.get()] = IntSet::range(r);
+    var_level_[v.get()] = level;
+    var_rec_.push_back(v);
+  }
+  // functions
+  Stmt Mutate(Stmt stmt) final {
+    return IRMutator::Mutate(stmt);
+  }
+  Expr MutateExpr_(Expr expr) {
+    static const FMutateExpr& f = Internal::vtable_expr();
+    stack_.push_back(StackEntry());
+    expr =  (f.can_dispatch(expr) ?
+            f(expr, expr, this) : IRMutator::Mutate(expr));
+    // update result of parent automatically during pop
+    if (stack_.size() > 1) {
+      StackEntry& back = stack_[stack_.size() - 1];
+      StackEntry& prev = stack_[stack_.size() - 2];
+      prev.max_level = std::max(prev.max_level, back.max_level);
+      if (back.has_side_effect) prev.has_side_effect = true;
+    }
+    // copy result from stack
+    ret_entry_.has_side_effect = stack_.back().has_side_effect;
+    ret_entry_.max_level = stack_.back().max_level;
+    stack_.pop_back();
+    return expr;
+  }
+  // call produce to get a cache entry.
+  CacheEntry Produce(Expr expr) {
+    ret_entry_.reset();
+    ret_entry_.value = MutateExpr_(expr);
+    CacheEntry ret  = ret_entry_;
+    ret_entry_.reset();
+    return ret;
+  }
+  Expr Mutate(Expr expr) final {
+    ret_entry_.reset();
+    expr = MutateExpr_(expr);
+    ret_entry_.reset();
+    return expr;
+  }
+
+  // Check whether do special canonicalization.
+  bool EnableOpt(Type t) const {
+    return (t.lanes() == 1 && (t.is_int() || t.is_uint()));
+  }
+  // Add
+  Expr Mutate_(const Add* op, const Expr& e) {
+    if (!EnableOpt(op->type)) {
+      return Binary(op, e, this);
+    }
+    CacheEntry a = Produce(op->a);
+    CacheEntry b = Produce(op->b);
+    if (a.has_side_effect || b.has_side_effect) {
+      return Binary_(op, e, a.value, b.value);
+    }
+    return SumAdd(a, b, +1);
+  }
+  // Sub
+  Expr Mutate_(const Sub* op, const Expr& e) {
+    if (!EnableOpt(op->type)) {
+      return Binary(op, e, this);
+    }
+    CacheEntry a = Produce(op->a);
+    CacheEntry b = Produce(op->b);
+    if (a.has_side_effect || b.has_side_effect) {
+      return Binary_(op, e, a.value, b.value);
+    }
+    return SumAdd(a, b, -1);
+  }
+  // Mul
+  Expr Mutate_(const Mul* op, const Expr& e) {
+    if (!EnableOpt(op->type)) {
+      return Binary(op, e, this);
+    }
+    CacheEntry a = Produce(op->a);
+    CacheEntry b = Produce(op->b);
+    if (a.has_side_effect || b.has_side_effect) {
+      return Binary_(op, e, a.value, b.value);
+    }
+    if (is_const(a.value) && is_const(b.value)) {
+      return ComputeExpr<Mul>(a.value, b.value);
+    } else if (is_const(a.value)) {
+      return SumMulConst(b.AsSum(), a.value);
+    } else if (is_const(b.value)) {
+      return SumMulConst(a.AsSum(), b.value);
+    } else {
+      return Binary_(op, e, a.value, b.value);
+    }
+  }
+  // Variable
+  Expr Mutate_(const Variable* op, const Expr& e) final {
+    auto it = var_level_.find(op);
+    if (it != var_level_.end()) {
+      stack_.back().max_level = it->second;
+    }
+    return IRMutator::Mutate_(op, e);
+  }
+  // comparison
+  Expr Mutate_(const LT* op, const Expr& e) {
+    if (!EnableOpt(op->a.type())) {
+      return Binary(op, e, this);
+    }
+    CacheEntry a = Produce(op->a);
+    CacheEntry b = Produce(op->b);
+    if (a.has_side_effect || b.has_side_effect) {
+      return Binary_(op, e, a.value, b.value);
+    }
+    Expr b_sub_a = SumAdd(b, a, -1);
+    if (EvalSet(b_sub_a, var_range_).can_prove_positive()) {
+      return make_const(op->type, true);
+    } else {
+      return Binary_(op, e, a.value, b.value);
+    }
+  }
+  // Call
+  Expr Mutate_(const Call* op, const Expr& e) final {
+    if (!op->is_pure()) {
+      stack_.back().has_side_effect = true;
+    }
+    return IRMutator::Mutate_(op, e);
+  }
+  // For
+  Stmt Mutate_(const For* op, const Stmt& s) {
+    ++level_counter_;
+    Var loop_var(op->loop_var.node_);
+    this->SetRange(loop_var,
+                   Range::make_with_min_extent(op->min, op->extent),
+                   level_counter_);
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    --level_counter_;
+    return stmt;
+  }
+  // AttrStmt
+  Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
+    if (op->type_key == "thread_extent") {
+      ++level_counter_;
+      IterVar iv(op->node.node_);
+      CHECK_NE(iv->thread_tag.length(), 0U);
+      if (!var_level_.count(iv->var.get())) {
+        this->SetRange(iv->var,
+                       Range::make_with_min_extent(0, op->value),
+                       level_counter_);
+      }
+      Stmt stmt = IRMutator::Mutate_(op, s);
+      --level_counter_;
+      return stmt;
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+  // The simplify statement.
+  static FMutateExpr& vtable_expr() {  // NOLINT(*)
+    static FMutateExpr inst; return inst;
+  }
+
+ private:
+  // return entry
+  CacheEntry ret_entry_;
+  // internal information stack
+  std::vector<StackEntry> stack_;
+  // cache sum
+  std::map<ComExpr, CacheEntry> cache_sum_;
+  // range of each var
+  std::unordered_map<const Variable*, IntSet> var_range_;
+  // level of each var
+  std::unordered_map<const Variable*, int> var_level_;
+  // record history vars, to avoid false positive.
+  std::vector<Var> var_rec_;
+  // level counter
+  int level_counter_{0};
+  // subroutine to do produce
+  Expr SumMulConst(ComExpr a, Expr v) {
+    int64_t value = 0;
+    const int64_t *v1 = as_const_int(v);
+    const uint64_t *v2 = as_const_uint(v);
+    CHECK(v1 || v2);
+    if (v1) {
+      value = *v1;
+    } else if (v2) {
+      CHECK_LE(*v2,
+               static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
+      value = static_cast<int64_t>(*v2);
+    }
+
+    if (value == 0) {
+      return make_zero(v.type());
+    }
+    std::shared_ptr<ComExprNode> vsum =
+        std::make_shared<ComExprNode>(*a.operator->());
+    vsum->base *= value;
+    for (auto& e : vsum->elem) {
+      e.scale *= value;
+    }
+    ret_entry_.max_level = stack_.back().max_level;
+    ret_entry_.has_side_effect = stack_.back().has_side_effect;
+    ret_entry_.sum = ComExpr(vsum);
+    auto it = cache_sum_.find(ret_entry_.sum);
+    if (it != cache_sum_.end()) {
+      ret_entry_ = it->second;
+    } else {
+      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
+      cache_sum_[ret_entry_.sum] = ret_entry_;
+    }
+    return ret_entry_.value;
+  }
+  // add two ComExpr together
+  ComExpr SumAdd_(const ComExpr& suma,
+                  const ComExpr& sumb,
+                  int bscale) {
+    std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
+    n->base = suma->base + sumb->base;
+    // merge of suma and sumb;
+    size_t i = 0, j = 0;
+    while (i < suma->elem.size() && j < sumb->elem.size()) {
+      const auto& a = suma->elem[i];
+      const auto& b = sumb->elem[j];
+      if (a.value.same_as(b.value)) {
+        CHECK_EQ(a.level, b.level);
+        ComExprEntry e = a;
+        e.scale = a.scale + b.scale * bscale;
+        if (e.scale != 0) {
+          n->elem.push_back(e);
+        }
+        ++i; ++j;
+      } else if (a < b) {
+        n->elem.push_back(a);
+        ++i;
+      } else {
+        ComExprEntry e = b;
+        e.scale *= bscale;
+        n->elem.push_back(e);
+        ++j;
+      }
+    }
+    for (; i < suma->elem.size(); ++i) {
+      n->elem.push_back(suma->elem[i]);
+    }
+    for (; j < sumb->elem.size(); ++j) {
+      ComExprEntry e = sumb->elem[j];
+      e.scale *= bscale;
+      n->elem.push_back(e);
+    }
+    return ComExpr(n);
+  }
+  // subroutine to do produce
+  Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
+    ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
+    ret_entry_.max_level = stack_.back().max_level;
+    ret_entry_.has_side_effect = stack_.back().has_side_effect;
+    auto it = cache_sum_.find(ret_entry_.sum);
+    if (it != cache_sum_.end()) {
+      ret_entry_ = it->second;
+    } else {
+      ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
+      cache_sum_[ret_entry_.sum] = ret_entry_;
+    }
+    ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
+    cache_sum_[ret_entry_.sum] = ret_entry_;
+    return ret_entry_.value;
+  }
+  // convert sum to expr
+  Expr Sum2Expr(const ComExpr& com, Type t) {
+    Expr vsum;
+    if (com->base != 0) {
+      vsum = make_const(t, com->base);
+    }
+    for (const ComExprEntry& e : com->elem) {
+      if (e.scale > 0) {
+        Expr v = e.value;
+        if (e.scale != 1) {
+          v = Mul::make(v, make_const(t, e.scale));
+        }
+        if (vsum.defined()) {
+          vsum = Add::make(vsum, v);
+        } else {
+          vsum = v;
+        }
+      }
+    }
+    for (const ComExprEntry& e : com->elem) {
+      if (e.scale < 0) {
+        Expr v = e.value;
+        if (e.scale != -1) {
+          v = Mul::make(v, make_const(t, -e.scale));
+        }
+        if (vsum.defined()) {
+          vsum = Sub::make(vsum, v);
+        } else {
+          vsum = Sub::make(make_zero(t), v);
+        }
+      }
+    }
+    return vsum;
+  }
+};
+
+using CInternal = Canonical::Internal;
+
+#define DISPATCH_EXPR(OP)                                          \
+  set_dispatch<OP>([](const OP *op, const Expr& e, IRMutator* p) { \
+    return static_cast<CInternal*>(p)->Mutate_(op, e); })
+
+TVM_STATIC_IR_FUNCTOR(CInternal, vtable_expr)
+.DISPATCH_EXPR(Add)
+.DISPATCH_EXPR(Sub)
+.DISPATCH_EXPR(Mul)
+.DISPATCH_EXPR(LT);
+
+
+Canonical::Canonical()
+    : ptr_(std::make_shared<Internal>()) {}
+
+Expr Canonical::Simplify(Expr expr) {
+  return ptr_->Mutate(expr);
+}
+
+Stmt Canonical::Simplify(Stmt stmt) {
+  return ptr_->Mutate(stmt);
+}
+
+void Canonical::SetRange(Var v, Range r, int level) {
+  ptr_->SetRange(v, r, level);
+}
+}  // namespace arith
+
+namespace ir {
+Stmt CanonicalSimplify(Stmt stmt) {
+  return arith::Canonical().Simplify(stmt);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/arithmetic/canonical.h b/src/arithmetic/canonical.h
new file mode 100644
index 000000000..174acc20a
--- /dev/null
+++ b/src/arithmetic/canonical.h
@@ -0,0 +1,55 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file canonical.h
+ * \brief Internal canonicalized expression simplification engine.
+ */
+#ifndef TVM_ARITHMETIC_CANONICAL_H_
+#define TVM_ARITHMETIC_CANONICAL_H_
+
+#include <tvm/expr.h>
+#include <tvm/schedule.h>
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief A stateful CanonicalEngine over SSA.
+ *
+ *  Simplify and CSE with canonicalization expressions.
+ *  Each call's result will get cached, so next call will
+ *  simply return the cached result.
+ */
+class Canonical {
+ public:
+  /*! \brief constructor */
+  Canonical();
+  /*!
+   * \brief simplify expression e.
+   * \param expr The expression to be simplified.
+   */
+  Expr Simplify(Expr expr);
+  /*!
+   * \brief simplify stmt.
+   * \param stmt The stmt to be simplified.
+   */
+  Stmt Simplify(Stmt expr);
+  /*!
+   * \brief Set range and level variable
+   * \param v The variable
+   * \param r The range of the variable, can be undefined.
+   * \param level The scope level of the variable,
+   *  affect the order of formula in communicative ops.
+   */
+  void SetRange(Var v, Range r, int level);
+
+  class Internal;
+ private:
+  // Internal pointer
+  std::shared_ptr<Internal> ptr_;
+};
+
+
+}  // namespace arith
+}  // namespace tvm
+
+#endif  // TVM_ARITHMETIC_CANONICAL_H_
diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc
index 04b40191d..d60504f2c 100644
--- a/src/arithmetic/int_set.cc
+++ b/src/arithmetic/int_set.cc
@@ -94,6 +94,11 @@ bool IntSet::is_single_point() const {
   return (s_int && s_int->i.is_single_point());
 }
 
+bool IntSet::can_prove_positive() const {
+  const IntervalSet* s_int = (*this).as<IntervalSet>();
+  return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
+}
+
 Expr IntSet::point_value() const {
   const IntervalSet* s_int = (*this).as<IntervalSet>();
   CHECK(s_int && s_int->i.is_single_point());
@@ -358,6 +363,9 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
 // Evaluator to evalute the epxression.
 class IntSetEvaluator {
  public:
+  explicit IntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
+      : dom_map(dom_map) {}
+
   inline IntSet Eval(Expr expr) {
     static const FType& f = vtable();
     if (f.can_dispatch(expr)) {
@@ -373,7 +381,7 @@ class IntSetEvaluator {
     static FType inst; return inst;
   }
 
-  std::unordered_map<const Variable*, IntSet> dom_map;
+  const std::unordered_map<const Variable*, IntSet>& dom_map;
 };
 
 inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) {
@@ -424,21 +432,29 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
 .set_dispatch<And>(Binary<And>)
 .set_dispatch<Or>(Binary<Or>);
 
+
+IntSet EvalSet(Expr e,
+               const std::unordered_map<const Variable*, IntSet>& dom_map) {
+  return IntSetEvaluator(dom_map).Eval(e);
+}
+
 IntSet EvalSet(Expr e,
                const Map<IterVar, IntSet>& dom_map) {
-  IntSetEvaluator m;
+  std::unordered_map<const Variable*, IntSet> dmap;
   for (auto kv : dom_map) {
-    m.dom_map[kv.first->var.as<Variable>()] = kv.second;
+    dmap[kv.first->var.as<Variable>()] = kv.second;
   }
+  IntSetEvaluator m(dmap);
   return m.Eval(e);
 }
 
 IntSet EvalSet(Range r,
                const Map<IterVar, IntSet>& dom_map) {
-  IntSetEvaluator m;
+  std::unordered_map<const Variable*, IntSet> dmap;
   for (auto kv : dom_map) {
-    m.dom_map[kv.first->var.as<Variable>()] = kv.second;
+    dmap[kv.first->var.as<Variable>()] = kv.second;
   }
+  IntSetEvaluator m(dmap);
   IntSet min_set = m.Eval(r->min);
   IntSet ext_set = m.Eval(r->extent).cover_interval();
   const Interval& ei = ext_set.as<IntervalSet>()->i;
diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h
index 80c2fae79..979d138af 100644
--- a/src/arithmetic/int_set.h
+++ b/src/arithmetic/int_set.h
@@ -44,6 +44,8 @@ class IntSet : public NodeRef {
   bool is_everything() const;
   /*! \return Whether the set is a single point */
   bool is_single_point() const;
+  /*! \return Whether the set is proved to be bigger than 0 */
+  bool can_prove_positive() const;
   /*!
    * \brief The single point value, call only if is_single_point is true
    * \return The point value.
@@ -88,6 +90,8 @@ struct IntSetNode : public Node {
  */
 IntSet EvalSet(Expr e,
                const Map<IterVar, IntSet>& dom_map);
+IntSet EvalSet(Expr e,
+               const std::unordered_map<const Variable*, IntSet>& dom_map);
 
 /*!
  * \brief Find an symbolic integer set that contains is union over
diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc
index b4957a3d5..9098200bd 100644
--- a/src/codegen/codegen_cuda.cc
+++ b/src/codegen/codegen_cuda.cc
@@ -45,7 +45,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
   std::ostringstream os;
   os << "typedef int int32_t;\n"
      << "typedef unsigned unt32_t;\n";
-  bool output_ssa = true;
+  bool output_ssa = false;
   for (LoweredFunc f : funcs) {
     os << CodeGenCUDA().Compile(f, output_ssa);
     os << '\n';
diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc
index 3d54a66a8..bafb56deb 100644
--- a/src/codegen/codegen_opencl.cc
+++ b/src/codegen/codegen_opencl.cc
@@ -57,7 +57,7 @@ MakeOpenCL(Array<LoweredFunc> funcs) {
   std::ostringstream os;
   os << "typedef int int32_t;\n"
      << "typedef unsigned unt32_t;\n";
-  bool output_ssa = true;
+  bool output_ssa = false;
   for (LoweredFunc f : funcs) {
     os << CodeGenOpenCL().Compile(f, output_ssa);
     os << '\n';
diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py
index ac5c5c2c4..8b63d8c08 100644
--- a/tests/python/integration/test_gemm.py
+++ b/tests/python/integration/test_gemm.py
@@ -3,9 +3,9 @@ import numpy as np
 
 def test_gemm():
     # graph
-    nn = 1235
+    nn = 1024
     n = tvm.Var('n')
-    #n = tvm.convert(nn)
+    n = tvm.convert(nn)
     m = n
     l = n
     A = tvm.placeholder((n, l), name='A')
@@ -52,12 +52,14 @@ def test_gemm():
     _, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y)
     _, xi = s[BB].split(xi, outer=thread_x)
 
+    max_auto_unroll_step = 0
     # lowering test
     s.normalize()
 
     def check_device(target):
         codes = []
-        f = tvm.build(s, [A, B, C], target, record_codes=codes)
+        f = tvm.build(s, [A, B, C], target, record_codes=codes,
+                      max_auto_unroll_step=max_auto_unroll_step)
         for c in codes[1:]:
             print(c)
         if target == "cuda":
diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py
new file mode 100644
index 000000000..9002b9686
--- /dev/null
+++ b/tests/python/unittest/test_pass_simplify.py
@@ -0,0 +1,26 @@
+import tvm
+import numpy
+
+def test_simplify():
+    """Not yet working, mock design"""
+    dtype = 'int64'
+    n = tvm.Var('n')
+    Ab = tvm.Buffer((n, ), dtype)
+    i = tvm.Var('i')
+    j = tvm.Var('j')
+    # for i in 0 to n-1:
+    stmt = tvm.make.For(
+        i, 2, n, 0, 0,
+        tvm.make.For(j, 0, n, 0, 0,
+                     tvm.make.IfThenElse(
+                         tvm.make.LT(i + 2, n),
+                         tvm.make.Store(Ab.data,
+                                        tvm.make.Load(dtype, Ab.data, i + 4) + 1,
+                                        (j + 1) * 4 - 4 * j + i),
+                         None)))
+    print(stmt)
+    stmt = tvm.ir_pass.CanonicalSimplify(stmt)
+    print(stmt)
+
+if __name__ == "__main__":
+    test_simplify()
-- 
GitLab