diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 8d53e27892bc876adfdd658a839518152069d103..e30cf8ba2ccf6ca2397f8989422f519a2652997e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -9,7 +9,6 @@ from ..op import OpPattern, schedule_injective reg.register_schedule("nn.relu", schedule_injective) reg.register_pattern("nn.relu", OpPattern.ELEMWISE) - @reg.register_schedule("nn.softmax") def schedule_softmax(_, outputs, target): """Schedule definition of softmax""" @@ -19,6 +18,15 @@ def schedule_softmax(_, outputs, target): reg.register_pattern("nn.softmax", OpPattern.OPAQUE) +@reg.register_schedule("nn.log_softmax") +def schedule_log_softmax(_, outputs, target): + """Schedule definition of log_softmax""" + with target: + return topi.generic.schedule_softmax(outputs) + +reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) + + # dense @reg.register_compute("nn.dense") def compute_dense(attrs, inputs, out_type, target): diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 16b65aeeab7f3ec7984857434c84727c1de72e33..dfa68197819b99a1141cc35c73678f2437065e2c 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -291,7 +291,18 @@ RELAY_REGISTER_OP("nn.log_softmax") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) -.add_type_rel("Identity", IdentityRel); +.add_type_rel("Identity", IdentityRel) +.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, + const Array<Tensor>& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as<SoftmaxAttrs>(); + CHECK(param != nullptr); + CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1) + << "log_softmax currently only works on last dimension"; + return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) }; +}); + // BatchFlatten diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 53de7aa262797f92e8f315e617bda6de7c623c7c..35844ddd4a3fbcf809f8b2fd2848573a7d00753a 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -137,12 +137,19 @@ def test_softmax(): def test_log_softmax(): - n, d = tvm.var("n"), tvm.var("d") - x = relay.var("x", shape=(n, d)) - y = relay.nn.log_softmax(x, axis=0) + shape = (10, 4) + x = relay.var("x", shape=shape) + y = relay.nn.log_softmax(x, axis=1) assert "nn.log_softmax" in y.astext() yy = relay.ir_pass.infer_type(y) - assert yy.checked_type == relay.TensorType((n, d)) + assert yy.checked_type == relay.TensorType(shape) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype("float32") + ref_res = topi.testing.log_softmax_python(x_data) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) def test_concatenate():