diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h
index 9c54ab7b45fbd2c602181b9e21d6e2c3f4910715..e9d6123b5e068fda353b8af22ead90477c6b6eb8 100644
--- a/include/tvm/build_module.h
+++ b/include/tvm/build_module.h
@@ -32,6 +32,11 @@ class TargetNode : public Node {
   int max_num_threads = 1;
   /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
   int thread_warp_size = 1;
+  /*!
+   * \brief The thread index that is the lowest(correspond to warp)
+   *  In cuda it is threadIdx.x, but can be different in some platform.
+   */
+  int thread_warp_index = 0;
   /*! \brief Keys for this target */
   Array<Expr> keys_array;
   /*! \brief Options for this target */
@@ -48,6 +53,7 @@ class TargetNode : public Node {
     v->Visit("device_type", &device_type);
     v->Visit("max_num_threads", &max_num_threads);
     v->Visit("thread_warp_size", &thread_warp_size);
+    v->Visit("thread_warp_index", &thread_warp_index);
     v->Visit("keys_array", &keys_array);
     v->Visit("options_array", &options_array);
     v->Visit("libs_array", &libs_array);
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 1ae41032cbb89a14bff1285cafefeba269021d14..44a77aebbfd8e05d8607e7e8f4380f62a8d94b93 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -416,6 +416,20 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
  */
 LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
 
+/*!
+ * \brief Remap the thread axis
+ *
+ *  This can be used to get equivalent program which uses
+ *  threadIdx.y in place of threadIdx.x by passing
+ *  {"threadIdx.x": thread_axis("threadIdx.y")}
+ *
+ *
+ * \param f The device function to be lowered.
+ * \param axis_map The map from StringImm -> ItrVar
+ * \return Transformed function.
+ */
+LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map);
+
 /*!
  * \brief Lower packed function call.
  * \param f The function to be lowered.
diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py
index ffb9bd4c7844dbb9ac9caa1d742f43651e13bcdf..6364b32fd64c80ee29011d11ec79b5001f8b5c7f 100755
--- a/python/tvm/build_module.py
+++ b/python/tvm/build_module.py
@@ -98,6 +98,7 @@ class DumpIR(object):
         schedule.ScheduleOps = self._old_sgpass
         DumpIR.scope_level -= 1
 
+
 @register_node
 class BuildConfig(NodeBase):
     """Configuration scope to set a build config option.
@@ -469,6 +470,13 @@ def build(sch,
     for i, func in enumerate(fdevice):
         warp_size = target.thread_warp_size
         fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
+        warp_index = target.thread_warp_index
+        if warp_index != 0:
+            assert warp_index == 2
+            # swap z and x
+            tmap = {api.convert("threadIdx.z"): api.thread_axis("threadIdx.x"),
+                    api.convert("threadIdx.x"): api.thread_axis("threadIdx.z")}
+            fdevice[i] = ir_pass.RemapThreadAxis(func, tmap)
 
     if "gpu" in target.keys and not fdevice:
         warnings.warn(
diff --git a/python/tvm/target.py b/python/tvm/target.py
index 3ca72bafdc85d1c170bc3ae85010b2e1f69887d4..1dabc862a0ee822c6b9bc8bf7a107d1a83539234 100644
--- a/python/tvm/target.py
+++ b/python/tvm/target.py
@@ -109,6 +109,7 @@ class Target(NodeBase):
     def __exit__(self, ptype, value, trace):
         _api_internal._ExitTargetScope()
 
+
 @register_node
 class GenericFunc(NodeBase):
     """GenericFunc node reference. This represents a generic function
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index 6d59cb3ae505698c706d4182140dadc8e68a6d6e..64c9225592296ace71ac222fb633af97a8e715d9 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -126,6 +126,7 @@ REGISTER_PASS2(LiftAttrScope);
 REGISTER_PASS1(NarrowChannelAccess);
 REGISTER_PASS2(LowerThreadAllreduce);
 REGISTER_PASS2(LowerWarpMemory);
+REGISTER_PASS2(RemapThreadAxis);
 REGISTER_PASS2(LowerIntrin);
 REGISTER_PASS1(LowerTVMBuiltin);
 REGISTER_PASS1(CombineContextCall);
diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc
index 318e231fc67c35a57165f9a1714f19136a6cad70..0bb2dfcf0206072bc7a06591f53250163725fe71 100644
--- a/src/codegen/build_module.cc
+++ b/src/codegen/build_module.cc
@@ -78,6 +78,8 @@ Target CreateTarget(const std::string& target_name,
     t->max_num_threads = 256;
     if (t->device_name == "intel_gpu") {
       t->thread_warp_size = 16;
+      // use threadIdx.z for index
+      t->thread_warp_index = 2;
     }
   } else if (target_name == "metal" || target_name == "vulkan") {
     if (target_name == "metal") {
diff --git a/src/pass/remap_thread_axis.cc b/src/pass/remap_thread_axis.cc
new file mode 100644
index 0000000000000000000000000000000000000000..94e4819a1d71b7c768a0f964ce89fdf89d7a4a1f
--- /dev/null
+++ b/src/pass/remap_thread_axis.cc
@@ -0,0 +1,83 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file remap_thread_axis.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_visitor.h>
+#include <tvm/ir_pass.h>
+#include <unordered_map>
+
+
+namespace tvm {
+namespace ir {
+
+// Mutator to change the read pattern
+class ThreadAxisRewriter : private IRMutator {
+ public:
+  explicit ThreadAxisRewriter(
+      const std::unordered_map<std::string, IterVar>& tmap)
+      : tmap_(tmap) {
+  }
+
+  Stmt Rewrite(Stmt stmt) {
+    return Mutate(stmt);
+  }
+
+ private:
+  Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
+    if (op->attr_key == attr::thread_extent) {
+      IterVar iv(op->node.node_);
+      CHECK_NE(iv->thread_tag.length(), 0U);
+      auto it = tmap_.find(iv->thread_tag);
+      if (it != tmap_.end()) {
+        const IterVar& new_iv = it->second;
+        const Variable* v = iv->var.get();
+        if (!vmap_.count(v)) {
+          vmap_[v] = new_iv->var;
+        } else {
+          CHECK(vmap_[v].same_as(new_iv->var));
+        }
+        Stmt body = this->Mutate(op->body);
+        return AttrStmt::make(
+            new_iv, op->attr_key, op->value, body);
+      }
+    }
+    return IRMutator::Mutate_(op, stmt);
+  }
+
+  Expr Mutate_(const Variable* op, const Expr& expr) final {
+    auto it = vmap_.find(op);
+    if (it != vmap_.end()) return it->second;
+    return IRMutator::Mutate_(op, expr);
+  }
+  // The thread map
+  const std::unordered_map<std::string, IterVar>& tmap_;
+  // variable map
+  std::unordered_map<const Variable*, Var> vmap_;
+};
+
+LoweredFunc
+RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
+  std::unordered_map<std::string, IterVar> tmap;
+  for (const auto& kv : thread_map) {
+    const StringImm* str = kv.first.as<StringImm>();
+    CHECK(str != nullptr);
+    tmap[str->value] = kv.second;
+  }
+
+  CHECK_EQ(f->func_type, kDeviceFunc);
+  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
+  // replace the thread axis
+  for (size_t i = 0; i < n->thread_axis.size(); ++i) {
+    auto it = tmap.find(n->thread_axis[i]->thread_tag);
+    if (it != tmap.end()) {
+      n->thread_axis.Set(i, it->second);
+    }
+  }
+  n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
+  return LoweredFunc(n);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py
index 3ad464f75faaf59dc5736921aaccce6bf85241de..cebacdc2b9a5454699c66e03a7ca43d8adfd872a 100644
--- a/tests/python/integration/test_ewise.py
+++ b/tests/python/integration/test_ewise.py
@@ -34,9 +34,10 @@ def test_exp():
         np.testing.assert_allclose(
             b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
 
+    check_device("opencl -device=intel_gpu")
     check_device("cuda", "llvm")
     check_device("vulkan")
-    check_device("opencl")
+
 
 
 def test_log_pow_llvm():
@@ -196,8 +197,8 @@ def try_warp_memory():
 
 
 if __name__ == "__main__":
+    test_exp()
     try_warp_memory()
     test_add()
     test_log_pow_llvm()
-    test_exp()
     test_popcount()