From a9313787ab9431a191d7c43ab7f57b14eb9302ae Mon Sep 17 00:00:00 2001
From: larrywyang <larrywyang@gmail.com>
Date: Tue, 5 Jun 2018 13:51:38 -0700
Subject: [PATCH] [WIP] [NNVM] Fix softmax gradient (#1201)

[NNVM] Fix softmax gradient
---
 nnvm/src/top/nn/nn.cc                         | 38 ++++++++++---------
 nnvm/tests/python/compiler/test_top_level1.py |  4 +-
 2 files changed, 23 insertions(+), 19 deletions(-)

diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index e1dbe3cfc..cedfb2108 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -366,22 +366,24 @@ NNVM_REGISTER_OP(softmax)
     // [ ...                  ,-ynyn + yn]
     //
     // grad_x =
-    // [-y1*(ograd1*y1 - 1 + ograd2*y2 + ..., -y2*(ograd1*y1 - 1 + ograd2*y2, ..., ...]]
+    // [-y1*(ograd1*y1 - ograd1 + ograd2*y2 + ...),
+    //  -y2*(ograd1*y1 - ograd2 + ograd2*y2 + ...),
+    //  ...
+    //  -yn*(ograd1*y1 - ogradn + ograd2*y2 + ...)]
 
     // grad_x = ograd elemwise_mul output
     // grad_x = sum(grad_x, keepdim, axis)
     // grad_x = grad_x broadcast_mul output
     // grad_x = neg grad_x
-    // grad_x = grad_x + output
+    // grad_x = grad_x + ograd elemwise_mul output
     const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed);
     NodeEntry output =  NodeEntry{n, 0, 0};
     NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output});
     NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0},
                               {{"axis", std::to_string(param.axis)}, {"keepdims", "true"}});
     NodeEntry sub2 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub2", {sub1, output});
-    NodeEntry sub3 = MakeNode("negative", n->attrs.name + "_grad_sub3", {sub2});
     return std::vector<NodeEntry> {
-      MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, output})
+      MakeNode("elemwise_sub", n->attrs.name + "_grad", {sub0, sub2})
     };
 });
 
@@ -414,31 +416,33 @@ NNVM_REGISTER_OP(log_softmax)
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
-    // grad_x = grad_y dot jacobian of softmax
+    // grad_x = grad_y dot jacobian of logsoftmax
     //
-    // jacobian of softmax
+    // jacobian of logsoftmax
     // [-y1 + 1, -y2,        ...    ]
     // [ ...   , -y2 + 1,    ...    ]
     // [ ...                 ...    ]
     // [ ...                ,-yn + 1]
     //
     // grad_x =
-    // [-(ograd1*y1 - 1 + ograd2*y2 + ..., -(ograd1*y1 - 1 + ograd2*y2, ..., ...]]
-
-    // grad_x = ograd elemwise_mul output
-    // grad_x = sum(grad_x, keepdim, axis)
+    // [ograd1 - exp(y1)*(ograd1 + ... + ogradn),
+    //  ograd2 - exp(y2)*(ograd1 + ... + ogradn),
+    //  ...
+    //  ogradn - exp(yn)*(ograd1 + ... + ogradn)]
+
+    // grad_x = sum(ograd, keepdim, axis)
+    // sigma = exp(output)
+    // grad_x = grad_x elemwise_mul sigma
     // grad_x = neg grad_x
-    // grad_x = grad_x + ones_like(grad_x)
+    // grad_x = grad_x + ograd
     const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed);
     NodeEntry output =  NodeEntry{n, 0, 0};
-    NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output});
-    NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0},
+    NodeEntry sub0 = MakeNode("sum", n->attrs.name + "_grad_sub0", {ograds[0]},
                               {{"axis", std::to_string(param.axis)}, {"keepdims", "true"}});
-    NodeEntry sub2 = MakeNode("full_like", n->attrs.name + "_grad_sub2", {n->inputs[0]},
-                              {{"fill_value", "-1"}});
-    NodeEntry sub3 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub3", {sub1, sub2});
+    NodeEntry sub1 = MakeNode("exp", n->attrs.name + "_grad_sub1", {output});
+    NodeEntry sub2 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub2", {sub0, sub1});
     return std::vector<NodeEntry> {
-      MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, ograds[0]})
+      MakeNode("elemwise_sub", n->attrs.name + "_grad", {ograds[0], sub2})
     };
 })
 .set_support_level(1);
diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py
index ebf8c6ce6..3058d6ccf 100644
--- a/nnvm/tests/python/compiler/test_top_level1.py
+++ b/nnvm/tests/python/compiler/test_top_level1.py
@@ -217,7 +217,7 @@ def test_softmax():
     dtype = "float32"
     dshape = (10, 1000)
     inputs = [('x', dshape, x)]
-    helper(y, inputs, dtype, forward), backward
+    helper(y, inputs, dtype, forward, backward)
 
 
 def test_log_softmax():
@@ -229,7 +229,7 @@ def test_log_softmax():
 
     def backward(head_grads, x):
         y = topi.testing.log_softmax_python(x)
-        grad = head_grads - np.sum(y * head_grads, axis=1, keepdims=True)
+        grad = head_grads - np.exp(y) * np.sum(head_grads, axis=1, keepdims=True)
         return [grad]
 
     dtype = "float32"
-- 
GitLab