From 10d9da486518a0128f17156caf023c5339b96087 Mon Sep 17 00:00:00 2001
From: Salem Derisavi <33945117+derisavi-huawei@users.noreply.github.com>
Date: Thu, 30 Nov 2017 18:44:08 -0500
Subject: [PATCH] Consider variable range information during simplification of
 tensorize expressions (#674)

---
 src/arithmetic/canonical.cc |  2 ++
 src/op/tensorize.cc         | 22 +++++++++++++++-------
 2 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index e7f9da1b4..24369db02 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -5,6 +5,7 @@
  */
 #include <tvm/ir_mutator.h>
 #include <tvm/arithmetic.h>
+#include <tvm/ir_pass.h>
 #include "./canonical.h"
 #include "./compute_expr.h"
 #include "arithmetic/Simplify.h"
@@ -612,6 +613,7 @@ void Canonical::SetRange(Var v, Range r, int level) {
 }  // namespace arith
 
 namespace ir {
+
 Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
   return arith::Canonical(vrange).Simplify(stmt);
 }
diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc
index 9715fcbab..243b7931d 100644
--- a/src/op/tensorize.cc
+++ b/src/op/tensorize.cc
@@ -187,7 +187,8 @@ class TensorIntrinMatcher final : public IRMutator {
             const Stage& stage,
             const std::unordered_map<IterVar, Range>& out_dom,
             const std::unordered_map<Tensor, Array<Range> >& in_region,
-            const TensorIntrin& intrin) {
+            const TensorIntrin& intrin,
+            Map<Var, Range>* compute_intrin_iter_space) {
     CHECK(self == stage->op.get());
     // input remap.
     Array<Tensor> inputs = self->InputTensors();
@@ -232,6 +233,7 @@ class TensorIntrinMatcher final : public IRMutator {
       Range r = out_dom.at(iv);
       var_remap_[iv->var.get()] = target_iv->var + r->min;
       axis_remap_[iv] = target_iv;
+      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
     }
     // Remap reduction axis
     CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
@@ -251,6 +253,7 @@ class TensorIntrinMatcher final : public IRMutator {
       Range r = out_dom.at(iv);
       var_remap_[iv->var.get()] = target_iv->var + r->min;
       axis_remap_[iv] = target_iv;
+      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
     }
   }
 
@@ -275,9 +278,10 @@ Array<Expr> MatchTensorizeBody(
     const Stage& stage,
     const std::unordered_map<IterVar, Range>& out_dom,
     const std::unordered_map<Tensor, Array<Range> >& in_region,
-    const TensorIntrin& intrin) {
+    const TensorIntrin& intrin,
+    Map<Var, Range>* compute_intrin_iter_space) {
   TensorIntrinMatcher matcher;
-  matcher.Init(self, stage, out_dom, in_region, intrin);
+  matcher.Init(self, stage, out_dom, in_region, intrin, compute_intrin_iter_space);
   Array<Expr> ret;
   for (Expr expr : self->body) {
     ret.push_back(matcher.Mutate(expr));
@@ -291,14 +295,16 @@ void VerifyTensorizeBody(
     const std::unordered_map<IterVar, Range>& out_dom,
     const std::unordered_map<Tensor, Array<Range> >& in_region,
     const TensorIntrin& intrin) {
-  Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin);
+  Map<Var, Range> compute_intrin_iter_space;
+  Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin,
+                                        &compute_intrin_iter_space);
   const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
   CHECK(intrin_compute) << "Only support compute intrinsic for now";
   CHECK_EQ(body.size(), intrin_compute->body.size())
       << "Tensorize failed: body size mismatch";
   for (size_t i = 0; i < body.size(); ++i) {
-    Expr lhs = CanonicalSimplify(body[i]);
-    Expr rhs = CanonicalSimplify(intrin_compute->body[i]);
+    Expr lhs = CanonicalSimplify(body[i], compute_intrin_iter_space);
+    Expr rhs = CanonicalSimplify(intrin_compute->body[i], compute_intrin_iter_space);
     if (lhs.type() != rhs.type()) {
       LOG(FATAL)
           << "Failed to match the data type with TensorIntrin "
@@ -459,11 +465,13 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody")
     Map<IterVar, Range> out_dom = args[1];
     Map<Tensor, Array<Range> > in_region = args[2];
     TensorIntrin intrin = args[3];
+    Map<Var, Range> vrange;
     CHECK(stage->op.as<ComputeOpNode>());
     *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
                               stage,
                               as_unordered_map(out_dom),
                               as_unordered_map(in_region),
-                              intrin);
+                              intrin,
+                              &vrange);
   });
 }  // namespace tvm
-- 
GitLab