From b71edd76bdb73cc188e4e985e987b8e083b2bcf2 Mon Sep 17 00:00:00 2001
From: Animesh Jain <anijain@umich.edu>
Date: Mon, 19 Nov 2018 14:23:37 -0500
Subject: [PATCH] Relay Op sprint (part 2) - Level 1 - log_softmax (#2128)

---
 python/tvm/relay/op/nn/_nn.py        | 10 +++++++++-
 src/relay/op/nn/nn.cc                | 13 ++++++++++++-
 tests/python/relay/test_op_level1.py | 15 +++++++++++----
 3 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 8d53e2789..e30cf8ba2 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -9,7 +9,6 @@ from ..op import OpPattern, schedule_injective
 reg.register_schedule("nn.relu", schedule_injective)
 reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
 
-
 @reg.register_schedule("nn.softmax")
 def schedule_softmax(_, outputs, target):
     """Schedule definition of softmax"""
@@ -19,6 +18,15 @@ def schedule_softmax(_, outputs, target):
 reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
 
 
+@reg.register_schedule("nn.log_softmax")
+def schedule_log_softmax(_, outputs, target):
+    """Schedule definition of log_softmax"""
+    with target:
+        return topi.generic.schedule_softmax(outputs)
+
+reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
+
+
 # dense
 @reg.register_compute("nn.dense")
 def compute_dense(attrs, inputs, out_type, target):
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 16b65aeea..dfa681978 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -291,7 +291,18 @@ RELAY_REGISTER_OP("nn.log_softmax")
 .set_num_inputs(1)
 .add_argument("data", "Tensor", "The input tensor.")
 .set_support_level(1)
-.add_type_rel("Identity", IdentityRel);
+.add_type_rel("Identity", IdentityRel)
+.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
+                                         const Array<Tensor>& inputs,
+                                         const Type& out_type,
+                                         const Target& target) {
+  const auto* param = attrs.as<SoftmaxAttrs>();
+  CHECK(param != nullptr);
+  CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
+      << "log_softmax currently only works on last dimension";
+  return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
+});
+
 
 
 // BatchFlatten
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index 53de7aa26..35844ddd4 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -137,12 +137,19 @@ def test_softmax():
 
 
 def test_log_softmax():
-    n, d = tvm.var("n"), tvm.var("d")
-    x = relay.var("x", shape=(n, d))
-    y = relay.nn.log_softmax(x, axis=0)
+    shape = (10, 4)
+    x = relay.var("x", shape=shape)
+    y = relay.nn.log_softmax(x, axis=1)
     assert "nn.log_softmax" in y.astext()
     yy = relay.ir_pass.infer_type(y)
-    assert yy.checked_type == relay.TensorType((n, d))
+    assert yy.checked_type == relay.TensorType(shape)
+    func = relay.Function([x], y)
+    x_data = np.random.uniform(size=shape).astype("float32")
+    ref_res = topi.testing.log_softmax_python(x_data)
+    for target, ctx in ctx_list():
+        intrp = relay.create_executor("graph", ctx=ctx, target=target)
+        op_res = intrp.evaluate(func)(x_data)
+        np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
 
 def test_concatenate():
-- 
GitLab