diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index ba89e5ceba58907fa59a50c2e5dc85cc5011be1c..322d77b6d032f49d04c9623444e6e55021820273 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax) .set_num_outputs(1) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) -.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) +.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_support_level(1) .set_attr<FTVMCompute>( "FTVMCompute", [](const NodeAttrs& attrs, @@ -404,7 +404,7 @@ NNVM_REGISTER_OP(log_softmax) .set_num_outputs(1) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) -.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) +.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_attr<FTVMCompute>( "FTVMCompute", [](const NodeAttrs& attrs, const Array<Tensor>& inputs, diff --git a/nnvm/tests/python/unittest/test_correct_layout.py b/nnvm/tests/python/unittest/test_correct_layout.py index c428a2f837ac6b7272fdf2d8da71c4ad17635954..6176586284a7ee3ecb3cf10d0ceca116a232befc 100644 --- a/nnvm/tests/python/unittest/test_correct_layout.py +++ b/nnvm/tests/python/unittest/test_correct_layout.py @@ -3,7 +3,6 @@ import nnvm.symbol as sym import nnvm.graph as graph from nnvm.compiler import graph_attr -# Level 1 def correct_layout(g, layout=None): if isinstance(g, nnvm.symbol.Symbol): g = graph.create(g) @@ -19,6 +18,7 @@ def correct_layout(g, layout=None): return g, ldict +# Level 1 def test_dense(): x = sym.Variable("data", shape=(10, 20)) y = sym.dense(x, units=30, name="fc") @@ -169,6 +169,19 @@ def test_flatten(): assert(ldict["y"][0] == "__undef__") +def test_softmax(): + x = sym.Variable("x", shape=(10, 20, 10, 10)) + y = sym.softmax(x, name="y") + g, ldict = correct_layout(y, "NCHW") + assert(ldict["x"][0] == "NCHW") + assert(ldict["y"][0] == "NCHW") + # second pass will insert layout transform + _, ldict = correct_layout(g, "NCHW16c") + assert(ldict["x"][0] == "NCHW16c") + assert(ldict["x_NCHW"][0] == "NCHW") + assert(ldict["y"][0] == "NCHW") + + # Level 2 def test_conv2d(): x = sym.Variable("data", shape=(1, 32, 512, 512)) @@ -327,6 +340,7 @@ if __name__ == "__main__": test_split() test_batchnorm() test_flatten() + test_softmax() test_conv2d() test_conv2d_transpose() test_max_pool2d()