diff --git a/nnvm/include/nnvm/compiler/op_attr_types.h b/nnvm/include/nnvm/compiler/op_attr_types.h index 231e85093433beea70d503861623c4a05214235a..497a520db78e4e50a245585cc243173bab0b8ede 100644 --- a/nnvm/include/nnvm/compiler/op_attr_types.h +++ b/nnvm/include/nnvm/compiler/op_attr_types.h @@ -80,11 +80,14 @@ using FTVMSchedule = std::function< * \param attrs The attribute of the original node. * \param inputs The input symbols of the original node. * \param tinfos The inferred shape and dtype of the inputs. + * \param ret The replaced operator. + * \return Whether to replace current operator. */ using FTVMAlterOpLayout = std::function< - Symbol(const NodeAttrs& attrs, - const Symbol& inputs, - const Array<Tensor>& tinfos)>; + bool(const NodeAttrs& attrs, + const Symbol& inputs, + const Array<Tensor>& tinfos, + Symbol* ret)>; /*! * \brief Transform from normal operator to vectorized operator diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 80e2fb698163d771ae8a4d1fd2856e439d4acf3a..8135474c9c67e432056960cfd6ffee7b34338452 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -120,6 +120,10 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_conv2d_nhwc(outs) return topi.generic.schedule_depthwise_conv2d_nchw(outs) +@reg.register_alter_op_layout("conv2d") +def alter_conv2d_layout(attrs, inputs, tinfos): + return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos) + reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) # convolution NCHWc diff --git a/nnvm/src/compiler/alter_op_layout.cc b/nnvm/src/compiler/alter_op_layout.cc index 369338f19ee2599151eee8e7a1d545d52f5cd4da..bf28df3d04f85a9ca084a2022cf729d29610fee1 100644 --- a/nnvm/src/compiler/alter_op_layout.cc +++ b/nnvm/src/compiler/alter_op_layout.cc @@ -103,9 +103,11 @@ Graph AlterOpLayout(const Graph& src) { tensor_infos.push_back(op_output_tinfos[input.index]); } // callback registered function to get a new operator. - auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos); - *ret = op.outputs; - return true; + Symbol op; + bool do_alter = + fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op); + if (do_alter) *ret = op.outputs; + return do_alter; }; Graph ret = nnvm::compiler::GraphTransform(src, transform); diff --git a/nnvm/src/compiler/fold_scale_axis.cc b/nnvm/src/compiler/fold_scale_axis.cc index f9524eb8ed345e41cdc9fcce151af922f46417b4..28796a2b0bcd71e81a1e11a9731847cf7d1e0bb8 100644 --- a/nnvm/src/compiler/fold_scale_axis.cc +++ b/nnvm/src/compiler/fold_scale_axis.cc @@ -466,8 +466,8 @@ bool Conv2DScaleAxisBackward( using top::Conv2DParam; const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); if (out_info.kind != kPending) return false; - // only optimize for nchw for now - if (param.layout == "NCHW" && out_info.axis == 1) { + // only optimize for kernel layout OIHW for now + if (param.kernel_layout == "OIHW" && out_info.axis == 1) { (*in_axis)[1].kind = kMulConsumer; (*in_axis)[1].axis = 0; (*in_axis)[1].source = out_info.source; @@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward( const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); if ((*in_info)[0].kind != kPending) return false; // only optimize for nchw for now - if (param.layout == "NCHW" && (*in_info)[0].axis == 1) { + if (param.kernel_layout == "OIHW" && (*in_info)[0].axis == 1) { (*in_info)[1].kind = kMulConsumer; (*in_info)[1].axis = 1; (*in_info)[1].source = (*in_info)[0].source; diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index 2587534d7d07cd2b6280249576e1d55d6bf947ed..d549f9e2004f7cfa811e4d3c23146577d2fc68dc 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -70,12 +70,17 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout") Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); auto fpack = [f](const NodeAttrs& attrs, const Symbol& inputs, - const Array<Tensor>& tinfos) { + const Array<Tensor>& tinfos, + Symbol* ret_symbol) { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos); + if (ret.type_code() == TVMTypeCode::kNull) { + return false; + } CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code) << " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code << ") but get code = " << ret.type_code(); - return *(static_cast<Symbol*>(ret.value().v_handle)); + *ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle)); + return true; }; op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]); }); diff --git a/nnvm/src/top/nn/nn_common.h b/nnvm/src/top/nn/nn_common.h index 49a020348485088f159681f9f5d8c080f5ceff84..4dc9f7db54c023ac6c22caa61cd1e5e328d2ca9d 100644 --- a/nnvm/src/top/nn/nn_common.h +++ b/nnvm/src/top/nn/nn_common.h @@ -75,7 +75,7 @@ inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout& CHECK_GT(dst_factor, 0); CHECK_LE(dst_factor, src_dim_size) << "Converting " << src << " from " << src_layout - << " to " << dst_factor + << " to " << dst_layout << ": cannot split dimension size of " << src_dim_size << " by " << dst_factor; dst[dst_major_pos] /= dst_factor; diff --git a/nnvm/tests/python/compiler/test_alter_op_layout.py b/nnvm/tests/python/compiler/test_alter_op_layout.py index bfda1807dc21b9d6d4047fb56adec42d617ce282..0fbf5ad3b479e3bfb99b47cb3e437f2a2b02d9cd 100644 --- a/nnvm/tests/python/compiler/test_alter_op_layout.py +++ b/nnvm/tests/python/compiler/test_alter_op_layout.py @@ -32,7 +32,7 @@ def test_alter_conv2d_layout(): g = g.apply(["InferShape", "InferType"]) layouts_origin = get_layouts(g) - @reg.register_alter_op_layout("conv2d") + @reg.register_alter_op_layout("conv2d", level=100) def alter_conv2d_layout(attrs, inputs, tinfos): new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs["layout"] = "NCHW16c"