From 59a8d099360e98389ab98e018213fc4dfbf6d07a Mon Sep 17 00:00:00 2001
From: nhynes <nhynes@berkeley.edu>
Date: Mon, 25 Jun 2018 20:41:01 -0700
Subject: [PATCH] [NNVM][TOPI] Add FTVMCompute for matmul (#1239)

---
 nnvm/python/nnvm/top/nn.py       |  3 +++
 nnvm/src/top/tensor/matrix_op.cc | 13 +++++++++++++
 topi/include/topi/nn.h           |  6 +++---
 topi/include/topi/tags.h         |  2 +-
 4 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py
index 2432cc84f..78253bc5b 100644
--- a/nnvm/python/nnvm/top/nn.py
+++ b/nnvm/python/nnvm/top/nn.py
@@ -73,6 +73,9 @@ def schedule_dense(_, outs, target):
 
 reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+#matmul
+reg.register_pattern("matmul", OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_schedule("matmul", _fschedule_injective)
 
 # conv2d
 @reg.register_compute("conv2d")
diff --git a/nnvm/src/top/tensor/matrix_op.cc b/nnvm/src/top/tensor/matrix_op.cc
index d28097b10..c881e683a 100644
--- a/nnvm/src/top/tensor/matrix_op.cc
+++ b/nnvm/src/top/tensor/matrix_op.cc
@@ -3,9 +3,11 @@
  * \file matrix_op.cc
  * \brief Matrix operators
  */
+#include <topi/nn.h>
 #include <nnvm/op.h>
 #include <nnvm/node.h>
 #include <nnvm/op_attr_types.h>
+#include <nnvm/compiler/op_attr_types.h>
 #include <nnvm/top/tensor.h>
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
@@ -13,6 +15,8 @@
 namespace nnvm {
 namespace top {
 
+using namespace nnvm::compiler;
+
 DMLC_REGISTER_PARAMETER(MatMulParam);
 
 inline bool DotShape(const nnvm::NodeAttrs& attrs,
@@ -93,6 +97,15 @@ NNVM_REGISTER_OP(matmul)
 .set_attr<FInferShape>("FInferShape", DotShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
 .set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
+    return Array<Tensor>{
+      topi::matmul(inputs[0], inputs[1], param.transpose_a, param.transpose_b)
+    };
+  })
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h
index 2459eb515..ee3101c4c 100644
--- a/topi/include/topi/nn.h
+++ b/topi/include/topi/nn.h
@@ -214,14 +214,14 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
  * \param name The name of the operation
  * \param tag The tag to mark the operation
  *
- * \return A Tensor whose op member is the matmult operation
+ * \return A Tensor whose op member is the matmul operation
  */
-inline tvm::Tensor matmult(const tvm::Tensor& A,
+inline tvm::Tensor matmul(const tvm::Tensor& A,
                            const tvm::Tensor& B,
                            bool trans_a = false,
                            bool trans_b = false,
                            std::string name = "tensor",
-                           std::string tag = kMatMult) {
+                           std::string tag = kMatMul) {
   tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
                                      B->shape[trans_b ? 0 : 1]};
   auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
diff --git a/topi/include/topi/tags.h b/topi/include/topi/tags.h
index 8ba9955be..8c92644d9 100644
--- a/topi/include/topi/tags.h
+++ b/topi/include/topi/tags.h
@@ -15,7 +15,7 @@ constexpr auto kInjective = "injective";
 constexpr auto kCommReduce = "comm_reduce";
 constexpr auto kCommReduceIdx = "comm_reduce_idx";
 constexpr auto kBroadcast = "broadcast";
-constexpr auto kMatMult = "matmult";
+constexpr auto kMatMul = "matmul";
 constexpr auto kConv2dNCHW = "conv2d_nchw";
 constexpr auto kConv2dHWCN = "conv2d_hwcn";
 constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
-- 
GitLab