diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 9c54ab7b45fbd2c602181b9e21d6e2c3f4910715..e9d6123b5e068fda353b8af22ead90477c6b6eb8 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -32,6 +32,11 @@ class TargetNode : public Node { int max_num_threads = 1; /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ int thread_warp_size = 1; + /*! + * \brief The thread index that is the lowest(correspond to warp) + * In cuda it is threadIdx.x, but can be different in some platform. + */ + int thread_warp_index = 0; /*! \brief Keys for this target */ Array<Expr> keys_array; /*! \brief Options for this target */ @@ -48,6 +53,7 @@ class TargetNode : public Node { v->Visit("device_type", &device_type); v->Visit("max_num_threads", &max_num_threads); v->Visit("thread_warp_size", &thread_warp_size); + v->Visit("thread_warp_index", &thread_warp_index); v->Visit("keys_array", &keys_array); v->Visit("options_array", &options_array); v->Visit("libs_array", &libs_array); diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 1ae41032cbb89a14bff1285cafefeba269021d14..44a77aebbfd8e05d8607e7e8f4380f62a8d94b93 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -416,6 +416,20 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); */ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); +/*! + * \brief Remap the thread axis + * + * This can be used to get equivalent program which uses + * threadIdx.y in place of threadIdx.x by passing + * {"threadIdx.x": thread_axis("threadIdx.y")} + * + * + * \param f The device function to be lowered. + * \param axis_map The map from StringImm -> ItrVar + * \return Transformed function. + */ +LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map); + /*! * \brief Lower packed function call. * \param f The function to be lowered. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index ffb9bd4c7844dbb9ac9caa1d742f43651e13bcdf..6364b32fd64c80ee29011d11ec79b5001f8b5c7f 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -98,6 +98,7 @@ class DumpIR(object): schedule.ScheduleOps = self._old_sgpass DumpIR.scope_level -= 1 + @register_node class BuildConfig(NodeBase): """Configuration scope to set a build config option. @@ -469,6 +470,13 @@ def build(sch, for i, func in enumerate(fdevice): warp_size = target.thread_warp_size fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) + warp_index = target.thread_warp_index + if warp_index != 0: + assert warp_index == 2 + # swap z and x + tmap = {api.convert("threadIdx.z"): api.thread_axis("threadIdx.x"), + api.convert("threadIdx.x"): api.thread_axis("threadIdx.z")} + fdevice[i] = ir_pass.RemapThreadAxis(func, tmap) if "gpu" in target.keys and not fdevice: warnings.warn( diff --git a/python/tvm/target.py b/python/tvm/target.py index 3ca72bafdc85d1c170bc3ae85010b2e1f69887d4..1dabc862a0ee822c6b9bc8bf7a107d1a83539234 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -109,6 +109,7 @@ class Target(NodeBase): def __exit__(self, ptype, value, trace): _api_internal._ExitTargetScope() + @register_node class GenericFunc(NodeBase): """GenericFunc node reference. This represents a generic function diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 6d59cb3ae505698c706d4182140dadc8e68a6d6e..64c9225592296ace71ac222fb633af97a8e715d9 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -126,6 +126,7 @@ REGISTER_PASS2(LiftAttrScope); REGISTER_PASS1(NarrowChannelAccess); REGISTER_PASS2(LowerThreadAllreduce); REGISTER_PASS2(LowerWarpMemory); +REGISTER_PASS2(RemapThreadAxis); REGISTER_PASS2(LowerIntrin); REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(CombineContextCall); diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 318e231fc67c35a57165f9a1714f19136a6cad70..0bb2dfcf0206072bc7a06591f53250163725fe71 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -78,6 +78,8 @@ Target CreateTarget(const std::string& target_name, t->max_num_threads = 256; if (t->device_name == "intel_gpu") { t->thread_warp_size = 16; + // use threadIdx.z for index + t->thread_warp_index = 2; } } else if (target_name == "metal" || target_name == "vulkan") { if (target_name == "metal") { diff --git a/src/pass/remap_thread_axis.cc b/src/pass/remap_thread_axis.cc new file mode 100644 index 0000000000000000000000000000000000000000..94e4819a1d71b7c768a0f964ce89fdf89d7a4a1f --- /dev/null +++ b/src/pass/remap_thread_axis.cc @@ -0,0 +1,83 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file remap_thread_axis.cc + */ +#include <tvm/ir.h> +#include <tvm/ir_mutator.h> +#include <tvm/ir_visitor.h> +#include <tvm/ir_pass.h> +#include <unordered_map> + + +namespace tvm { +namespace ir { + +// Mutator to change the read pattern +class ThreadAxisRewriter : private IRMutator { + public: + explicit ThreadAxisRewriter( + const std::unordered_map<std::string, IterVar>& tmap) + : tmap_(tmap) { + } + + Stmt Rewrite(Stmt stmt) { + return Mutate(stmt); + } + + private: + Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final { + if (op->attr_key == attr::thread_extent) { + IterVar iv(op->node.node_); + CHECK_NE(iv->thread_tag.length(), 0U); + auto it = tmap_.find(iv->thread_tag); + if (it != tmap_.end()) { + const IterVar& new_iv = it->second; + const Variable* v = iv->var.get(); + if (!vmap_.count(v)) { + vmap_[v] = new_iv->var; + } else { + CHECK(vmap_[v].same_as(new_iv->var)); + } + Stmt body = this->Mutate(op->body); + return AttrStmt::make( + new_iv, op->attr_key, op->value, body); + } + } + return IRMutator::Mutate_(op, stmt); + } + + Expr Mutate_(const Variable* op, const Expr& expr) final { + auto it = vmap_.find(op); + if (it != vmap_.end()) return it->second; + return IRMutator::Mutate_(op, expr); + } + // The thread map + const std::unordered_map<std::string, IterVar>& tmap_; + // variable map + std::unordered_map<const Variable*, Var> vmap_; +}; + +LoweredFunc +RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) { + std::unordered_map<std::string, IterVar> tmap; + for (const auto& kv : thread_map) { + const StringImm* str = kv.first.as<StringImm>(); + CHECK(str != nullptr); + tmap[str->value] = kv.second; + } + + CHECK_EQ(f->func_type, kDeviceFunc); + auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); + // replace the thread axis + for (size_t i = 0; i < n->thread_axis.size(); ++i) { + auto it = tmap.find(n->thread_axis[i]->thread_tag); + if (it != tmap.end()) { + n->thread_axis.Set(i, it->second); + } + } + n->body = ThreadAxisRewriter(tmap).Rewrite(n->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 3ad464f75faaf59dc5736921aaccce6bf85241de..cebacdc2b9a5454699c66e03a7ca43d8adfd872a 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -34,9 +34,10 @@ def test_exp(): np.testing.assert_allclose( b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5) + check_device("opencl -device=intel_gpu") check_device("cuda", "llvm") check_device("vulkan") - check_device("opencl") + def test_log_pow_llvm(): @@ -196,8 +197,8 @@ def try_warp_memory(): if __name__ == "__main__": + test_exp() try_warp_memory() test_add() test_log_pow_llvm() - test_exp() test_popcount()