From f9281241c221deb414f55bf73d237246262e0347 Mon Sep 17 00:00:00 2001
From: Zhen Zhang <7168454+izgzhen@users.noreply.github.com>
Date: Thu, 18 Oct 2018 22:18:00 -0700
Subject: [PATCH] Check iter_type in vectorize (#1921)

---
 src/schedule/schedule_lang.cc               | 7 +++++++
 tests/python/unittest/test_lang_schedule.py | 9 +++++++++
 2 files changed, 16 insertions(+)

diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index d503e9788..29265f2e9 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -352,6 +352,13 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type)
 }
 
 Stage& Stage::vectorize(IterVar var) {   // NOLINT(*)
+  CHECK(var->iter_type == kDataPar ||
+        var->iter_type == kOpaque ||
+        var->iter_type == kUnrolled ||
+        var->iter_type == kVectorized ||
+        var->iter_type == kTensorized ||
+        var->iter_type == kParallelized)
+      << "Cannot vectorize on " << IterVarType2String(var->iter_type);
   SetAttrIterType(operator->(), var, kVectorized);
   return *this;
 }
diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py
index 1eb42f3f0..a00785dea 100644
--- a/tests/python/unittest/test_lang_schedule.py
+++ b/tests/python/unittest/test_lang_schedule.py
@@ -1,3 +1,4 @@
+from nose.tools import raises
 import tvm
 import pickle as pkl
 
@@ -112,6 +113,13 @@ def test_vectorize():
     assert s[T].iter_var_attrs[xi].iter_type == UNROLL
     assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
 
+@raises(Exception)
+def test_vectorize_commreduce():
+    V = tvm.placeholder((128,), name='V')
+    ax = tvm.reduce_axis((0, 128), name='ax')
+    O = tvm.compute((1,), lambda _: tvm.sum(V[ax], axis=[ax]))
+    s = tvm.create_schedule(O.op)
+    s[O].vectorize(ax) # should throw here
 
 def test_pragma():
     m = 100
@@ -197,3 +205,4 @@ if __name__ == "__main__":
     test_split()
     test_fuse()
     test_vectorize()
+    test_vectorize_commreduce()
-- 
GitLab