From 59a8d099360e98389ab98e018213fc4dfbf6d07a Mon Sep 17 00:00:00 2001 From: nhynes <nhynes@berkeley.edu> Date: Mon, 25 Jun 2018 20:41:01 -0700 Subject: [PATCH] [NNVM][TOPI] Add FTVMCompute for matmul (#1239) --- nnvm/python/nnvm/top/nn.py | 3 +++ nnvm/src/top/tensor/matrix_op.cc | 13 +++++++++++++ topi/include/topi/nn.h | 6 +++--- topi/include/topi/tags.h | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 2432cc84f..78253bc5b 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -73,6 +73,9 @@ def schedule_dense(_, outs, target): reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE) +#matmul +reg.register_pattern("matmul", OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_schedule("matmul", _fschedule_injective) # conv2d @reg.register_compute("conv2d") diff --git a/nnvm/src/top/tensor/matrix_op.cc b/nnvm/src/top/tensor/matrix_op.cc index d28097b10..c881e683a 100644 --- a/nnvm/src/top/tensor/matrix_op.cc +++ b/nnvm/src/top/tensor/matrix_op.cc @@ -3,9 +3,11 @@ * \file matrix_op.cc * \brief Matrix operators */ +#include <topi/nn.h> #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> +#include <nnvm/compiler/op_attr_types.h> #include <nnvm/top/tensor.h> #include "../op_common.h" #include "../elemwise_op_common.h" @@ -13,6 +15,8 @@ namespace nnvm { namespace top { +using namespace nnvm::compiler; + DMLC_REGISTER_PARAMETER(MatMulParam); inline bool DotShape(const nnvm::NodeAttrs& attrs, @@ -93,6 +97,15 @@ NNVM_REGISTER_OP(matmul) .set_attr<FInferShape>("FInferShape", DotShape) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed); + return Array<Tensor>{ + topi::matmul(inputs[0], inputs[1], param.transpose_a, param.transpose_b) + }; + }) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 2459eb515..ee3101c4c 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -214,14 +214,14 @@ inline tvm::Tensor pad(const tvm::Tensor& t, * \param name The name of the operation * \param tag The tag to mark the operation * - * \return A Tensor whose op member is the matmult operation + * \return A Tensor whose op member is the matmul operation */ -inline tvm::Tensor matmult(const tvm::Tensor& A, +inline tvm::Tensor matmul(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a = false, bool trans_b = false, std::string name = "tensor", - std::string tag = kMatMult) { + std::string tag = kMatMul) { tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); diff --git a/topi/include/topi/tags.h b/topi/include/topi/tags.h index 8ba9955be..8c92644d9 100644 --- a/topi/include/topi/tags.h +++ b/topi/include/topi/tags.h @@ -15,7 +15,7 @@ constexpr auto kInjective = "injective"; constexpr auto kCommReduce = "comm_reduce"; constexpr auto kCommReduceIdx = "comm_reduce_idx"; constexpr auto kBroadcast = "broadcast"; -constexpr auto kMatMult = "matmult"; +constexpr auto kMatMul = "matmul"; constexpr auto kConv2dNCHW = "conv2d_nchw"; constexpr auto kConv2dHWCN = "conv2d_hwcn"; constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw"; -- GitLab