From f9281241c221deb414f55bf73d237246262e0347 Mon Sep 17 00:00:00 2001 From: Zhen Zhang <7168454+izgzhen@users.noreply.github.com> Date: Thu, 18 Oct 2018 22:18:00 -0700 Subject: [PATCH] Check iter_type in vectorize (#1921) --- src/schedule/schedule_lang.cc | 7 +++++++ tests/python/unittest/test_lang_schedule.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index d503e9788..29265f2e9 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -352,6 +352,13 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) } Stage& Stage::vectorize(IterVar var) { // NOLINT(*) + CHECK(var->iter_type == kDataPar || + var->iter_type == kOpaque || + var->iter_type == kUnrolled || + var->iter_type == kVectorized || + var->iter_type == kTensorized || + var->iter_type == kParallelized) + << "Cannot vectorize on " << IterVarType2String(var->iter_type); SetAttrIterType(operator->(), var, kVectorized); return *this; } diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 1eb42f3f0..a00785dea 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -1,3 +1,4 @@ +from nose.tools import raises import tvm import pickle as pkl @@ -112,6 +113,13 @@ def test_vectorize(): assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE +@raises(Exception) +def test_vectorize_commreduce(): + V = tvm.placeholder((128,), name='V') + ax = tvm.reduce_axis((0, 128), name='ax') + O = tvm.compute((1,), lambda _: tvm.sum(V[ax], axis=[ax])) + s = tvm.create_schedule(O.op) + s[O].vectorize(ax) # should throw here def test_pragma(): m = 100 @@ -197,3 +205,4 @@ if __name__ == "__main__": test_split() test_fuse() test_vectorize() + test_vectorize_commreduce() -- GitLab