diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 6f7afa0b75d4e10cb2f501fd90922514859512bc..c4f338f1cd470dbb9d55338b83ff18205eaf16e1 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -159,8 +159,8 @@ struct IntSetNode : public Node { }; /*! - * \brief Detect if e can be rewritten as e = sum_{i=0}^n var[i] * coeff[i] + coeff[n] - * Where coeff and base are invariant of var. + * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] + * Where coeff[i] and base are invariant of var[j] for all i and j. * * \param e The expression to be detected. * \param vars List of variables to be used in detection. diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 25f5dde86b2e7fd2392b2fd726de27bdbb18d3d3..610532e261a32850e4ba57f00caa0d505abcbb69 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -128,7 +128,7 @@ class BufferNode : public Node { Type dtype, Array<Expr> shape, Array<Expr> strides, - Expr byte_offset, + Expr elem_offset, std::string name, std::string scope, int data_alignment, diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 4b499465240564ec1360f040cdab725515369232..6b95bd26865244d48dd77a256f4414c46ff95fbe 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -239,6 +239,24 @@ Stmt InjectPrefetch(Stmt stmt); */ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); +/*! + * \brief Inject copy intrinsics with optional pad. + * + * \param stmt The statment to be transformed. + * \param pragma_key The pragma key for hint of copy. + * \param fintrin The function with signature + * + * Stmt fintrin(Buffer src, + * Buffer dst, + * Array<Expr> pad_before, + * Array<Expr> pad_after, + * Expr pad_value) + * \return Transformed stmt. + */ +Stmt InjectCopyIntrin(Stmt stmt, + const std::string& pragma_key, + const runtime::PackedFunc& fintrin); + /*! * \brief Rewrite storage allocation pattern. * Moves the allocation to outer most possible scope. diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 39d94155a9900479f751c7a6b59c7aa16f41a0cc..1f66232baacc744e1b80ffc3374e9d064f373552 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -171,8 +171,12 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { } inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) - values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_)); - type_codes_[i] = kNodeHandle; + if (other.defined()) { + values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_)); + type_codes_[i] = kNodeHandle; + } else { + type_codes_[i] = kNull; + } } // type related stuffs diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 644ace9d9855ef77d230adf4590648e194b44ef3..2dacb32e54f70172795e44b7101d387709975451 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -92,6 +92,7 @@ REGISTER_PASS3(StorageFlatten); REGISTER_PASS4(IRTransform); REGISTER_PASS1(VectorizeLoop); REGISTER_PASS4(UnrollLoop); +REGISTER_PASS3(InjectCopyIntrin); REGISTER_PASS2(ThreadSync); REGISTER_PASS5(MakeAPI); REGISTER_PASS2(BindDeviceType); diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index c904b92d8ccb06a642e194e0d69aae6d1cc49b7d..518e7b3587b746cf090cb7ee40ec554927b4b053 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -307,7 +307,13 @@ class Canonical::Internal : public IRMutator { if (!op->is_pure()) { stack_.back().has_side_effect = true; } - return IRMutator::Mutate_(op, e); + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as<Call>(); + if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { + return op->args[0]; + } else { + return expr; + } } // For Stmt Mutate_(const For* op, const Stmt& s) { @@ -320,6 +326,13 @@ class Canonical::Internal : public IRMutator { --level_counter_; return stmt; } + // IfThenElse + Stmt Mutate_(const IfThenElse* op, const Stmt& s) { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<IfThenElse>(); + if (is_one(op->condition)) return op->then_case; + return stmt; + } // AttrStmt Stmt Mutate_(const AttrStmt* op, const Stmt& s) { if (op->attr_key == attr::thread_extent || diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7151ed0aeb5a449440b6f9912a1faff16735ec2 --- /dev/null +++ b/src/pass/inject_copy_intrin.cc @@ -0,0 +1,162 @@ +/*! + * Copyright (c) 2017 by Contributors + * \brief Replace certain copy with copy intrinsics. + * \file copy_intrin_rewrite.cc + */ +#include <tvm/ir.h> +#include <tvm/packed_func_ext.h> +#include <tvm/ir_mutator.h> +#include <tvm/ir_pass.h> + +namespace tvm { +namespace ir { + +using runtime::PackedFunc; + +class CopyIntrinInjector : public IRMutator { + public: + CopyIntrinInjector(const std::string& pragma_key, + const PackedFunc& flower_copy_fromto) + : pragma_key_(pragma_key), + flower_copy_fromto_(flower_copy_fromto) { + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->attr_key == attr::storage_scope) { + const Variable* buf = op->node.as<Variable>(); + storage_scope_[buf] = op->value.as<StringImm>()->value; + } else if (op->attr_key == ir::attr::pragma_scope) { + const std::string& pname = op->value.as<StringImm>()->value; + if (pname == pragma_key_) { + Stmt ret; + CHECK(MatchCopyPattern(op->body, &ret)) + << "Cannot match copy pattern of " << op->body; + return ret; + } + } + return IRMutator::Mutate_(op, s); + } + + private: + bool MatchCopyPattern(Stmt stmt, Stmt *out) { + Stmt body = stmt; + + // strip the loops + std::vector<const For*> loops; + while (const For* op = body.as<For>()) { + if (!is_zero(op->min)) return false; + loops.push_back(op); + body = op->body; + } + const Store* store = body.as<Store>(); + if (store == nullptr) return false; + const Select* select = store->value.as<Select>(); + const Load* load = store->value.as<Load>(); + + // for now only support true condition matching + if (select != nullptr) { + load = select->true_value.as<Load>(); + } + if (load == nullptr) return false; + if (load->type.lanes() != 1) return false; + Array<Var> loop_vars; + for (const For* op : loops) { + loop_vars.push_back(Var(op->loop_var.node_)); + } + Array<Expr> store_strides = + arith::DetectLinearEquation(store->index, loop_vars); + Array<Expr> load_strides = + arith::DetectLinearEquation(load->index, loop_vars); + if (load_strides.size() == 0 || store_strides.size() == 0) return false; + Array<Expr> dst_shape; + for (const For* op : loops) { + dst_shape.push_back(op->extent); + } + Array<Expr> src_shape = dst_shape; + Array<Expr> pad_before, pad_after; + Expr pad_value; + Expr src_elem_offset = load_strides[loop_vars.size()]; + if (select != nullptr) { + Array<Expr> clip_bound = + arith::DetectClipBound(select->condition, loop_vars); + pad_value = select->false_value; + if (clip_bound.size() == 0) return false; + CHECK_EQ(src_shape.size(), loop_vars.size()); + CHECK_EQ(clip_bound.size(), loop_vars.size() * 2); + for (size_t i = 0; i < src_shape.size(); ++i) { + Expr min_value = clip_bound[2 * i]; + Expr max_value = clip_bound[2 * i + 1]; + Type t = loop_vars[i].type(); + Expr svalue = src_shape[i]; + if (min_value.defined()) { + Expr pbefore = Simplify(Max::make(min_value, make_zero(t))); + src_elem_offset = src_elem_offset + pbefore * load_strides[i]; + svalue = svalue - pbefore; + pad_before.push_back(pbefore); + } else { + pad_before.push_back(make_zero(t)); + } + if (max_value.defined()) { + Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1), + make_zero(t))); + svalue = svalue - pafter; + pad_after.push_back(pafter); + } else { + pad_after.push_back(make_zero(t)); + } + src_shape.Set(i, Simplify(svalue)); + } + src_elem_offset = Simplify(src_elem_offset); + } + CHECK_EQ(load_strides.size(), store_strides.size()); + CHECK_EQ(load_strides.size(), loop_vars.size() + 1); + Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size()); + Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size()); + Buffer dst = BufferNode::make( + Var(store->buffer_var.node_), + load->type, + dst_shape, + dst_strides, + store_strides[loop_vars.size()], + store->buffer_var->name_hint, + GetStorageScope(store->buffer_var.get()), + 0, 0); + Buffer src = BufferNode::make( + Var(load->buffer_var.node_), + load->type, + src_shape, + src_strides, + src_elem_offset, + load->buffer_var->name_hint, + GetStorageScope(load->buffer_var.get()), + 0, 0); + *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); + CHECK(out->defined()) << "flower function did not return correct stmt"; + return true; + } + // Get storage scope + std::string GetStorageScope(const Variable* var) const { + auto it = storage_scope_.find(var); + if (it != storage_scope_.end()) { + return it->second; + } else { + return ""; + } + } + // pragma key + const std::string& pragma_key_; + // function to lower copy intrinsics. + const PackedFunc& flower_copy_fromto_; + // Storage scope + std::unordered_map<const Variable*, std::string> storage_scope_; +}; + +Stmt InjectCopyIntrin(Stmt stmt, + const std::string& pragma_key, + const PackedFunc& flower_copy_fromto) { + return CopyIntrinInjector(pragma_key, flower_copy_fromto) + .Mutate(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py new file mode 100644 index 0000000000000000000000000000000000000000..08477895b32201bfeebdcb7ff44116abb7a01ae7 --- /dev/null +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -0,0 +1,82 @@ +import tvm + +def test_copy2d(): + m = tvm.var('m') + l = tvm.var('l') + A = tvm.placeholder((m, l), name='A') + B = tvm.compute((m, l), lambda i, j: A[i, j], name='B') + s = tvm.create_schedule(B.op) + s[B].pragma(B.op.axis[0], "memcpy") + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') + Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + def cb(src, dst, pad_before, pad_after, pad_value): + assert dst.strides[0] == l + assert dst.strides[1].value == 1 + assert src.strides[0] == l + assert tuple(src.shape) == (m, l) + return tvm.make.Evaluate(0) + stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + +def test_copy_pad(): + m = tvm.var('m') + l = tvm.var('l') + A = tvm.placeholder((m, l), name='A') + B = tvm.compute((m + 2, l), lambda i, j: + tvm.select(tvm.all(i >= 1, i < m + 1), + A[i - 1, j], 1.0), name='B') + s = tvm.create_schedule(B.op) + s[B].pragma(B.op.axis[0], "memcpy") + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') + Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + def cb(src, dst, pad_before, pad_after, pad_value): + assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 + assert pad_before[0].value == 1 + assert pad_before[1].value == 0 + assert pad_after[0].value == 1 + assert pad_after[1].value == 0 + assert pad_value.value == 1.0 + return tvm.make.Evaluate(0) + stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + +def assert_expr_equal(a, b): + assert tvm.ir_pass.Simplify(a - b).value == 0 + +def test_copy_pad_split(): + m = 4 * 3 + A = tvm.placeholder((m, ), name="A") + Apad = tvm.compute((m + 2,), lambda i: + tvm.select(tvm.all(i >= 1, i <= m), + A[i - 1], 0.0), "Apad") + B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) + s = tvm.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=4) + s[Apad].compute_at(s[B], xo) + s[Apad].pragma(s[Apad].op.axis[0], "memcpy") + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') + Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + stmt = tvm.ir_pass.CanonicalSimplify(stmt) + def cb(src, dst, pad_before, pad_after, pad_value): + assert(dst.elem_offset.value == 0) + assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) + rpad_before = tvm.max(1 - xo * 4, 0) + rpad_after = tvm.max(xo * 4 - 7, 0) + assert_expr_equal(pad_before[0], rpad_before) + assert_expr_equal(pad_after[0], rpad_after) + assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) + return tvm.make.Evaluate(0) + stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + +if __name__ == "__main__": + test_copy2d() + test_copy_pad() + test_copy_pad_split()