diff --git a/nnvm/include/nnvm/compiler/op_attr_types.h b/nnvm/include/nnvm/compiler/op_attr_types.h
index 231e85093433beea70d503861623c4a05214235a..497a520db78e4e50a245585cc243173bab0b8ede 100644
--- a/nnvm/include/nnvm/compiler/op_attr_types.h
+++ b/nnvm/include/nnvm/compiler/op_attr_types.h
@@ -80,11 +80,14 @@ using FTVMSchedule = std::function<
  * \param attrs The attribute of the original node.
  * \param inputs The input symbols of the original node.
  * \param tinfos The inferred shape and dtype of the inputs.
+ * \param ret The replaced operator.
+ * \return Whether to replace current operator.
  */
 using FTVMAlterOpLayout = std::function<
-  Symbol(const NodeAttrs& attrs,
-         const Symbol& inputs,
-         const Array<Tensor>& tinfos)>;
+  bool(const NodeAttrs& attrs,
+       const Symbol& inputs,
+       const Array<Tensor>& tinfos,
+       Symbol* ret)>;
 
 /*!
  * \brief Transform from normal operator to vectorized operator
diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py
index 80e2fb698163d771ae8a4d1fd2856e439d4acf3a..8135474c9c67e432056960cfd6ffee7b34338452 100644
--- a/nnvm/python/nnvm/top/nn.py
+++ b/nnvm/python/nnvm/top/nn.py
@@ -120,6 +120,10 @@ def schedule_conv2d(attrs, outs, target):
             return topi.generic.schedule_conv2d_nhwc(outs)
         return topi.generic.schedule_depthwise_conv2d_nchw(outs)
 
+@reg.register_alter_op_layout("conv2d")
+def alter_conv2d_layout(attrs, inputs, tinfos):
+    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos)
+
 reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # convolution NCHWc
diff --git a/nnvm/src/compiler/alter_op_layout.cc b/nnvm/src/compiler/alter_op_layout.cc
index 369338f19ee2599151eee8e7a1d545d52f5cd4da..bf28df3d04f85a9ca084a2022cf729d29610fee1 100644
--- a/nnvm/src/compiler/alter_op_layout.cc
+++ b/nnvm/src/compiler/alter_op_layout.cc
@@ -103,9 +103,11 @@ Graph AlterOpLayout(const Graph& src) {
       tensor_infos.push_back(op_output_tinfos[input.index]);
     }
     // callback registered function to get a new operator.
-    auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos);
-    *ret = op.outputs;
-    return true;
+    Symbol op;
+    bool do_alter =
+      fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
+    if (do_alter) *ret = op.outputs;
+    return do_alter;
   };
 
   Graph ret = nnvm::compiler::GraphTransform(src, transform);
diff --git a/nnvm/src/compiler/fold_scale_axis.cc b/nnvm/src/compiler/fold_scale_axis.cc
index f9524eb8ed345e41cdc9fcce151af922f46417b4..28796a2b0bcd71e81a1e11a9731847cf7d1e0bb8 100644
--- a/nnvm/src/compiler/fold_scale_axis.cc
+++ b/nnvm/src/compiler/fold_scale_axis.cc
@@ -466,8 +466,8 @@ bool Conv2DScaleAxisBackward(
   using top::Conv2DParam;
   const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
   if (out_info.kind != kPending) return false;
-  // only optimize for nchw for now
-  if (param.layout == "NCHW" && out_info.axis == 1) {
+  // only optimize for kernel layout OIHW for now
+  if (param.kernel_layout == "OIHW" && out_info.axis == 1) {
     (*in_axis)[1].kind = kMulConsumer;
     (*in_axis)[1].axis = 0;
     (*in_axis)[1].source = out_info.source;
@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward(
   const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
   if ((*in_info)[0].kind != kPending) return false;
   // only optimize for nchw for now
-  if (param.layout == "NCHW" && (*in_info)[0].axis == 1) {
+  if (param.kernel_layout == "OIHW" && (*in_info)[0].axis == 1) {
     (*in_info)[1].kind = kMulConsumer;
     (*in_info)[1].axis = 1;
     (*in_info)[1].source = (*in_info)[0].source;
diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc
index 2587534d7d07cd2b6280249576e1d55d6bf947ed..d549f9e2004f7cfa811e4d3c23146577d2fc68dc 100644
--- a/nnvm/src/compiler/packed_func_ext.cc
+++ b/nnvm/src/compiler/packed_func_ext.cc
@@ -70,12 +70,17 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
   Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
   auto fpack = [f](const NodeAttrs& attrs,
                    const Symbol& inputs,
-                   const Array<Tensor>& tinfos) {
+                   const Array<Tensor>& tinfos,
+                   Symbol* ret_symbol) {
     TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos);
+    if (ret.type_code() == TVMTypeCode::kNull) {
+      return false;
+    }
     CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
       << " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
       << ") but get code = " << ret.type_code();
-    return *(static_cast<Symbol*>(ret.value().v_handle));
+    *ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle));
+    return true;
   };
   op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]);
 });
diff --git a/nnvm/src/top/nn/nn_common.h b/nnvm/src/top/nn/nn_common.h
index 49a020348485088f159681f9f5d8c080f5ceff84..4dc9f7db54c023ac6c22caa61cd1e5e328d2ca9d 100644
--- a/nnvm/src/top/nn/nn_common.h
+++ b/nnvm/src/top/nn/nn_common.h
@@ -75,7 +75,7 @@ inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout&
         CHECK_GT(dst_factor, 0);
         CHECK_LE(dst_factor, src_dim_size) << "Converting " << src
                                            << " from " << src_layout
-                                           << " to " << dst_factor
+                                           << " to " << dst_layout
                                            << ": cannot split dimension size of "
                                            << src_dim_size << " by " << dst_factor;
         dst[dst_major_pos] /= dst_factor;
diff --git a/nnvm/tests/python/compiler/test_alter_op_layout.py b/nnvm/tests/python/compiler/test_alter_op_layout.py
index bfda1807dc21b9d6d4047fb56adec42d617ce282..0fbf5ad3b479e3bfb99b47cb3e437f2a2b02d9cd 100644
--- a/nnvm/tests/python/compiler/test_alter_op_layout.py
+++ b/nnvm/tests/python/compiler/test_alter_op_layout.py
@@ -32,7 +32,7 @@ def test_alter_conv2d_layout():
     g = g.apply(["InferShape", "InferType"])
     layouts_origin = get_layouts(g)
 
-    @reg.register_alter_op_layout("conv2d")
+    @reg.register_alter_op_layout("conv2d", level=100)
     def alter_conv2d_layout(attrs, inputs, tinfos):
         new_attrs = {k : attrs[k] for k in attrs.keys()}
         new_attrs["layout"] = "NCHW16c"