From 0ec70800dde06d636b6d510c63363b3ceab5fae6 Mon Sep 17 00:00:00 2001
From: Yizhi Liu <liuyizhi@apache.org>
Date: Mon, 9 Jul 2018 13:19:54 -0700
Subject: [PATCH] fix CorrectLayout for softmax & log_softmax (#1401)

---
 nnvm/src/top/nn/nn.cc                            |  4 ++--
 .../tests/python/unittest/test_correct_layout.py | 16 +++++++++++++++-
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index ba89e5ceb..322d77b6d 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax)
 .set_num_outputs(1)
 .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
+.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
 .set_support_level(1)
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const NodeAttrs& attrs,
@@ -404,7 +404,7 @@ NNVM_REGISTER_OP(log_softmax)
 .set_num_outputs(1)
 .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
+.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const NodeAttrs& attrs,
                     const Array<Tensor>& inputs,
diff --git a/nnvm/tests/python/unittest/test_correct_layout.py b/nnvm/tests/python/unittest/test_correct_layout.py
index c428a2f83..617658628 100644
--- a/nnvm/tests/python/unittest/test_correct_layout.py
+++ b/nnvm/tests/python/unittest/test_correct_layout.py
@@ -3,7 +3,6 @@ import nnvm.symbol as sym
 import nnvm.graph as graph
 from nnvm.compiler import graph_attr
 
-# Level 1
 def correct_layout(g, layout=None):
     if isinstance(g, nnvm.symbol.Symbol):
         g = graph.create(g)
@@ -19,6 +18,7 @@ def correct_layout(g, layout=None):
     return g, ldict
 
 
+# Level 1
 def test_dense():
     x = sym.Variable("data", shape=(10, 20))
     y = sym.dense(x, units=30, name="fc")
@@ -169,6 +169,19 @@ def test_flatten():
     assert(ldict["y"][0] == "__undef__")
 
 
+def test_softmax():
+    x = sym.Variable("x", shape=(10, 20, 10, 10))
+    y = sym.softmax(x, name="y")
+    g, ldict = correct_layout(y, "NCHW")
+    assert(ldict["x"][0] == "NCHW")
+    assert(ldict["y"][0] == "NCHW")
+    # second pass will insert layout transform
+    _, ldict = correct_layout(g, "NCHW16c")
+    assert(ldict["x"][0] == "NCHW16c")
+    assert(ldict["x_NCHW"][0] == "NCHW")
+    assert(ldict["y"][0] == "NCHW")
+
+
 # Level 2
 def test_conv2d():
     x = sym.Variable("data", shape=(1, 32, 512, 512))
@@ -327,6 +340,7 @@ if __name__ == "__main__":
     test_split()
     test_batchnorm()
     test_flatten()
+    test_softmax()
     test_conv2d()
     test_conv2d_transpose()
     test_max_pool2d()
-- 
GitLab