diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 63f582160312a7f013d37987f4714f25db25d5c3..642a866866d26963553e1d9cef7af36303d09404 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -123,25 +123,28 @@ class LinearEqDetector }; Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) { - CHECK_GE(vars.size(), 1U); Expr base = e; Array<Expr> coeff; - for (Var v : vars) { - LinearEqEntry ret; - if (!LinearEqDetector(v).Detect(base, &ret)) { - return Array<Expr>(); + if (0 == vars.size()) { + coeff.push_back(make_const(Int(32), 1)); + } else { + for (Var v : vars) { + LinearEqEntry ret; + if (!LinearEqDetector(v).Detect(base, &ret)) { + return Array<Expr>(); + } + coeff.push_back(ret.coeff); + base = std::move(ret.base); } - coeff.push_back(ret.coeff); - base = std::move(ret.base); - } - std::unordered_set<const Variable*> vset; - for (size_t i = vars.size(); i != 1; --i) { - vset.insert(vars[i - 1].get()); - // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset)) { - return Array<Expr>(); + std::unordered_set<const Variable*> vset; + for (size_t i = vars.size(); i != 1; --i) { + vset.insert(vars[i - 1].get()); + // The previous coeff contains the variable + if (ExprUseVar(coeff[i - 2], vset)) { + return Array<Expr>(); + } } } coeff.push_back(base); diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index cafcddcb9dde088e8bf74ecf2f1cb599d8ef4e5b..ba44253a0cd5c10e4f31a5b4f55497c29763536e 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator { private: bool MatchCopyPattern(Stmt stmt, Stmt *out) { Stmt body = stmt; + bool is_single_point_copy = false; // strip the loops std::vector<const For*> loops; @@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator { const Select* select = store->value.as<Select>(); const Cast* cast = store->value.as<Cast>(); const Load* load = store->value.as<Load>(); - + if (0 == loops.size()) { + is_single_point_copy = true; + CHECK(select == nullptr); + } // for now only support true condition matching if (select != nullptr) { load = select->true_value.as<Load>(); @@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator { arith::DetectLinearEquation(load->index, loop_vars); if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array<Expr> dst_shape; - for (const For* op : loops) { - dst_shape.push_back(op->extent); + auto loop_var_size = loop_vars.size(); + if (is_single_point_copy) { + loop_var_size = 1; + dst_shape.push_back(make_const(Int(32), 1)); + } else { + for (const For* op : loops) { + dst_shape.push_back(op->extent); + } } Array<Expr> src_shape = dst_shape; Array<Expr> pad_before, pad_after; Expr pad_value; - Expr src_elem_offset = load_strides[loop_vars.size()]; + Expr src_elem_offset = load_strides[loop_var_size]; if (select != nullptr) { Array<Expr> clip_bound = arith::DetectClipBound(select->condition, loop_vars); @@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator { src_elem_offset = Simplify(src_elem_offset); } CHECK_EQ(load_strides.size(), store_strides.size()); - CHECK_EQ(load_strides.size(), loop_vars.size() + 1); - Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size()); - Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size()); + CHECK_EQ(load_strides.size(), loop_var_size + 1); + Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); + Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); Buffer dst = BufferNode::make( Var(store->buffer_var.node_), store->value.type(), dst_shape, dst_strides, - store_strides[loop_vars.size()], + store_strides[loop_var_size], store->buffer_var->name_hint, GetStorageScope(store->buffer_var.get()), 0, 0); diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index 08477895b32201bfeebdcb7ff44116abb7a01ae7..c6ed19d65b69ad7ddb768664ae93e1e60ca0dd33 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -44,6 +44,25 @@ def test_copy_pad(): return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) +def test_single_point_test(): + A = tvm.placeholder((1,), name='A') + B = tvm.compute((1,), lambda i: + A[i], name='B') + s = tvm.create_schedule(B.op) + s[B].pragma(B.op.axis[0], "memcpy") + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') + Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + def cb(src, dst, pad_before, pad_after, pad_value): + assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 + assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0 + assert tvm.ir_pass.Simplify(src.strides[0]).value == 1 + assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1 + return tvm.make.Evaluate(0) + stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + def assert_expr_equal(a, b): assert tvm.ir_pass.Simplify(a - b).value == 0 @@ -80,3 +99,4 @@ if __name__ == "__main__": test_copy2d() test_copy_pad() test_copy_pad_split() + test_single_point_test()