From 5884cd01f3ba67ff1323d6e3480f9f67f61d6a4e Mon Sep 17 00:00:00 2001
From: alex-weaver <awsweaver@gmail.com>
Date: Tue, 6 Feb 2018 17:39:59 +0000
Subject: [PATCH] Change TOPI ops to use C++ implementation where applicable
 (#357)

* Updated TVM version. Implemented fix for nnvm_compiler crash on exit on windows. Changed TOPI ops from using python to using C++ where applicable.

* Fix lint

* Fix lint

* Fix macro

* Fix reshape

* Update TVM to fix test fails
---
 nnvm/CMakeLists.txt                 |   1 +
 nnvm/Makefile                       |   2 +-
 nnvm/include/nnvm/compiler/util.h   |  33 +++++++
 nnvm/include/nnvm/symbolic.h        |   2 +-
 nnvm/python/nnvm/top/nn.py          |  85 -----------------
 nnvm/python/nnvm/top/reduction.py   |   3 -
 nnvm/python/nnvm/top/tensor.py      |  35 -------
 nnvm/python/nnvm/top/transform.py   |  43 +--------
 nnvm/src/compiler/compile_engine.cc |   2 +-
 nnvm/src/compiler/graph_hash.cc     |   2 +-
 nnvm/src/top/nn/nn.cc               |  73 +++++++++++++++
 nnvm/src/top/nn/pooling.cc          |  53 +++++++++++
 nnvm/src/top/tensor/broadcast.cc    |  22 +++++
 nnvm/src/top/tensor/elemwise.cc     | 138 ++++++++++++++++++++++++++++
 nnvm/src/top/tensor/reduce.cc       |  32 +++++++
 nnvm/src/top/tensor/transform.cc    |  64 +++++++++++++
 16 files changed, 421 insertions(+), 169 deletions(-)
 create mode 100644 nnvm/include/nnvm/compiler/util.h

diff --git a/nnvm/CMakeLists.txt b/nnvm/CMakeLists.txt
index a37bc5f1f..3747be06a 100644
--- a/nnvm/CMakeLists.txt
+++ b/nnvm/CMakeLists.txt
@@ -22,6 +22,7 @@ include_directories(BEFORE "include")
 include_directories("tvm/include")
 include_directories("tvm/dlpack/include")
 include_directories("tvm/HalideIR/src")
+include_directories("tvm/topi/include")
 
 set(NNVM_LINKER_LIBS "")
 set(NNVM_COMPILER_LINKER_LIBS "")
diff --git a/nnvm/Makefile b/nnvm/Makefile
index 3f86eac40..4779e95b3 100644
--- a/nnvm/Makefile
+++ b/nnvm/Makefile
@@ -11,7 +11,7 @@ include $(config)
 
 export LDFLAGS = -pthread -lm
 export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
-CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src
+CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src -Itvm/topi/include
 
 ifdef DMLC_CORE_PATH
   CFLAGS += -I$(DMLC_CORE_PATH)/include
diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h
new file mode 100644
index 000000000..04095c783
--- /dev/null
+++ b/nnvm/include/nnvm/compiler/util.h
@@ -0,0 +1,33 @@
+/*!
+*  Copyright (c) 2016 by Contributors
+* \file util.h
+* \brief Utility functions for nnvm compiler
+*/
+#ifndef NNVM_COMPILER_UTIL_H_
+#define NNVM_COMPILER_UTIL_H_
+
+#include <tvm/expr.h>
+#include <nnvm/tuple.h>
+
+namespace nnvm {
+namespace compiler {
+
+/*
+ * \brief Helper function to convert TShape to TVM array. Useful for
+ * passing data from NNVM param structures to TOPI ops.
+ *
+ * \param shape The shape to convert
+ *
+ * \return An Array of Expr, where each element is a constant int32
+ */
+inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) {
+  tvm::Array<tvm::Expr> result;
+  for (auto i : shape) {
+    result.push_back(tvm::make_const(tvm::Int(32), i));
+  }
+  return result;
+}
+
+}  // namespace compiler
+}  // namespace nnvm
+#endif  // NNVM_COMPILER_UTIL_H_
diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h
index 02b1dd947..f290a1683 100644
--- a/nnvm/include/nnvm/symbolic.h
+++ b/nnvm/include/nnvm/symbolic.h
@@ -28,7 +28,7 @@ namespace nnvm {
  *  symbol is the final operation of a graph and thus including all the information
  *  required (the graph) to evaluate its output value.
  */
-class Symbol {
+class NNVM_DLL Symbol {
  public:
   /*! \brief option passed to ListAttr */
   enum ListAttrOption {
diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py
index 88aafec62..f91f77dfd 100644
--- a/nnvm/python/nnvm/top/nn.py
+++ b/nnvm/python/nnvm/top/nn.py
@@ -10,59 +10,26 @@ from . import registry as reg
 from .registry import OpPattern
 
 # relu
-@reg.register_compute("relu")
-def compute_relu(attrs, inputs, _):
-    """Compute definition of relu"""
-    return topi.nn.relu(inputs[0])
-
 reg.register_schedule("relu", _fschedule_broadcast)
 reg.register_pattern("relu", OpPattern.ELEMWISE)
 
 
 # leaky_relu
-@reg.register_compute("leaky_relu")
-def compute_leaky_relu(attrs, inputs, _):
-    """Compute definition of relu"""
-    return topi.nn.leaky_relu(inputs[0], attrs.get_float("alpha"))
-
 reg.register_schedule("leaky_relu", _fschedule_broadcast)
 reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
 
 
 # flatten
-@reg.register_compute("flatten")
-def compute_flatten(attrs, inputs, _):
-    """Compute definition of flatten"""
-    return topi.nn.flatten(inputs[0])
-
 reg.register_schedule("flatten", _fschedule_broadcast)
 reg.register_pattern("flatten", OpPattern.INJECTIVE)
 
 
 # pad
-@reg.register_compute("pad")
-def compute_pad(attrs, inputs, _):
-    """Compute definition of pad"""
-    pad_width = attrs.get_int_pair_tuple('pad_width')
-    assert len(pad_width) == len(inputs[0].shape) and \
-        len(pad_width[0]) == 2, "illegal pad_width"
-    pad_before = [x[0] for x in pad_width]
-    pad_after = [x[1] for x in pad_width]
-    pad_value = attrs.get_int('pad_value')
-    return topi.nn.pad(inputs[0], pad_before, pad_after, pad_value)
-
 reg.register_schedule("pad", _fschedule_broadcast)
 reg.register_pattern("pad", OpPattern.INJECTIVE)
 
 
 # softmax
-@reg.register_compute("softmax")
-def compute_softmax(attrs, inputs, _):
-    """Compute definition of softmax"""
-    axis = attrs.get_int("axis")
-    assert axis == -1, "only support axis == -1 for now"
-    return topi.nn.softmax(inputs[0])
-
 @reg.register_schedule("softmax")
 def schedule_softmax(_, outs, target):
     """Schedule definition of softmax"""
@@ -73,13 +40,6 @@ reg.register_pattern("softmax", OpPattern.OPAQUE)
 
 
 # log softmax
-@reg.register_compute("log_softmax")
-def compute_log_softmax(attrs, inputs, _):
-    """Compute definition of softmax"""
-    axis = attrs.get_int("axis")
-    assert axis == -1, "only support axis == -1 for now"
-    return topi.nn.log_softmax(inputs[0])
-
 @reg.register_schedule("log_softmax")
 def schedule_log_softmax(_, outs, target):
     """Schedule definition of softmax"""
@@ -91,13 +51,6 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE)
 
 
 # dense
-@reg.register_compute("dense")
-def compute_dense(attrs, inputs, _):
-    """Compute definition of dense"""
-    if attrs.get_bool("use_bias"):
-        return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
-    return topi.nn.dense(inputs[0], inputs[1])
-
 @reg.register_schedule("dense")
 def schedule_dense(_, outs, target):
     """Schedule definition of dense"""
@@ -175,18 +128,6 @@ reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # max_pool2d
-@reg.register_compute("max_pool2d")
-def compute_max_pool2d(attrs, inputs, _):
-    """Compute definition of max_pool2d"""
-    pool_size = attrs.get_int_tuple("pool_size")
-    strides = attrs.get_int_tuple("strides")
-    padding = attrs.get_int_tuple("padding")
-    layout = attrs["layout"]
-    ceil_mode = attrs.get_bool("ceil_mode")
-    assert layout == "NCHW", "only support nchw for now"
-    return topi.nn.pool(inputs[0], pool_size, strides, padding,
-                        pool_type='max', ceil_mode=ceil_mode)
-
 @reg.register_schedule("max_pool2d")
 def schedule_max_pool2d(_, outs, target):
     """Schedule definition of max_pool2d"""
@@ -197,18 +138,6 @@ reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # avg_pool2d
-@reg.register_compute("avg_pool2d")
-def compute_avg_pool2d(attrs, inputs, _):
-    """Compute definition of avg_pool2d"""
-    pool_size = attrs.get_int_tuple("pool_size")
-    strides = attrs.get_int_tuple("strides")
-    padding = attrs.get_int_tuple("padding")
-    layout = attrs["layout"]
-    ceil_mode = attrs.get_bool("ceil_mode")
-    assert layout == "NCHW", "only support nchw for now"
-    return topi.nn.pool(inputs[0], pool_size, strides, padding,
-                        pool_type='avg', ceil_mode=ceil_mode)
-
 @reg.register_schedule("avg_pool2d")
 def schedule_avg_pool2d(_, outs, target):
     """Schedule definition of avg_pool2d"""
@@ -219,13 +148,6 @@ reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # global_max_pool2d
-@reg.register_compute("global_max_pool2d")
-def compute_global_max_pool2d(attrs, inputs, _):
-    """Compute definition of global_max_pool2d"""
-    layout = attrs["layout"]
-    assert layout == "NCHW", "only support nchw for now"
-    return topi.nn.global_pool(inputs[0], pool_type='max')
-
 @reg.register_schedule("global_max_pool2d")
 def schedule_global_max_pool2d(_, outs, target):
     """Schedule definition of global_max_pool2d"""
@@ -236,13 +158,6 @@ reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # global_avg_pool2d
-@reg.register_compute("global_avg_pool2d")
-def compute_global_avg_pool2d(attrs, inputs, _):
-    """Compute definition of global_avg_pool2d"""
-    layout = attrs["layout"]
-    assert layout == "NCHW", "only support nchw for now"
-    return topi.nn.global_pool(inputs[0], pool_type='avg')
-
 @reg.register_schedule("global_avg_pool2d")
 def schedule_global_avg_pool2d(_, outs, target):
     """Schedule definition of global_avg_pool2d"""
diff --git a/nnvm/python/nnvm/top/reduction.py b/nnvm/python/nnvm/top/reduction.py
index 7003d3a52..193ea6038 100644
--- a/nnvm/python/nnvm/top/reduction.py
+++ b/nnvm/python/nnvm/top/reduction.py
@@ -27,16 +27,13 @@ def _compute_reduce(f):
     return _compute
 
 # sum
-reg.register_compute("sum", _compute_reduce(topi.sum))
 reg.register_pattern("sum", OpPattern.COMM_REDUCE)
 reg.register_schedule("sum", _fschedule_reduce)
 
 # max
-reg.register_compute("max", _compute_reduce(topi.max))
 reg.register_pattern("max", OpPattern.COMM_REDUCE)
 reg.register_schedule("max", _fschedule_reduce)
 
 # min
-reg.register_compute("min", _compute_reduce(topi.min))
 reg.register_pattern("min", OpPattern.COMM_REDUCE)
 reg.register_schedule("min", _fschedule_reduce)
diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py
index 7ca800e5b..565733775 100644
--- a/nnvm/python/nnvm/top/tensor.py
+++ b/nnvm/python/nnvm/top/tensor.py
@@ -43,132 +43,97 @@ _fschedule_broadcast = _fschedule_injective
 _fschedule_elemwise = _fschedule_injective
 
 # copy
-reg.register_compute("copy", _compute_unary(topi.identity))
 reg.register_pattern("copy", OpPattern.ELEMWISE)
 reg.register_schedule("copy", _fschedule_broadcast)
 
 # exp
-reg.register_compute("exp", _compute_unary(topi.exp))
 reg.register_pattern("exp", OpPattern.ELEMWISE)
 reg.register_schedule("exp", _fschedule_broadcast)
 
 # sqrt
-reg.register_compute("sqrt", _compute_unary(topi.sqrt))
 reg.register_pattern("sqrt", OpPattern.ELEMWISE)
 reg.register_schedule("sqrt", _fschedule_broadcast)
 
 # log
-reg.register_compute("log", _compute_unary(topi.log))
 reg.register_pattern("log", OpPattern.ELEMWISE)
 reg.register_schedule("log", _fschedule_broadcast)
 
 # tanh
-reg.register_compute("tanh", _compute_unary(topi.tanh))
 reg.register_pattern("tanh", OpPattern.ELEMWISE)
 reg.register_schedule("tanh", _fschedule_broadcast)
 
 # negative
-reg.register_compute("negative", _compute_unary(topi.negative))
 reg.register_pattern("negative", OpPattern.ELEMWISE)
 reg.register_schedule("negative", _fschedule_broadcast)
 
 # sigmoid
-reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
 reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
 reg.register_schedule("sigmoid", _fschedule_broadcast)
 
 # add_scalar
-reg.register_compute("__add_scalar__",
-                     _compute_binary_scalar(lambda x, y: x + y))
 reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__add_scalar__", _fschedule_broadcast)
 
 # sub_calar
-reg.register_compute("__sub_scalar__",
-                     _compute_binary_scalar(lambda x, y: x - y))
 reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
 
 # rsub_scalar
-reg.register_compute("__rsub_scalar__",
-                     _compute_binary_scalar(lambda x, y: y - x))
 reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
 
 # mul_scalar
-reg.register_compute("__mul_scalar__",
-                     _compute_binary_scalar(lambda x, y: x * y))
 reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
 
 # div_scalar
-reg.register_compute("__div_scalar__",
-                     _compute_binary_scalar(lambda x, y: x / y))
 reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__div_scalar__", _fschedule_broadcast)
 
 # rdiv_scalar
-reg.register_compute("__rdiv_scalar__",
-                     _compute_binary_scalar(lambda x, y: y / x))
 reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
 
 # pow_scalar
-reg.register_compute("__pow_scalar__",
-                     _compute_binary_scalar(tvm.power))
 reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
 
 # rpow_scalar
-reg.register_compute("__rpow_scalar__",
-                     _compute_binary_scalar(lambda x, y: tvm.power(y, x)))
 reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
 reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
 
 # elemwise_add
-reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
 reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
 reg.register_schedule("elemwise_add", _fschedule_broadcast)
 
 # elemwise_sub
-reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
 reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
 reg.register_schedule("elemwise_sub", _fschedule_broadcast)
 
 # elemwise_mul
-reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
 reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
 reg.register_schedule("elemwise_mul", _fschedule_broadcast)
 
 # elemwise_div
-reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
 reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
 reg.register_schedule("elemwise_div", _fschedule_broadcast)
 
 # broadcast_add
-reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
 reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_add", _fschedule_broadcast)
 
 # broadcast_sub
-reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
 reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_sub", _fschedule_broadcast)
 
 # broadcast_mul
-reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
 reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_mul", _fschedule_broadcast)
 
 # broadcast_div
-reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
 reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_div", _fschedule_broadcast)
 
 # broadcast_to
-@reg.register_compute("broadcast_to")
-def compute_broadcast_to(attrs, inputs, out_info):
-    """Compute definition of softmax"""
-    return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
 reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_to", _fschedule_broadcast)
diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py
index 21c10f5cc..ec1596e95 100644
--- a/nnvm/python/nnvm/top/transform.py
+++ b/nnvm/python/nnvm/top/transform.py
@@ -2,71 +2,30 @@
 """Tensor transformation ops"""
 from __future__ import absolute_import
 
-import topi
 from .tensor import _fschedule_broadcast, _fschedule_injective
 from . import registry as reg
 from .registry import OpPattern
 
 # expand_dims
-@reg.register_compute("expand_dims")
-def compute_expand_dims(attrs, inputs, out_info):
-    """Compute definition of expand_dims"""
-    return topi.expand_dims(
-        inputs[0], attrs.get_int("axis"),
-        num_newaxis=attrs.get_int("num_newaxis"))
 reg.register_pattern("expand_dims", OpPattern.BROADCAST)
 reg.register_schedule("expand_dims", _fschedule_broadcast)
 
 # transpose
-@reg.register_compute("transpose")
-def compute_transpose(attrs, inputs, out_info):
-    """Compute definition of transpose"""
-    axes = attrs.get_int_tuple("axes")
-    axes = tuple(axes) if axes else None
-    return topi.transpose(inputs[0], axes)
 reg.register_pattern("transpose", OpPattern.INJECTIVE)
 reg.register_schedule("transpose", _fschedule_injective)
 
 # reshape
-@reg.register_compute("reshape")
-def compute_reshape(attrs, inputs, out_info):
-    """Compute definition of reshape"""
-    oshape = out_info[0].shape
-    return topi.reshape(inputs[0], oshape)
 reg.register_pattern("reshape", OpPattern.INJECTIVE)
 reg.register_schedule("reshape", _fschedule_injective)
 
-# reshape
-@reg.register_compute("squeeze")
-def compute_squeeze(attrs, inputs, out_info):
-    """Compute definition of reshape"""
-    axis = attrs.get_int_tuple("axis")
-    axis = tuple(axis) if axis else None
-    return topi.squeeze(inputs[0], axis)
+# squeeze
 reg.register_pattern("squeeze", OpPattern.INJECTIVE)
 reg.register_schedule("squeeze", _fschedule_injective)
 
 # concatenate
-@reg.register_compute("concatenate")
-def compute_concatenate(attrs, inputs, out_info):
-    """Compute definition of concatenate"""
-    axis = attrs.get_int("axis")
-    return topi.concatenate([x for x in inputs], axis=axis)
-
 reg.register_pattern("concatenate", OpPattern.INJECTIVE)
 reg.register_schedule("concatenate", _fschedule_injective)
 
 # split
-@reg.register_compute("split")
-def compute_split(attrs, inputs, out_info):
-    """Compute definition of split"""
-    x = attrs["indices_or_sections"]
-    if x.startswith("(") or x.startswith("["):
-        indices_or_sections = attrs.get_int_tuple("indices_or_sections")
-    else:
-        indices_or_sections = attrs.get_int("indices_or_sections")
-    return topi.split(inputs[0], indices_or_sections, axis=attrs.get_int("axis"))
-
-
 reg.register_pattern("split", OpPattern.INJECTIVE)
 reg.register_schedule("split", _fschedule_injective)
diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc
index 5902abdc6..23c4d4c9b 100644
--- a/nnvm/src/compiler/compile_engine.cc
+++ b/nnvm/src/compiler/compile_engine.cc
@@ -344,7 +344,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
     *rv = ret;
   });
 
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
     p->stream << "GraphFunc(name=" << op->func_name
               << ", addr=" << op << ")";
diff --git a/nnvm/src/compiler/graph_hash.cc b/nnvm/src/compiler/graph_hash.cc
index f493bc042..d881130f7 100644
--- a/nnvm/src/compiler/graph_hash.cc
+++ b/nnvm/src/compiler/graph_hash.cc
@@ -80,7 +80,7 @@ GraphKey GraphKeyNode::make(Graph graph,
   return GraphKey(n);
 }
 
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) {
     p->stream << "GraphKeyNode("<< op << ")";
 });
diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index bf169acc7..88ed7fd58 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -3,17 +3,27 @@
  * \file nn.cc
  * \brief Property def of nn operators.
  */
+#include <tvm/expr.h>
+#include <tvm/packed_func_ext.h>
 #include <nnvm/op.h>
 #include <nnvm/node.h>
 #include <nnvm/op_attr_types.h>
+#include <nnvm/compiler/op_attr_types.h>
 #include <nnvm/top/nn.h>
 #include "./nn_common.h"
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
+#include "topi/nn/dense.h"
+#include "topi/nn.h"
+#include "topi/nn/softmax.h"
 
 namespace nnvm {
 namespace top {
 
+using tvm::Tensor;
+using tvm::Array;
+using nnvm::compiler::FTVMCompute;
+
 // dense
 DMLC_REGISTER_PARAMETER(DenseParam);
 
@@ -72,6 +82,21 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
 .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>)
 .set_attr<FInferShape>("FInferShape", DenseInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    Tensor bias_val;
+    Tensor* bias;
+    const DenseParam& param = nnvm::get<DenseParam>(attrs.parsed);
+    if (param.use_bias) {
+      bias_val = inputs[2];
+      bias = &bias_val;
+    } else {
+      bias = nullptr;
+    }
+    return Array<Tensor>{ topi::nn::dense(inputs[0], inputs[1], bias) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -110,6 +135,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
    max(input, 0)
 
 )code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::relu(inputs[0], 0.0f) };
+  })
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -258,6 +289,14 @@ NNVM_REGISTER_OP(softmax)
 .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+    CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported";
+    return Array<Tensor>{ topi::nn::softmax(inputs[0]) };
+  })
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -306,6 +345,14 @@ NNVM_REGISTER_OP(log_softmax)
 .set_num_outputs(1)
 .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+    CHECK_EQ(param.axis, -1) << "Currently only axis=-1 is supported";
+    return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
+  })
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -357,6 +404,13 @@ NNVM_REGISTER_OP(leaky_relu)
 .set_num_outputs(1)
 .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
+    return Array<Tensor>{ topi::leaky_relu<float>(inputs[0], 0.0, param.alpha) };
+  })
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -413,6 +467,25 @@ NNVM_REGISTER_OP(pad)
 .set_num_inputs(1)
 .set_attr<FInferShape>("FInferShape", PadInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const PadParam& param = nnvm::get<PadParam>(attrs.parsed);
+    auto pad_width = param.pad_width;
+    CHECK(pad_width.ndim() == inputs[0]->shape.size() &&
+      pad_width[0].ndim() == 2)
+      << "Illegal pad_width";
+    Array<tvm::Expr> pad_before;
+    for (size_t i = 0; i < pad_width.ndim(); ++i) {
+      pad_before.push_back(tvm::make_const(tvm::Int(32), pad_width[i][0]));
+    }
+    Array<tvm::Expr> pad_after;
+    for (size_t i = 0; i < pad_width.ndim(); ++i) {
+      pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1]));
+    }
+    return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after, param.pad_value) };
+})
 .set_support_level(1);
 
 }  // namespace top
diff --git a/nnvm/src/top/nn/pooling.cc b/nnvm/src/top/nn/pooling.cc
index 0f0eb6817..ae8d872de 100644
--- a/nnvm/src/top/nn/pooling.cc
+++ b/nnvm/src/top/nn/pooling.cc
@@ -6,13 +6,18 @@
 #include <nnvm/op.h>
 #include <nnvm/node.h>
 #include <nnvm/op_attr_types.h>
+#include <nnvm/compiler/op_attr_types.h>
+#include <nnvm/compiler/util.h>
 #include <nnvm/top/nn.h>
 #include "./nn_common.h"
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
+#include "topi/nn/pooling.h"
 
 namespace nnvm {
 namespace top {
+using namespace tvm;
+using namespace nnvm::compiler;
 
 DMLC_REGISTER_PARAMETER(Pool2DParam);
 
@@ -77,6 +82,20 @@ NNVM_REGISTER_OP(max_pool2d)
 .set_num_inputs(1)
 .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
+    auto pool_size = ShapeToArray(param.pool_size);
+    auto strides = ShapeToArray(param.strides);
+    auto padding = ShapeToArray(param.padding);
+    auto ceil_mode = param.ceil_mode;
+    CHECK_EQ(param.layout, kNCHW)
+      << "max_pool2d currently only supports NCHW layout";
+    return Array<Tensor>{
+      topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kMaxPool, ceil_mode) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -124,6 +143,20 @@ NNVM_REGISTER_OP(avg_pool2d)
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
 .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
+    auto pool_size = ShapeToArray(param.pool_size);
+    auto strides = ShapeToArray(param.strides);
+    auto padding = ShapeToArray(param.padding);
+    auto ceil_mode = param.ceil_mode;
+    CHECK_EQ(param.layout, kNCHW)
+      << "avg_pool2d currently only supports NCHW layout";
+    return Array<Tensor>{
+      topi::nn::pool(inputs[0], pool_size, strides, padding, topi::nn::kAvgPool, ceil_mode) };
+})
 .set_num_outputs(1)
 .set_num_inputs(1)
 .set_support_level(2);
@@ -162,6 +195,16 @@ NNVM_REGISTER_OP(global_max_pool2d)
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
 .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
+    CHECK_EQ(param.layout, kNCHW)
+      << "global_max_pool2d currently only supports NCHW layout";
+    return Array<Tensor>{
+      topi::nn::global_pool(inputs[0], topi::nn::kMaxPool) };
+})
 .set_num_outputs(1)
 .set_num_inputs(1)
 .set_support_level(2);
@@ -182,6 +225,16 @@ NNVM_REGISTER_OP(global_avg_pool2d)
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
 .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
+    CHECK_EQ(param.layout, kNCHW)
+      << "global_avg_pool2d currently only supports NCHW layout";
+    return Array<Tensor>{
+      topi::nn::global_pool(inputs[0], topi::nn::kAvgPool) };
+})
 .set_num_outputs(1)
 .set_num_inputs(1)
 .set_support_level(2);
diff --git a/nnvm/src/top/tensor/broadcast.cc b/nnvm/src/top/tensor/broadcast.cc
index 945195fd0..773281450 100644
--- a/nnvm/src/top/tensor/broadcast.cc
+++ b/nnvm/src/top/tensor/broadcast.cc
@@ -3,15 +3,22 @@
  * \file broadcast.cc
  * \brief broadcast operator.
  */
+#include <tvm/expr.h>
+#include <tvm/packed_func_ext.h>
 #include <nnvm/op.h>
 #include <nnvm/node.h>
 #include <nnvm/op_attr_types.h>
+#include <nnvm/compiler/op_attr_types.h>
+#include <nnvm/compiler/util.h>
 #include <nnvm/top/tensor.h>
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
+#include "topi/broadcast.h"
 
 namespace nnvm {
 namespace top {
+using namespace tvm;
+using namespace nnvm::compiler;
 
 // broadcast_to
 DMLC_REGISTER_PARAMETER(BroadcastToParam);
@@ -67,6 +74,14 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
 .set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+    const Array<Tensor>& inputs,
+    const Array<Tensor>& out_info) {
+      const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
+      auto shape = ShapeToArray(param.shape);
+      return Array<Tensor>{ topi::broadcast_to(inputs[0], shape) };
+  })
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_support_level(4);
@@ -122,6 +137,13 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
     [](const NodeAttrs& attrs) {                                    \
       return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};     \
     })                                                              \
+  .set_attr<FTVMCompute>(                                           \
+    "FTVMCompute", [](const NodeAttrs& attrs,                       \
+      const Array<Tensor>& inputs,                                  \
+      const Array<Tensor>& out_info) {                              \
+        return Array<Tensor>{                                       \
+          topi::name(inputs[0], inputs[1]) };                       \
+    })                                                              \
   .add_argument("lhs", "Tensor", "first input")                     \
   .add_argument("rhs", "Tensor", "second input")
 
diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc
index 73a958322..ac66e0c3f 100644
--- a/nnvm/src/top/tensor/elemwise.cc
+++ b/nnvm/src/top/tensor/elemwise.cc
@@ -6,13 +6,19 @@
 #include <nnvm/op.h>
 #include <nnvm/node.h>
 #include <nnvm/op_attr_types.h>
+#include <nnvm/compiler/op_attr_types.h>
 #include <nnvm/top/tensor.h>
 #include <cmath>
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
+#include "topi/broadcast.h"
+#include "topi/elemwise.h"
+#include "topi/tags.h"
 
 namespace nnvm {
 namespace top {
+using namespace tvm;
+using namespace nnvm::compiler;
 // undefined op
 NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__)
 .describe(R"code(undefined op.
@@ -32,6 +38,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
 
 )code" NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::sigmoid(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -56,6 +68,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(tanh)
 
 )code" NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::tanh(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -80,6 +98,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(exp)
 
 )code" NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::exp(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -100,6 +124,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(log)
 
 )code" NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::log(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -120,6 +150,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sqrt)
 
 )code" NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::sqrt(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -140,6 +176,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
 
 )code")
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::broadcast_add(inputs[0], inputs[1]) };
+  })
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -154,6 +196,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::broadcast_sub(inputs[0], inputs[1]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -171,6 +219,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::broadcast_mul(inputs[0], inputs[1]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -190,6 +244,12 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::broadcast_div(inputs[0], inputs[1]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -216,6 +276,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::negative(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -232,6 +298,12 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+      return Array<Tensor>{ topi::identity(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -315,12 +387,29 @@ DMLC_REGISTER_PARAMETER(ScalarParam);
   .set_attr_parser(ParamParser<ScalarParam>)                            \
   .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ScalarParam>)
 
+inline Tensor binary_scalar_op(const NodeAttrs& attrs,
+                               const Tensor& x,
+                               std::function<Expr(Expr, Expr)> f) {
+  const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
+  auto scalar_val = static_cast<float>(param.scalar);
+  return compute(x->shape, [&](const Array<Var>& i) {
+    auto scalar_const = make_const(x->dtype, scalar_val);
+    return f(x(i), scalar_const);
+    }, "tensor", topi::kElementWise);
+}
 
 NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__)
 .describe(R"code(Tensor add scalar
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return x + y; }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -332,6 +421,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return x - y; }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -343,6 +439,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return y - x; }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -356,6 +459,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return x * y; }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -372,6 +482,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__div_scalar__)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return x / y; }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -388,6 +505,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return y / x; }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -411,6 +535,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__pow_scalar__)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return tvm::pow(x, y); }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -434,6 +565,13 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__)
 
 )code"  NNVM_ADD_FILELINE)
 .set_support_level(3)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ binary_scalar_op(attrs, inputs[0],
+      [](Expr x, Expr y) { return tvm::pow(y, x); }) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc
index 429800fe5..8eac2449b 100644
--- a/nnvm/src/top/tensor/reduce.cc
+++ b/nnvm/src/top/tensor/reduce.cc
@@ -6,12 +6,17 @@
 #include <nnvm/op.h>
 #include <nnvm/node.h>
 #include <nnvm/op_attr_types.h>
+#include <nnvm/compiler/op_attr_types.h>
+#include <nnvm/compiler/util.h>
 #include <nnvm/top/tensor.h>
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
+#include "topi/reduction.h"
 
 namespace nnvm {
 namespace top {
+using namespace tvm;
+using namespace nnvm::compiler;
 
 // reduce
 DMLC_REGISTER_PARAMETER(ReduceParam);
@@ -127,6 +132,15 @@ Example::
   [ 12.  19.  27.]
 
 )code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
+    auto axis = ShapeToArray(param.axis);
+    return Array<Tensor>{
+      topi::sum(inputs[0], axis, param.keepdims) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -145,6 +159,15 @@ NNVM_REGISTER_REDUCE_OP(max)
 .describe(R"code(Computes the max of array elements over given axes.
 
 )code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
+    auto axis = ShapeToArray(param.axis);
+    return Array<Tensor>{
+      topi::max(inputs[0], axis, param.keepdims) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -168,6 +191,15 @@ NNVM_REGISTER_REDUCE_OP(min)
 .describe(R"code(Computes the min of array elements over given axes.
 
 )code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
+    auto axis = ShapeToArray(param.axis);
+    return Array<Tensor>{
+      topi::min(inputs[0], axis, param.keepdims) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index 90457389c..0bf1a91ec 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -6,13 +6,19 @@
 #include <nnvm/op.h>
 #include <nnvm/node.h>
 #include <nnvm/op_attr_types.h>
+#include <nnvm/compiler/op_attr_types.h>
+#include <nnvm/compiler/util.h>
 #include <nnvm/top/tensor.h>
 #include <cctype>
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
+#include "topi/nn/flatten.h"
+#include "topi/transform.h"
 
 namespace nnvm {
 namespace top {
+using namespace tvm;
+using namespace nnvm::compiler;
 
 // flatten
 inline bool FlattenInferShape(const NodeAttrs& attrs,
@@ -58,6 +64,12 @@ Example::
 .set_attr<FInferShape>("FInferShape", FlattenInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
 .add_argument("data", "Tensor", "Input data.")
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::nn::flatten(inputs[0]) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -144,6 +156,13 @@ Example::
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
 .set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
 .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ConcatenateParam& param = nnvm::get<ConcatenateParam>(attrs.parsed);
+    return Array<Tensor>{ topi::concatenate(inputs, param.axis) };
+})
 .set_num_outputs(1)
 .set_num_inputs(kVarg)
 .set_support_level(1);
@@ -190,6 +209,13 @@ will return a new array with shape ``(2,5,3,4)``.
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_num_inputs(1)
 .set_num_outputs(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(attrs.parsed);
+    return Array<Tensor>{ topi::expand_dims(inputs[0], param.axis, param.num_newaxis) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
@@ -326,6 +352,22 @@ along which to split the array.
 .set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
 .set_num_inputs(1)
 .set_num_outputs(SplitNumOutputs)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+    if (param.equal_split) {
+      return Array<Tensor>{
+        topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) };
+    } else {
+      Array<Expr> indices;
+      for (auto i : param.indices_or_sections) {
+        indices.push_back(tvm::make_const(tvm::Int(32), i));
+      }
+      return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
+    }
+})
 .set_support_level(1);
 
 // cast
@@ -504,6 +546,12 @@ The significance of each is explained below:
 .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_num_inputs(1)
 .set_num_outputs(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::reshape(inputs[0], out_info[0]->shape) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -620,6 +668,14 @@ Examples::
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_num_inputs(1)
 .set_num_outputs(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
+    auto axis = ShapeToArray(param.axis);
+    return Array<Tensor>{ topi::squeeze(inputs[0], axis) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
@@ -695,6 +751,14 @@ Examples::
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_support_level(4)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
+    auto axes = ShapeToArray(param.axes);
+    return Array<Tensor>{ topi::transpose(inputs[0], axes) };
+})
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
-- 
GitLab