From 5884cd01f3ba67ff1323d6e3480f9f67f61d6a4e Mon Sep 17 00:00:00 2001 From: alex-weaver <awsweaver@gmail.com> Date: Tue, 6 Feb 2018 17:39:59 +0000 Subject: [PATCH] Change TOPI ops to use C++ implementation where applicable (#357) * Updated TVM version. Implemented fix for nnvm_compiler crash on exit on windows. Changed TOPI ops from using python to using C++ where applicable. * Fix lint * Fix lint * Fix macro * Fix reshape * Update TVM to fix test fails --- nnvm/CMakeLists.txt | 1 + nnvm/Makefile | 2 +- nnvm/include/nnvm/compiler/util.h | 33 +++++++ nnvm/include/nnvm/symbolic.h | 2 +- nnvm/python/nnvm/top/nn.py | 85 ----------------- nnvm/python/nnvm/top/reduction.py | 3 - nnvm/python/nnvm/top/tensor.py | 35 ------- nnvm/python/nnvm/top/transform.py | 43 +-------- nnvm/src/compiler/compile_engine.cc | 2 +- nnvm/src/compiler/graph_hash.cc | 2 +- nnvm/src/top/nn/nn.cc | 73 +++++++++++++++ nnvm/src/top/nn/pooling.cc | 53 +++++++++++ nnvm/src/top/tensor/broadcast.cc | 22 +++++ nnvm/src/top/tensor/elemwise.cc | 138 ++++++++++++++++++++++++++++ nnvm/src/top/tensor/reduce.cc | 32 +++++++ nnvm/src/top/tensor/transform.cc | 64 +++++++++++++ 16 files changed, 421 insertions(+), 169 deletions(-) create mode 100644 nnvm/include/nnvm/compiler/util.h diff --git a/nnvm/CMakeLists.txt b/nnvm/CMakeLists.txt index a37bc5f1f..3747be06a 100644 --- a/nnvm/CMakeLists.txt +++ b/nnvm/CMakeLists.txt @@ -22,6 +22,7 @@ include_directories(BEFORE "include") include_directories("tvm/include") include_directories("tvm/dlpack/include") include_directories("tvm/HalideIR/src") +include_directories("tvm/topi/include") set(NNVM_LINKER_LIBS "") set(NNVM_COMPILER_LINKER_LIBS "") diff --git a/nnvm/Makefile b/nnvm/Makefile index 3f86eac40..4779e95b3 100644 --- a/nnvm/Makefile +++ b/nnvm/Makefile @@ -11,7 +11,7 @@ include $(config) export LDFLAGS = -pthread -lm export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC -CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src +CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src -Itvm/topi/include ifdef DMLC_CORE_PATH CFLAGS += -I$(DMLC_CORE_PATH)/include diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h new file mode 100644 index 000000000..04095c783 --- /dev/null +++ b/nnvm/include/nnvm/compiler/util.h @@ -0,0 +1,33 @@ +/*! +* Copyright (c) 2016 by Contributors +* \file util.h +* \brief Utility functions for nnvm compiler +*/ +#ifndef NNVM_COMPILER_UTIL_H_ +#define NNVM_COMPILER_UTIL_H_ + +#include <tvm/expr.h> +#include <nnvm/tuple.h> + +namespace nnvm { +namespace compiler { + +/* + * \brief Helper function to convert TShape to TVM array. Useful for + * passing data from NNVM param structures to TOPI ops. + * + * \param shape The shape to convert + * + * \return An Array of Expr, where each element is a constant int32 + */ +inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) { + tvm::Array<tvm::Expr> result; + for (auto i : shape) { + result.push_back(tvm::make_const(tvm::Int(32), i)); + } + return result; +} + +} // namespace compiler +} // namespace nnvm +#endif // NNVM_COMPILER_UTIL_H_ diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index 02b1dd947..f290a1683 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -28,7 +28,7 @@ namespace nnvm { * symbol is the final operation of a graph and thus including all the information * required (the graph) to evaluate its output value. */ -class Symbol { +class NNVM_DLL Symbol { public: /*! \brief option passed to ListAttr */ enum ListAttrOption { diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 88aafec62..f91f77dfd 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -10,59 +10,26 @@ from . import registry as reg from .registry import OpPattern # relu -@reg.register_compute("relu") -def compute_relu(attrs, inputs, _): - """Compute definition of relu""" - return topi.nn.relu(inputs[0]) - reg.register_schedule("relu", _fschedule_broadcast) reg.register_pattern("relu", OpPattern.ELEMWISE) # leaky_relu -@reg.register_compute("leaky_relu") -def compute_leaky_relu(attrs, inputs, _): - """Compute definition of relu""" - return topi.nn.leaky_relu(inputs[0], attrs.get_float("alpha")) - reg.register_schedule("leaky_relu", _fschedule_broadcast) reg.register_pattern("leaky_relu", OpPattern.ELEMWISE) # flatten -@reg.register_compute("flatten") -def compute_flatten(attrs, inputs, _): - """Compute definition of flatten""" - return topi.nn.flatten(inputs[0]) - reg.register_schedule("flatten", _fschedule_broadcast) reg.register_pattern("flatten", OpPattern.INJECTIVE) # pad -@reg.register_compute("pad") -def compute_pad(attrs, inputs, _): - """Compute definition of pad""" - pad_width = attrs.get_int_pair_tuple('pad_width') - assert len(pad_width) == len(inputs[0].shape) and \ - len(pad_width[0]) == 2, "illegal pad_width" - pad_before = [x[0] for x in pad_width] - pad_after = [x[1] for x in pad_width] - pad_value = attrs.get_int('pad_value') - return topi.nn.pad(inputs[0], pad_before, pad_after, pad_value) - reg.register_schedule("pad", _fschedule_broadcast) reg.register_pattern("pad", OpPattern.INJECTIVE) # softmax -@reg.register_compute("softmax") -def compute_softmax(attrs, inputs, _): - """Compute definition of softmax""" - axis = attrs.get_int("axis") - assert axis == -1, "only support axis == -1 for now" - return topi.nn.softmax(inputs[0]) - @reg.register_schedule("softmax") def schedule_softmax(_, outs, target): """Schedule definition of softmax""" @@ -73,13 +40,6 @@ reg.register_pattern("softmax", OpPattern.OPAQUE) # log softmax -@reg.register_compute("log_softmax") -def compute_log_softmax(attrs, inputs, _): - """Compute definition of softmax""" - axis = attrs.get_int("axis") - assert axis == -1, "only support axis == -1 for now" - return topi.nn.log_softmax(inputs[0]) - @reg.register_schedule("log_softmax") def schedule_log_softmax(_, outs, target): """Schedule definition of softmax""" @@ -91,13 +51,6 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE) # dense -@reg.register_compute("dense") -def compute_dense(attrs, inputs, _): - """Compute definition of dense""" - if attrs.get_bool("use_bias"): - return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2]) - return topi.nn.dense(inputs[0], inputs[1]) - @reg.register_schedule("dense") def schedule_dense(_, outs, target): """Schedule definition of dense""" @@ -175,18 +128,6 @@ reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) # max_pool2d -@reg.register_compute("max_pool2d") -def compute_max_pool2d(attrs, inputs, _): - """Compute definition of max_pool2d""" - pool_size = attrs.get_int_tuple("pool_size") - strides = attrs.get_int_tuple("strides") - padding = attrs.get_int_tuple("padding") - layout = attrs["layout"] - ceil_mode = attrs.get_bool("ceil_mode") - assert layout == "NCHW", "only support nchw for now" - return topi.nn.pool(inputs[0], pool_size, strides, padding, - pool_type='max', ceil_mode=ceil_mode) - @reg.register_schedule("max_pool2d") def schedule_max_pool2d(_, outs, target): """Schedule definition of max_pool2d""" @@ -197,18 +138,6 @@ reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) # avg_pool2d -@reg.register_compute("avg_pool2d") -def compute_avg_pool2d(attrs, inputs, _): - """Compute definition of avg_pool2d""" - pool_size = attrs.get_int_tuple("pool_size") - strides = attrs.get_int_tuple("strides") - padding = attrs.get_int_tuple("padding") - layout = attrs["layout"] - ceil_mode = attrs.get_bool("ceil_mode") - assert layout == "NCHW", "only support nchw for now" - return topi.nn.pool(inputs[0], pool_size, strides, padding, - pool_type='avg', ceil_mode=ceil_mode) - @reg.register_schedule("avg_pool2d") def schedule_avg_pool2d(_, outs, target): """Schedule definition of avg_pool2d""" @@ -219,13 +148,6 @@ reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) # global_max_pool2d -@reg.register_compute("global_max_pool2d") -def compute_global_max_pool2d(attrs, inputs, _): - """Compute definition of global_max_pool2d""" - layout = attrs["layout"] - assert layout == "NCHW", "only support nchw for now" - return topi.nn.global_pool(inputs[0], pool_type='max') - @reg.register_schedule("global_max_pool2d") def schedule_global_max_pool2d(_, outs, target): """Schedule definition of global_max_pool2d""" @@ -236,13 +158,6 @@ reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) # global_avg_pool2d -@reg.register_compute("global_avg_pool2d") -def compute_global_avg_pool2d(attrs, inputs, _): - """Compute definition of global_avg_pool2d""" - layout = attrs["layout"] - assert layout == "NCHW", "only support nchw for now" - return topi.nn.global_pool(inputs[0], pool_type='avg') - @reg.register_schedule("global_avg_pool2d") def schedule_global_avg_pool2d(_, outs, target): """Schedule definition of global_avg_pool2d""" diff --git a/nnvm/python/nnvm/top/reduction.py b/nnvm/python/nnvm/top/reduction.py index 7003d3a52..193ea6038 100644 --- a/nnvm/python/nnvm/top/reduction.py +++ b/nnvm/python/nnvm/top/reduction.py @@ -27,16 +27,13 @@ def _compute_reduce(f): return _compute # sum -reg.register_compute("sum", _compute_reduce(topi.sum)) reg.register_pattern("sum", OpPattern.COMM_REDUCE) reg.register_schedule("sum", _fschedule_reduce) # max -reg.register_compute("max", _compute_reduce(topi.max)) reg.register_pattern("max", OpPattern.COMM_REDUCE) reg.register_schedule("max", _fschedule_reduce) # min -reg.register_compute("min", _compute_reduce(topi.min)) reg.register_pattern("min", OpPattern.COMM_REDUCE) reg.register_schedule("min", _fschedule_reduce) diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index 7ca800e5b..565733775 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -43,132 +43,97 @@ _fschedule_broadcast = _fschedule_injective _fschedule_elemwise = _fschedule_injective # copy -reg.register_compute("copy", _compute_unary(topi.identity)) reg.register_pattern("copy", OpPattern.ELEMWISE) reg.register_schedule("copy", _fschedule_broadcast) # exp -reg.register_compute("exp", _compute_unary(topi.exp)) reg.register_pattern("exp", OpPattern.ELEMWISE) reg.register_schedule("exp", _fschedule_broadcast) # sqrt -reg.register_compute("sqrt", _compute_unary(topi.sqrt)) reg.register_pattern("sqrt", OpPattern.ELEMWISE) reg.register_schedule("sqrt", _fschedule_broadcast) # log -reg.register_compute("log", _compute_unary(topi.log)) reg.register_pattern("log", OpPattern.ELEMWISE) reg.register_schedule("log", _fschedule_broadcast) # tanh -reg.register_compute("tanh", _compute_unary(topi.tanh)) reg.register_pattern("tanh", OpPattern.ELEMWISE) reg.register_schedule("tanh", _fschedule_broadcast) # negative -reg.register_compute("negative", _compute_unary(topi.negative)) reg.register_pattern("negative", OpPattern.ELEMWISE) reg.register_schedule("negative", _fschedule_broadcast) # sigmoid -reg.register_compute("sigmoid", _compute_unary(topi.sigmoid)) reg.register_pattern("sigmoid", OpPattern.ELEMWISE) reg.register_schedule("sigmoid", _fschedule_broadcast) # add_scalar -reg.register_compute("__add_scalar__", - _compute_binary_scalar(lambda x, y: x + y)) reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__add_scalar__", _fschedule_broadcast) # sub_calar -reg.register_compute("__sub_scalar__", - _compute_binary_scalar(lambda x, y: x - y)) reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__sub_scalar__", _fschedule_broadcast) # rsub_scalar -reg.register_compute("__rsub_scalar__", - _compute_binary_scalar(lambda x, y: y - x)) reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rsub_scalar__", _fschedule_broadcast) # mul_scalar -reg.register_compute("__mul_scalar__", - _compute_binary_scalar(lambda x, y: x * y)) reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__mul_scalar__", _fschedule_broadcast) # div_scalar -reg.register_compute("__div_scalar__", - _compute_binary_scalar(lambda x, y: x / y)) reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__div_scalar__", _fschedule_broadcast) # rdiv_scalar -reg.register_compute("__rdiv_scalar__", - _compute_binary_scalar(lambda x, y: y / x)) reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) # pow_scalar -reg.register_compute("__pow_scalar__", - _compute_binary_scalar(tvm.power)) reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__pow_scalar__", _fschedule_broadcast) # rpow_scalar -reg.register_compute("__rpow_scalar__", - _compute_binary_scalar(lambda x, y: tvm.power(y, x))) reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) # elemwise_add -reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add)) reg.register_pattern("elemwise_add", OpPattern.BROADCAST) reg.register_schedule("elemwise_add", _fschedule_broadcast) # elemwise_sub -reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub)) reg.register_pattern("elemwise_sub", OpPattern.BROADCAST) reg.register_schedule("elemwise_sub", _fschedule_broadcast) # elemwise_mul -reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul)) reg.register_pattern("elemwise_mul", OpPattern.BROADCAST) reg.register_schedule("elemwise_mul", _fschedule_broadcast) # elemwise_div -reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div)) reg.register_pattern("elemwise_div", OpPattern.BROADCAST) reg.register_schedule("elemwise_div", _fschedule_broadcast) # broadcast_add -reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add)) reg.register_pattern("broadcast_add", OpPattern.BROADCAST) reg.register_schedule("broadcast_add", _fschedule_broadcast) # broadcast_sub -reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub)) reg.register_pattern("broadcast_sub", OpPattern.BROADCAST) reg.register_schedule("broadcast_sub", _fschedule_broadcast) # broadcast_mul -reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul)) reg.register_pattern("broadcast_mul", OpPattern.BROADCAST) reg.register_schedule("broadcast_mul", _fschedule_broadcast) # broadcast_div -reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div)) reg.register_pattern("broadcast_div", OpPattern.BROADCAST) reg.register_schedule("broadcast_div", _fschedule_broadcast) # broadcast_to -@reg.register_compute("broadcast_to") -def compute_broadcast_to(attrs, inputs, out_info): - """Compute definition of softmax""" - return topi.broadcast_to(inputs[0], shape=out_info[0].shape) reg.register_pattern("broadcast_to", OpPattern.BROADCAST) reg.register_schedule("broadcast_to", _fschedule_broadcast) diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index 21c10f5cc..ec1596e95 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -2,71 +2,30 @@ """Tensor transformation ops""" from __future__ import absolute_import -import topi from .tensor import _fschedule_broadcast, _fschedule_injective from . import registry as reg from .registry import OpPattern # expand_dims -@reg.register_compute("expand_dims") -def compute_expand_dims(attrs, inputs, out_info): - """Compute definition of expand_dims""" - return topi.expand_dims( - inputs[0], attrs.get_int("axis"), - num_newaxis=attrs.get_int("num_newaxis")) reg.register_pattern("expand_dims", OpPattern.BROADCAST) reg.register_schedule("expand_dims", _fschedule_broadcast) # transpose -@reg.register_compute("transpose") -def compute_transpose(attrs, inputs, out_info): - """Compute definition of transpose""" - axes = attrs.get_int_tuple("axes") - axes = tuple(axes) if axes else None - return topi.transpose(inputs[0], axes) reg.register_pattern("transpose", OpPattern.INJECTIVE) reg.register_schedule("transpose", _fschedule_injective) # reshape -@reg.register_compute("reshape") -def compute_reshape(attrs, inputs, out_info): - """Compute definition of reshape""" - oshape = out_info[0].shape - return topi.reshape(inputs[0], oshape) reg.register_pattern("reshape", OpPattern.INJECTIVE) reg.register_schedule("reshape", _fschedule_injective) -# reshape -@reg.register_compute("squeeze") -def compute_squeeze(attrs, inputs, out_info): - """Compute definition of reshape""" - axis = attrs.get_int_tuple("axis") - axis = tuple(axis) if axis else None - return topi.squeeze(inputs[0], axis) +# squeeze reg.register_pattern("squeeze", OpPattern.INJECTIVE) reg.register_schedule("squeeze", _fschedule_injective) # concatenate -@reg.register_compute("concatenate") -def compute_concatenate(attrs, inputs, out_info): - """Compute definition of concatenate""" - axis = attrs.get_int("axis") - return topi.concatenate([x for x in inputs], axis=axis) - reg.register_pattern("concatenate", OpPattern.INJECTIVE) reg.register_schedule("concatenate", _fschedule_injective) # split -@reg.register_compute("split") -def compute_split(attrs, inputs, out_info): - """Compute definition of split""" - x = attrs["indices_or_sections"] - if x.startswith("(") or x.startswith("["): - indices_or_sections = attrs.get_int_tuple("indices_or_sections") - else: - indices_or_sections = attrs.get_int("indices_or_sections") - return topi.split(inputs[0], indices_or_sections, axis=attrs.get_int("axis")) - - reg.register_pattern("split", OpPattern.INJECTIVE) reg.register_schedule("split", _fschedule_injective) diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 5902abdc6..23c4d4c9b 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -344,7 +344,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") *rv = ret; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) { p->stream << "GraphFunc(name=" << op->func_name << ", addr=" << op << ")"; diff --git a/nnvm/src/compiler/graph_hash.cc b/nnvm/src/compiler/graph_hash.cc index f493bc042..d881130f7 100644 --- a/nnvm/src/compiler/graph_hash.cc +++ b/nnvm/src/compiler/graph_hash.cc @@ -80,7 +80,7 @@ GraphKey GraphKeyNode::make(Graph graph, return GraphKey(n); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) { p->stream << "GraphKeyNode("<< op << ")"; }); diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index bf169acc7..88ed7fd58 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -3,17 +3,27 @@ * \file nn.cc * \brief Property def of nn operators. */ +#include <tvm/expr.h> +#include <tvm/packed_func_ext.h> #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> +#include <nnvm/compiler/op_attr_types.h> #include <nnvm/top/nn.h> #include "./nn_common.h" #include "../op_common.h" #include "../elemwise_op_common.h" +#include "topi/nn/dense.h" +#include "topi/nn.h" +#include "topi/nn/softmax.h" namespace nnvm { namespace top { +using tvm::Tensor; +using tvm::Array; +using nnvm::compiler::FTVMCompute; + // dense DMLC_REGISTER_PARAMETER(DenseParam); @@ -72,6 +82,21 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored. .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>) .set_attr<FInferShape>("FInferShape", DenseInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + Tensor bias_val; + Tensor* bias; + const DenseParam& param = nnvm::get<DenseParam>(attrs.parsed); + if (param.use_bias) { + bias_val = inputs[2]; + bias = &bias_val; + } else { + bias = nullptr; + } + return Array<Tensor>{ topi::nn::dense(inputs[0], inputs[1], bias) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -110,6 +135,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu) max(input, 0) )code" NNVM_ADD_FILELINE) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::relu(inputs[0], 0.0f) }; + }) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -258,6 +289,14 @@ NNVM_REGISTER_OP(softmax) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); + CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported"; + return Array<Tensor>{ topi::nn::softmax(inputs[0]) }; + }) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -306,6 +345,14 @@ NNVM_REGISTER_OP(log_softmax) .set_num_outputs(1) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed); + CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported"; + return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) }; + }) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -357,6 +404,13 @@ NNVM_REGISTER_OP(leaky_relu) .set_num_outputs(1) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed); + return Array<Tensor>{ topi::leaky_relu<float>(inputs[0], 0.0, param.alpha) }; + }) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -413,6 +467,25 @@ NNVM_REGISTER_OP(pad) .set_num_inputs(1) .set_attr<FInferShape>("FInferShape", PadInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const PadParam& param = nnvm::get<PadParam>(attrs.parsed); + auto pad_width = param.pad_width; + CHECK(pad_width.ndim() == inputs[0]->shape.size() && + pad_width[0].ndim() == 2) + << "Illegal pad_width"; + Array<tvm::Expr> pad_before; + for (size_t i = 0; i < pad_width.ndim(); ++i) { + pad_before.push_back(tvm::make_const(tvm::Int(32), pad_width[i][0])); + } + Array<tvm::Expr> pad_after; + for (size_t i = 0; i < pad_width.ndim(); ++i) { + pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1])); + } + return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after, param.pad_value) }; +}) .set_support_level(1); } // namespace top diff --git a/nnvm/src/top/nn/pooling.cc b/nnvm/src/top/nn/pooling.cc index 0f0eb6817..ae8d872de 100644 --- a/nnvm/src/top/nn/pooling.cc +++ b/nnvm/src/top/nn/pooling.cc @@ -6,13 +6,18 @@ #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> +#include <nnvm/compiler/op_attr_types.h> +#include <nnvm/compiler/util.h> #include <nnvm/top/nn.h> #include "./nn_common.h" #include "../op_common.h" #include "../elemwise_op_common.h" +#include "topi/nn/pooling.h" namespace nnvm { namespace top { +using namespace tvm; +using namespace nnvm::compiler; DMLC_REGISTER_PARAMETER(Pool2DParam); @@ -77,6 +82,20 @@ NNVM_REGISTER_OP(max_pool2d) .set_num_inputs(1) .set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); + auto pool_size = ShapeToArray(param.pool_size); + auto strides = ShapeToArray(param.strides); + auto padding = ShapeToArray(param.padding); + auto ceil_mode = param.ceil_mode; + CHECK_EQ(param.layout, kNCHW) + << "max_pool2d currently only supports NCHW layout"; + return Array<Tensor>{ + topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kMaxPool, ceil_mode) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -124,6 +143,20 @@ NNVM_REGISTER_OP(avg_pool2d) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>) .set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); + auto pool_size = ShapeToArray(param.pool_size); + auto strides = ShapeToArray(param.strides); + auto padding = ShapeToArray(param.padding); + auto ceil_mode = param.ceil_mode; + CHECK_EQ(param.layout, kNCHW) + << "avg_pool2d currently only supports NCHW layout"; + return Array<Tensor>{ + topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kAvgPool, ceil_mode) }; +}) .set_num_outputs(1) .set_num_inputs(1) .set_support_level(2); @@ -162,6 +195,16 @@ NNVM_REGISTER_OP(global_max_pool2d) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed); + CHECK_EQ(param.layout, kNCHW) + << "global_max_pool2d currently only supports NCHW layout"; + return Array<Tensor>{ + topi::nn::global_pool(inputs[0], topi::nn::kMaxPool) }; +}) .set_num_outputs(1) .set_num_inputs(1) .set_support_level(2); @@ -182,6 +225,16 @@ NNVM_REGISTER_OP(global_avg_pool2d) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed); + CHECK_EQ(param.layout, kNCHW) + << "global_avg_pool2d currently only supports NCHW layout"; + return Array<Tensor>{ + topi::nn::global_pool(inputs[0], topi::nn::kAvgPool) }; +}) .set_num_outputs(1) .set_num_inputs(1) .set_support_level(2); diff --git a/nnvm/src/top/tensor/broadcast.cc b/nnvm/src/top/tensor/broadcast.cc index 945195fd0..773281450 100644 --- a/nnvm/src/top/tensor/broadcast.cc +++ b/nnvm/src/top/tensor/broadcast.cc @@ -3,15 +3,22 @@ * \file broadcast.cc * \brief broadcast operator. */ +#include <tvm/expr.h> +#include <tvm/packed_func_ext.h> #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> +#include <nnvm/compiler/op_attr_types.h> +#include <nnvm/compiler/util.h> #include <nnvm/top/tensor.h> #include "../op_common.h" #include "../elemwise_op_common.h" +#include "topi/broadcast.h" namespace nnvm { namespace top { +using namespace tvm; +using namespace nnvm::compiler; // broadcast_to DMLC_REGISTER_PARAMETER(BroadcastToParam); @@ -67,6 +74,14 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example. .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>) .set_attr<FInferShape>("FInferShape", BroadcastToInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed); + auto shape = ShapeToArray(param.shape); + return Array<Tensor>{ topi::broadcast_to(inputs[0], shape) }; + }) .set_num_inputs(1) .set_num_outputs(1) .set_support_level(4); @@ -122,6 +137,13 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs) { \ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ }) \ + .set_attr<FTVMCompute>( \ + "FTVMCompute", [](const NodeAttrs& attrs, \ + const Array<Tensor>& inputs, \ + const Array<Tensor>& out_info) { \ + return Array<Tensor>{ \ + topi::name(inputs[0], inputs[1]) }; \ + }) \ .add_argument("lhs", "Tensor", "first input") \ .add_argument("rhs", "Tensor", "second input") diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index 73a958322..ac66e0c3f 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -6,13 +6,19 @@ #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> +#include <nnvm/compiler/op_attr_types.h> #include <nnvm/top/tensor.h> #include <cmath> #include "../op_common.h" #include "../elemwise_op_common.h" +#include "topi/broadcast.h" +#include "topi/elemwise.h" +#include "topi/tags.h" namespace nnvm { namespace top { +using namespace tvm; +using namespace nnvm::compiler; // undefined op NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__) .describe(R"code(undefined op. @@ -32,6 +38,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::sigmoid(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -56,6 +68,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(tanh) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::tanh(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -80,6 +98,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(exp) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::exp(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -100,6 +124,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(log) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::log(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -120,6 +150,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sqrt) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::sqrt(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -140,6 +176,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add) )code") .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::broadcast_add(inputs[0], inputs[1]) }; + }) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -154,6 +196,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::broadcast_sub(inputs[0], inputs[1]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -171,6 +219,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::broadcast_mul(inputs[0], inputs[1]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -190,6 +244,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div) )code" NNVM_ADD_FILELINE) .set_support_level(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::broadcast_div(inputs[0], inputs[1]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -216,6 +276,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::negative(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -232,6 +298,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::identity(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -315,12 +387,29 @@ DMLC_REGISTER_PARAMETER(ScalarParam); .set_attr_parser(ParamParser<ScalarParam>) \ .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ScalarParam>) +inline Tensor binary_scalar_op(const NodeAttrs& attrs, + const Tensor& x, + std::function<Expr(Expr, Expr)> f) { + const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed); + auto scalar_val = static_cast<float>(param.scalar); + return compute(x->shape, [&](const Array<Var>& i) { + auto scalar_const = make_const(x->dtype, scalar_val); + return f(x(i), scalar_const); + }, "tensor", topi::kElementWise); +} NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__) .describe(R"code(Tensor add scalar )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return x + y; }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -332,6 +421,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return x - y; }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -343,6 +439,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return y - x; }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -356,6 +459,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return x * y; }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -372,6 +482,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__div_scalar__) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return x / y; }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -388,6 +505,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return y / x; }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -411,6 +535,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__pow_scalar__) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return tvm::pow(x, y); }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -434,6 +565,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__) )code" NNVM_ADD_FILELINE) .set_support_level(3) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ binary_scalar_op(attrs, inputs[0], + [](Expr x, Expr y) { return tvm::pow(y, x); }) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index 429800fe5..8eac2449b 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -6,12 +6,17 @@ #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> +#include <nnvm/compiler/op_attr_types.h> +#include <nnvm/compiler/util.h> #include <nnvm/top/tensor.h> #include "../op_common.h" #include "../elemwise_op_common.h" +#include "topi/reduction.h" namespace nnvm { namespace top { +using namespace tvm; +using namespace nnvm::compiler; // reduce DMLC_REGISTER_PARAMETER(ReduceParam); @@ -127,6 +132,15 @@ Example:: [ 12. 19. 27.] )code" NNVM_ADD_FILELINE) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); + auto axis = ShapeToArray(param.axis); + return Array<Tensor>{ + topi::sum(inputs[0], axis, param.keepdims) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -145,6 +159,15 @@ NNVM_REGISTER_REDUCE_OP(max) .describe(R"code(Computes the max of array elements over given axes. )code" NNVM_ADD_FILELINE) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); + auto axis = ShapeToArray(param.axis); + return Array<Tensor>{ + topi::max(inputs[0], axis, param.keepdims) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -168,6 +191,15 @@ NNVM_REGISTER_REDUCE_OP(min) .describe(R"code(Computes the min of array elements over given axes. )code" NNVM_ADD_FILELINE) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); + auto axis = ShapeToArray(param.axis); + return Array<Tensor>{ + topi::min(inputs[0], axis, param.keepdims) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 90457389c..0bf1a91ec 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -6,13 +6,19 @@ #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> +#include <nnvm/compiler/op_attr_types.h> +#include <nnvm/compiler/util.h> #include <nnvm/top/tensor.h> #include <cctype> #include "../op_common.h" #include "../elemwise_op_common.h" +#include "topi/nn/flatten.h" +#include "topi/transform.h" namespace nnvm { namespace top { +using namespace tvm; +using namespace nnvm::compiler; // flatten inline bool FlattenInferShape(const NodeAttrs& attrs, @@ -58,6 +64,12 @@ Example:: .set_attr<FInferShape>("FInferShape", FlattenInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .add_argument("data", "Tensor", "Input data.") +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::nn::flatten(inputs[0]) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -144,6 +156,13 @@ Example:: .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>) .set_attr<FInferShape>("FInferShape", ConcatenateInferShape) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ConcatenateParam& param = nnvm::get<ConcatenateParam>(attrs.parsed); + return Array<Tensor>{ topi::concatenate(inputs, param.axis) }; +}) .set_num_outputs(1) .set_num_inputs(kVarg) .set_support_level(1); @@ -190,6 +209,13 @@ will return a new array with shape ``(2,5,3,4)``. .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_num_inputs(1) .set_num_outputs(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(attrs.parsed); + return Array<Tensor>{ topi::expand_dims(inputs[0], param.axis, param.num_newaxis) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds){ @@ -326,6 +352,22 @@ along which to split the array. .set_attr<FInferType>("FInferType", ElemwiseType<1, -1>) .set_num_inputs(1) .set_num_outputs(SplitNumOutputs) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed); + if (param.equal_split) { + return Array<Tensor>{ + topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) }; + } else { + Array<Expr> indices; + for (auto i : param.indices_or_sections) { + indices.push_back(tvm::make_const(tvm::Int(32), i)); + } + return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) }; + } +}) .set_support_level(1); // cast @@ -504,6 +546,12 @@ The significance of each is explained below: .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_num_inputs(1) .set_num_outputs(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + return Array<Tensor>{ topi::reshape(inputs[0], out_info[0]->shape) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -620,6 +668,14 @@ Examples:: .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_num_inputs(1) .set_num_outputs(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed); + auto axis = ShapeToArray(param.axis); + return Array<Tensor>{ topi::squeeze(inputs[0], axis) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { @@ -695,6 +751,14 @@ Examples:: .set_num_inputs(1) .set_num_outputs(1) .set_support_level(4) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed); + auto axes = ShapeToArray(param.axes); + return Array<Tensor>{ topi::transpose(inputs[0], axes) }; +}) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { -- GitLab