diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index d503e978887e9991debf7a61f376ce79a8ebee9a..29265f2e94b85b735e65fdac997b409212d6a448 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -352,6 +352,13 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) } Stage& Stage::vectorize(IterVar var) { // NOLINT(*) + CHECK(var->iter_type == kDataPar || + var->iter_type == kOpaque || + var->iter_type == kUnrolled || + var->iter_type == kVectorized || + var->iter_type == kTensorized || + var->iter_type == kParallelized) + << "Cannot vectorize on " << IterVarType2String(var->iter_type); SetAttrIterType(operator->(), var, kVectorized); return *this; } diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 1eb42f3f0bca6da05bd6458e6cf79573aa5184f2..a00785dea7afc4df7fccc02fa0feb9e2bc1ad0f6 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -1,3 +1,4 @@ +from nose.tools import raises import tvm import pickle as pkl @@ -112,6 +113,13 @@ def test_vectorize(): assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE +@raises(Exception) +def test_vectorize_commreduce(): + V = tvm.placeholder((128,), name='V') + ax = tvm.reduce_axis((0, 128), name='ax') + O = tvm.compute((1,), lambda _: tvm.sum(V[ax], axis=[ax])) + s = tvm.create_schedule(O.op) + s[O].vectorize(ax) # should throw here def test_pragma(): m = 100 @@ -197,3 +205,4 @@ if __name__ == "__main__": test_split() test_fuse() test_vectorize() + test_vectorize_commreduce()