diff --git a/nnvm/docs/top.rst b/nnvm/docs/top.rst
index c34ae5b54b07cef30ec4f85dfa381b1d532e03c2..11d0f6093f83411f7da4dc3273163e5b3fcdb909 100644
--- a/nnvm/docs/top.rst
+++ b/nnvm/docs/top.rst
@@ -28,7 +28,6 @@ This level enables fully connected multi-layer perceptron.
    :nosignatures:
 
    nnvm.symbol.dense
-   nnvm.symbol.matmul
    nnvm.symbol.relu
    nnvm.symbol.tanh
    nnvm.symbol.sigmoid
@@ -40,12 +39,6 @@ This level enables fully connected multi-layer perceptron.
    nnvm.symbol.elemwise_mul
    nnvm.symbol.elemwise_div
    nnvm.symbol.elemwise_sum
-   nnvm.symbol.full
-   nnvm.symbol.full_like
-   nnvm.symbol.ones
-   nnvm.symbol.ones_like
-   nnvm.symbol.zeros
-   nnvm.symbol.zeros_like
    nnvm.symbol.flatten
    nnvm.symbol.concatenate
    nnvm.symbol.expand_dims
@@ -57,7 +50,6 @@ This level enables fully connected multi-layer perceptron.
    nnvm.symbol.log_softmax
    nnvm.symbol.pad
    nnvm.symbol.block_grad
-   nnvm.symbol.indicator
 
 
 **Level 2: Convolutions**
@@ -81,8 +73,6 @@ This level enables typical convnet models.
    :nosignatures:
 
    nnvm.symbol.reshape
-   nnvm.symbol.reshape_like
-   nnvm.symbol.expand_like
    nnvm.symbol.copy
    nnvm.symbol.negative
    nnvm.symbol.leaky_relu
@@ -109,11 +99,21 @@ This level enables typical convnet models.
    nnvm.symbol.broadcast_sub
    nnvm.symbol.broadcast_mul
    nnvm.symbol.broadcast_div
+   nnvm.symbol.clip
+   nnvm.symbol.greater
+   nnvm.symbol.less
+   nnvm.symbol.expand_like
+   nnvm.symbol.reshape_like
+   nnvm.symbol.full
+   nnvm.symbol.full_like
+   nnvm.symbol.ones
+   nnvm.symbol.ones_like
+   nnvm.symbol.zeros
+   nnvm.symbol.zeros_like
 
 Detailed Definitions
 --------------------
 .. autofunction:: nnvm.symbol.dense
-.. autofunction:: nnvm.symbol.matmul
 .. autofunction:: nnvm.symbol.relu
 .. autofunction:: nnvm.symbol.tanh
 .. autofunction:: nnvm.symbol.sigmoid
@@ -125,12 +125,6 @@ Detailed Definitions
 .. autofunction:: nnvm.symbol.elemwise_mul
 .. autofunction:: nnvm.symbol.elemwise_div
 .. autofunction:: nnvm.symbol.elemwise_sum
-.. autofunction:: nnvm.symbol.full
-.. autofunction:: nnvm.symbol.full_like
-.. autofunction:: nnvm.symbol.ones
-.. autofunction:: nnvm.symbol.ones_like
-.. autofunction:: nnvm.symbol.zeros
-.. autofunction:: nnvm.symbol.zeros_like
 .. autofunction:: nnvm.symbol.flatten
 .. autofunction:: nnvm.symbol.concatenate
 .. autofunction:: nnvm.symbol.expand_dims
@@ -142,7 +136,6 @@ Detailed Definitions
 .. autofunction:: nnvm.symbol.log_softmax
 .. autofunction:: nnvm.symbol.pad
 .. autofunction:: nnvm.symbol.block_grad
-.. autofunction:: nnvm.symbol.indicator
 
 .. autofunction:: nnvm.symbol.conv2d
 .. autofunction:: nnvm.symbol.conv2d_transpose
@@ -152,8 +145,6 @@ Detailed Definitions
 .. autofunction:: nnvm.symbol.global_avg_pool2d
 
 .. autofunction:: nnvm.symbol.reshape
-.. autofunction:: nnvm.symbol.reshape_like
-.. autofunction:: nnvm.symbol.expand_like
 .. autofunction:: nnvm.symbol.copy
 .. autofunction:: nnvm.symbol.negative
 .. autofunction:: nnvm.symbol.leaky_relu
@@ -175,3 +166,14 @@ Detailed Definitions
 .. autofunction:: nnvm.symbol.broadcast_sub
 .. autofunction:: nnvm.symbol.broadcast_mul
 .. autofunction:: nnvm.symbol.broadcast_div
+.. autofunction:: nnvm.symbol.clip
+.. autofunction:: nnvm.symbol.greater
+.. autofunction:: nnvm.symbol.less
+.. autofunction:: nnvm.symbol.expand_like
+.. autofunction:: nnvm.symbol.reshape_like
+.. autofunction:: nnvm.symbol.full
+.. autofunction:: nnvm.symbol.full_like
+.. autofunction:: nnvm.symbol.ones
+.. autofunction:: nnvm.symbol.ones_like
+.. autofunction:: nnvm.symbol.zeros
+.. autofunction:: nnvm.symbol.zeros_like
diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h
index 7f50e49b996cad63706792af236ae7dcca783595..00bad8245713c44367ddf05bcf4da26df92a45d3 100644
--- a/nnvm/include/nnvm/top/tensor.h
+++ b/nnvm/include/nnvm/top/tensor.h
@@ -241,6 +241,16 @@ struct MatMulParam : public dmlc::Parameter<MatMulParam> {
   }
 };
 
+struct ClipParam : public dmlc::Parameter<ClipParam> {
+  double a_min, a_max;
+  DMLC_DECLARE_PARAMETER(ClipParam) {
+    DMLC_DECLARE_FIELD(a_min)
+      .describe("Minimum value such that value smaller then this will be clipped.");
+    DMLC_DECLARE_FIELD(a_max)
+      .describe("Maximum value such that value larger then this will be clipped.");
+  }
+};
+
 }  // namespace top
 }  // namespace nnvm
 
diff --git a/nnvm/python/nnvm/_base.py b/nnvm/python/nnvm/_base.py
index d01bdcaeb5b1a05277b2ae67bc75b30a87514d75..29390a2201bf7b711afe4ea419bfef1f89f93eaf 100644
--- a/nnvm/python/nnvm/_base.py
+++ b/nnvm/python/nnvm/_base.py
@@ -54,6 +54,9 @@ OpHandle = ctypes.c_void_p
 SymbolHandle = ctypes.c_void_p
 GraphHandle = ctypes.c_void_p
 
+# Global dict of str to symbol to initialize variables
+_all_var_init = {}
+
 #----------------------------
 # helper function definition
 #----------------------------
diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py
index bce142d21b630706be24cc97913e79c1157549ae..d97a8784dcc275129c4249a4c83dd1183c8be6da 100644
--- a/nnvm/python/nnvm/compiler/build_module.py
+++ b/nnvm/python/nnvm/compiler/build_module.py
@@ -4,9 +4,12 @@ from __future__ import absolute_import as _abs
 
 import logging
 import tvm
+
 from tvm.contrib import graph_runtime
 from . import graph_attr, graph_util
 from .. import graph as _graph
+from .. import symbol as sym
+from .._base import _all_var_init
 
 OPT_PASS_LEVEL = {
     "SimplifyInference": 0,
@@ -201,6 +204,9 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
         By default, llvm is used if it is enabled,
         otherwise a stackvm intepreter is used.
 
+    initialize : bool, optional
+        Whether to initialize variables in global dict _all_var_init.
+
     Returns
     -------
     graph : Graph
@@ -230,6 +236,10 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
     if not isinstance(dtype, str):
         idtype, _ = graph_util.infer_dtype(graph, **dtype)
         dtype.update(zip(graph.index.input_names, idtype))
+    # Initialize all variables specified in _all_var_init
+    init_var = {}
+    if _all_var_init:
+        init_var = initialize_variables(shape, dtype)
     # Apply optimization
     graph = optimize(graph, shape, dtype)
     # Precompute prune
@@ -250,6 +260,11 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
     with target:
         graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
     libmod = graph_attr._move_out_module(graph, "module")
+    # Write variable initial values into params
+    if init_var:
+        if params is None:
+            params = {}
+        params.update(init_var)
     return graph, libmod, params
 
 
@@ -329,3 +344,45 @@ def precompute_prune(graph, params):
     with tvm.build_config(auto_unroll_max_step=0):
         out_arrs = _run_graph(pre_graph, params)
     return graph, dict(zip(out_names, out_arrs))
+
+
+def initialize_variables(ishape, idtype):
+    """ Initialize variables stored in _all_var_init dictionary.
+
+    Parameters
+    ----------
+    ishape : dict of str to tuple of int
+        The input shape to the graph
+
+    idtype : str or dict of str to str
+        The input types to the graph
+
+    Returns
+    -------
+    init_var : dict of str to tvm.ndarray
+    """
+    symbol_init_dict = {}
+    const_init_dict = {}
+    init_var = {}
+    for key, value in _all_var_init.items():
+        if isinstance(value, sym.Symbol):
+            symbol_init_dict[key] = value
+        else:
+            const_init_dict[key] = tvm.nd.array(value)
+    # Make sure variables are initialized only once.
+    _all_var_init.clear()
+    if symbol_init_dict:
+        # Create dummy params to run initialization graph
+        params = {}
+        for name, shape in ishape.items():
+            dtype = idtype if isinstance(idtype, str) else idtype[name]
+            params[name] = tvm.nd.empty(shape, dtype, ctx=tvm.cpu())
+        init_group_sym = sym.Group(symbol_init_dict.values())
+        graph = _graph.create(init_group_sym)
+        with tvm.build_config(auto_unroll_max_step=0):
+            init_values = _run_graph(graph, params)
+        init_var.update(dict(zip(symbol_init_dict.keys(), init_values)))
+    init_var.update(const_init_dict)
+    for name, data in init_var.items():
+        ishape[name] = data.shape
+    return init_var
diff --git a/nnvm/python/nnvm/compiler/lr_scheduler.py b/nnvm/python/nnvm/compiler/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..791925e74960f58469b4123e94a72729f3fb8a98
--- /dev/null
+++ b/nnvm/python/nnvm/compiler/lr_scheduler.py
@@ -0,0 +1,58 @@
+# pylint: disable=too-few-public-methods, no-member
+"""API for scheduling learning rate."""
+from .. import symbol as sym
+
+class LRScheduler(object):
+    """Base class of a learning rate scheduler.
+
+    A scheduler returns a new learning rate based on the number of updates that have
+    been performed.
+
+    Parameters
+    ----------
+    base_lr : float, optional
+        The initial learning rate.
+    """
+    def __init__(self, base_lr=0.01, name='LRScheduler'):
+        self.name = name
+        self.base_lr = base_lr
+
+    def __call__(self, num_update):
+        """Return a new learning rate based on number of updates.
+
+        Parameters
+        ----------
+        num_update: nnvm Symbol
+            the number of updates applied to weight.
+        """
+        raise NotImplementedError("__call__ method must be overridden.")
+
+class FactorScheduler(LRScheduler):
+    """Reduce the learning rate by a factor for every *n* steps.
+
+    It returns a new learning rate by::
+
+        base_lr * pow(factor, num_update/step)
+
+    Parameters
+    ----------
+    step : int
+        Changes the learning rate for every n updates.
+    factor : float, optional
+        The factor to change the learning rate.
+    stop_factor_lr : float, optional
+        Stop updating the learning rate if it is less than this value.
+    """
+    def __init__(self, step, factor=1, stop_factor_lr=1e-8, name='FactorScheduler', **kwargs):
+        super(FactorScheduler, self).__init__(name=name, **kwargs)
+        if step < 1:
+            raise ValueError("Schedule step must be greater or equal than 1 round")
+        if factor > 1.0:
+            raise ValueError("Factor must be no more than 1 to make lr reduce")
+        self.step = step
+        self.factor = factor
+        self.stop_factor_lr = stop_factor_lr
+
+    def __call__(self, num_update):
+        updated_lr = self.base_lr * self.factor ** (num_update / self.step)
+        return sym.clip(updated_lr, a_min=self.stop_factor_lr, a_max=self.base_lr)
diff --git a/nnvm/python/nnvm/compiler/optimizer.py b/nnvm/python/nnvm/compiler/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcf7498528af203a89d56825e9441b9acd0b217e
--- /dev/null
+++ b/nnvm/python/nnvm/compiler/optimizer.py
@@ -0,0 +1,131 @@
+# pylint: disable=invalid-name, no-member, too-few-public-methods, too-many-arguments, too-many-locals, protected-access
+"""Optimizer API"""
+from . import graph_util
+from .. import symbol as sym
+
+class Optimizer(object):
+    """Base class inherited by all optimizers.
+
+    Parameters
+    ----------
+    learning_rate : float, optional
+        The initial learning rate.
+
+    lr_scheduler : LRScheduler, optional
+        The learning rate scheduler.
+
+    rescale_grad : float, optional
+        Multiply the gradient with `rescale_grad` before updating. Often
+        choose to be ``1.0/batch_size``.
+
+    clip_gradient : float, optional
+        Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.
+
+    wd : float, optional
+        The weight decay (or L2 regularization) coefficient. Modifies objective
+        by adding a penalty for having large weights.
+
+    name : string, optional
+        The name of optimizer.
+    """
+    def __init__(self, learning_rate=0.01, lr_scheduler=None,
+                 rescale_grad=1, clip_gradient=None, wd=0, name="Optimizer"):
+        self.name = name
+        self.lr = learning_rate
+        self.lr_scheduler = lr_scheduler
+        self.rescale_grad = rescale_grad
+        self.clip_gradient = clip_gradient
+        self.wd = wd
+        init_update_t = sym.Variable(name+'_t', init=sym.zeros(shape=(1,), dtype="int32"))
+        self.update_t = sym._assign(init_update_t, init_update_t + 1)
+
+    def minimize(self, obj, var=None):
+        """Minimize given obj symbol respect to var. If var is not set, all input
+        variables of obj will be used.
+
+        Parameters
+        ----------
+        obj : nnvm Symbol or list of nnvm Symbols
+            Symbols to be minimized.
+        var : nnvm Symbol or list of nnvm Symbols, optional
+            Symbols the gradient respect to.
+
+        Returns
+        -------
+        group_sym : nnvm Symbol
+            Group symbol represents update symbols.
+        """
+        raise NotImplementedError()
+
+    def _get_lr(self):
+        """Gets the learning rate with learning rate scheduler.
+
+        Returns
+        -------
+        lr : float
+            Learning rate.
+        """
+        if self.lr_scheduler is not None:
+            lr = self.lr_scheduler(self.update_t)
+        else:
+            lr = self.lr
+        return lr
+
+
+class SGD(Optimizer):
+    """The SGD optimizer
+    """
+    def __init__(self, name='SGD', **kwargs):
+        super(SGD, self).__init__(name=name, **kwargs)
+
+    def minimize(self, obj, var=None):
+        variables = var or obj.list_input_variables()
+        if not isinstance(variables, list):
+            variables = [variables]
+        grads = graph_util.gradients(obj, variables)
+        updates = []
+        lr_t = self._get_lr()
+        for v, g in zip(variables, grads):
+            g = self.rescale_grad * g
+            if self.clip_gradient is not None:
+                g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
+            updates.append(sym._assign(v, v - lr_t * (g + self.wd * v)))
+        return sym.Group(updates)
+
+
+class Adam(Optimizer):
+    """The Adam optimizer.
+
+    This class implements the optimizer described in *Adam: A Method for
+    Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
+    """
+    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999,
+                 epsilon=1e-8, name='Adam', **kwargs):
+        super(Adam, self).__init__(learning_rate=learning_rate, name=name, **kwargs)
+        self.beta1 = beta1
+        self.beta2 = beta2
+        self.epsilon = epsilon
+        self.m = []
+        self.v = []
+
+    def minimize(self, obj, var=None):
+        variables = var or obj.list_input_variables()
+        if not isinstance(variables, list):
+            variables = [variables]
+        grads = graph_util.gradients(obj, variables)
+        updates = []
+        for i, v in enumerate(variables):
+            self.m.append(sym.Variable(self.name + '_m' + str(i), init=sym.zeros_like(v)))
+            self.v.append(sym.Variable(self.name + '_v' + str(i), init=sym.zeros_like(v)))
+        rate = sym.sqrt(1 - self.beta2 ** self.update_t) / (1 -  self.beta1 ** self.update_t)
+        lr_t = self._get_lr() * rate
+        for variable, g, m, v in zip(variables, grads, self.m, self.v):
+            g = self.rescale_grad * g
+            if self.clip_gradient is not None:
+                g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
+            update_m = sym._assign(m, self.beta1 * m + (1 - self.beta1) * g)
+            update_v = sym._assign(v, self.beta2 * v + (1 - self.beta2) * g * g)
+            update_var = sym._assign(variable, variable - lr_t * (update_m / (sym.sqrt(update_v) \
+                         + self.epsilon) + self.wd * variable))
+            updates.append(update_var)
+        return sym.Group(updates)
diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py
index 0ac51f23fe9428a8ac6531bc18cc8e9fe60e3373..8b390e2cb72f6a59b7a4efb8e49f9ae6fcf1b3b9 100644
--- a/nnvm/python/nnvm/symbol.py
+++ b/nnvm/python/nnvm/symbol.py
@@ -1,4 +1,4 @@
-# pylint: disable=invalid-name, unused-import
+# pylint: disable=invalid-name, unused-import, protected-access
 """Symbolic graph construction API.
 
 This namespace contains most of the registered operators.
@@ -8,10 +8,12 @@ from __future__ import absolute_import as _abs
 import sys as _sys
 import os as _os
 import ctypes as _ctypes
-
 from numbers import Number as _Number
+
+import numpy as np
+
 from . import _base
-from ._base import _LIB, check_call as _check_call, _FFI_MODE
+from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init
 from .attribute import AttrScope
 from . import _symbol_internal as _internal
 
@@ -309,13 +311,19 @@ class Symbol(SymbolBase):
             self.handle, deps.handle))
 
 
-def Variable(name, **kwargs):
+def Variable(name, init=None, **kwargs):
     """Create a symbolic variable with specified name.
 
     Parameters
     ----------
     name : str
         Name of the variable.
+    init : Symbol or numpy.ndarray
+        Symbol or numpy ndarray of initial value for the variable.
+        Note that for symbolic initialization value, it must be able
+        to be defined through InferShape, such as sym.zeros_like(v),
+        in which v is an input or parameter. Otherwise, pass a numpy
+        ndarray instead.
     kwargs : dict of string -> string
         Additional attributes to set on the variable.
 
@@ -333,6 +341,11 @@ def Variable(name, **kwargs):
     attr = AttrScope.current.get(kwargs)
     if attr:
         ret._set_attr(**attr)
+    if init is not None:
+        if not isinstance(init, (Symbol, np.ndarray)):
+            raise TypeError('Expect a Symbol or numpy ndarray'
+                            'for variable `init`')
+        _all_var_init[name] = init
     return ret
 
 
diff --git a/nnvm/python/nnvm/top/attr_dict.py b/nnvm/python/nnvm/top/attr_dict.py
index 453c2971f73f33cebe55944cf8a67cd07e58fc21..a913a92552b28f2dc737eb16c60a9b725ad573af 100644
--- a/nnvm/python/nnvm/top/attr_dict.py
+++ b/nnvm/python/nnvm/top/attr_dict.py
@@ -123,6 +123,21 @@ class AttrDict(object):
         else:
             raise ValueError("Wrong bool format for key %s" % key)
 
+    def get_string(self, key):
+        """Get string from attr dict
+
+        Parameters
+        ----------
+        key : str
+            The attr key
+
+        Returns
+        -------
+        value : str
+            The result value
+        """
+        return self[key]
+
     def __repr__(self):
         return str({k : self[k] for k in self.keys()})
 
diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py
index 4f9d3ba9f36ffba717aabe795e5b192ff6a9a815..1e8688f9f2e69afb396f23db2d5e84e88f48dab8 100644
--- a/nnvm/python/nnvm/top/tensor.py
+++ b/nnvm/python/nnvm/top/tensor.py
@@ -143,3 +143,95 @@ reg.register_schedule("broadcast_div", _fschedule_broadcast)
 # broadcast_to
 reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_to", _fschedule_broadcast)
+
+# clip
+reg.register_pattern("clip", OpPattern.ELEMWISE)
+reg.register_schedule("clip", _fschedule_elemwise)
+
+# elemwise sum
+@reg.register_compute("elemwise_sum")
+def compute_elemwise_sum(attrs, inputs, _):
+    """Compute definition of elemwise sum"""
+    num_args = attrs.get_int("num_args")
+    assert num_args == len(inputs), "Number of tensors does not match num_args."
+    return topi.tensor.elemwise_sum(inputs, num_args)
+reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE)
+reg.register_schedule("elemwise_sum", _fschedule_elemwise)
+
+# full
+@reg.register_compute("full")
+def compute_full(attrs, inputs, _):
+    """Compute definition of full"""
+    shape = attrs.get_int_tuple("shape")
+    dtype = attrs.get_string("dtype")
+    fill_value = attrs.get_float("fill_value")
+    return topi.tensor.full(shape, dtype, fill_value)
+reg.register_pattern("full", OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_schedule("full", _fschedule_elemwise)
+
+# full_like
+@reg.register_compute("full_like")
+def compute_full_like(attrs, inputs, _):
+    """Compute definition of full_like"""
+    fill_value = attrs.get_float("fill_value")
+    return topi.tensor.full_like(inputs[0], fill_value)
+reg.register_pattern("full_like", OpPattern.ELEMWISE)
+reg.register_schedule("full_like", _fschedule_elemwise)
+
+# zeros
+@reg.register_compute("zeros")
+def compute_zeros(attrs, inputs, _):
+    """Compute definition of zeros"""
+    shape = attrs.get_int_tuple("shape")
+    dtype = attrs.get_string("dtype")
+    return topi.tensor.full(shape, dtype, 0)
+reg.register_pattern("zeros", OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_schedule("zeros", _fschedule_elemwise)
+
+# zeros_like
+@reg.register_compute("zeros_like")
+def compute_zeros_like(_, inputs, out_info):
+    """Compute definition of zeros_like"""
+    return topi.tensor.full_like(inputs[0], 0)
+reg.register_pattern("zeros_like", OpPattern.ELEMWISE)
+reg.register_schedule("zeros_like", _fschedule_elemwise)
+
+# ones
+@reg.register_compute("ones")
+def compute_ones(attrs, inputs, _):
+    """Compute definition of ones"""
+    shape = attrs.get_int_tuple("shape")
+    dtype = attrs.get_string("dtype")
+    #tvm.tensor.Tensor()
+    return topi.tensor.full(shape, dtype, 1)
+reg.register_pattern("ones", OpPattern.OUT_ELEMWISE_FUSABLE)
+reg.register_schedule("ones", _fschedule_elemwise)
+
+# ones_like
+@reg.register_compute("ones_like")
+def compute_ones_like(_, inputs, out_info):
+    """Compute definition of ones_like"""
+    return topi.tensor.full_like(inputs[0], 1)
+reg.register_pattern("ones_like", OpPattern.ELEMWISE)
+reg.register_schedule("ones_like", _fschedule_elemwise)
+
+# greater
+@reg.register_compute("greater")
+def compute_greater(_, inputs, out_info):
+    """Compute definition of greater"""
+    return topi.tensor.greater(inputs[0], inputs[1], 'float32')
+reg.register_pattern("greater", OpPattern.ELEMWISE)
+reg.register_schedule("greater", _fschedule_elemwise)
+
+# less
+@reg.register_compute("less")
+def compute_less(_, inputs, out_info):
+    """Compute definition of less"""
+    return topi.tensor.less(inputs[0], inputs[1], 'float32')
+reg.register_pattern("less", OpPattern.ELEMWISE)
+reg.register_schedule("less", _fschedule_elemwise)
+
+# block_grad
+reg.register_compute("block_grad", _compute_unary(topi.identity))
+reg.register_pattern("block_grad", OpPattern.ELEMWISE)
+reg.register_schedule("block_grad", _fschedule_elemwise)
diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py
index ec1596e9599bb99256d913aa167c6d05793e305c..c3ceb68682ee5a4f7dd38357663ef6c28e2379a3 100644
--- a/nnvm/python/nnvm/top/transform.py
+++ b/nnvm/python/nnvm/top/transform.py
@@ -2,6 +2,7 @@
 """Tensor transformation ops"""
 from __future__ import absolute_import
 
+import topi
 from .tensor import _fschedule_broadcast, _fschedule_injective
 from . import registry as reg
 from .registry import OpPattern
@@ -10,6 +11,32 @@ from .registry import OpPattern
 reg.register_pattern("expand_dims", OpPattern.BROADCAST)
 reg.register_schedule("expand_dims", _fschedule_broadcast)
 
+# expand_like
+@reg.register_compute("expand_like")
+def compute_expand_like(attrs, inputs, _):
+    """Compute definition of expand_like"""
+    exclude = attrs.get_bool("exclude")
+    axis = attrs.get_int_tuple("axis")
+    if exclude:
+        exclude_axis = (axis,) if isinstance(axis, int) else axis
+        axis = []
+        for item in range(len(inputs[1].shape)):
+            if item not in exclude_axis:
+                axis.append(item)
+        axis = tuple(axis)
+
+    return topi.transform.expand_like(inputs[0], inputs[1], axis)
+reg.register_pattern("expand_like", OpPattern.BROADCAST)
+reg.register_schedule("expand_like", _fschedule_broadcast)
+
+# reshape_like
+@reg.register_compute("reshape_like")
+def compute_reshape_like(attrs, inputs, out_info):
+    """Compute definition of reshape_like"""
+    return topi.reshape(inputs[0], inputs[1].shape)
+reg.register_pattern("reshape_like", OpPattern.INJECTIVE)
+reg.register_schedule("reshape_like", _fschedule_injective)
+
 # transpose
 reg.register_pattern("transpose", OpPattern.INJECTIVE)
 reg.register_schedule("transpose", _fschedule_injective)
diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index 79b4c8587a645da0ce287ba6d033fa17771c7f68..5ac4b7662ff5a155a11cf825d2c5e92359ce1251 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -130,15 +130,14 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
     // y = relu(x)
-    // grad = indicator(x > 0)
-    NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero",
+    // grad = indicator(x > 0) * ograd
+    NodeEntry sub0 = MakeNode("zeros_like", n->attrs.name + "_sub0",
                               {n->inputs[0]});
+    NodeEntry sub1 = MakeNode("greater", n->attrs.name + "_sub1",
+                              {n->inputs[0], sub0}, {{"exclude", "true"}});
     return std::vector<NodeEntry>{
-      MakeNode("elemwise_mul", n->attrs.name + "_grad", {
-        ograds[0],
-        MakeNode("greater", n->attrs.name + "_grad_mask",
-                 {n->inputs[0], zero}, {{"exclude", "true"}})
-      })
+      MakeNode("elemwise_mul", n->attrs.name + "_grad",
+               {ograds[0], sub1})
     };
 })
 .set_support_level(1);
@@ -358,23 +357,21 @@ NNVM_REGISTER_OP(log_softmax)
     // grad_x = sum(grad_x, keepdim, axis)
     // grad_x = neg grad_x
     // grad_x = grad_x + ones_like(grad_x)
-    // grad_x = expand_dims(grad_x, axis)
     const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed);
     NodeEntry output =  NodeEntry{n, 0, 0};
     NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output});
     NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0},
                               {{"axis", std::to_string(param.axis)}, {"keepdims", "true"}});
-    NodeEntry sub2 = MakeNode("negative", n->attrs.name + "_grad_sub2", {sub1});
-    NodeEntry sub3 = MakeNode("ones_like", n->attrs.name + "_grad_sub3", {sub2});
-    NodeEntry sub4 = MakeNode("elemwise_add", n->attrs.name + "_grad_sub4", {sub2, sub3});
+    NodeEntry sub2 = MakeNode("full_like", n->attrs.name + "_grad_sub2", {n->inputs[0]},
+                              {{"fill_value", "-1"}});
+    NodeEntry sub3 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub3", {sub1, sub2});
     return std::vector<NodeEntry> {
-      MakeNode("expand_like", n->attrs.name + "_grad", {sub4, output},
-               {{"axis", std::to_string(param.axis)}})
+      MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, ograds[0]})
     };
 })
 .set_support_level(1);
 
-// leaky_rlu
+// leaky_relu
 DMLC_REGISTER_PARAMETER(LeakyReLUParam);
 
 NNVM_REGISTER_OP(leaky_relu)
@@ -407,14 +404,15 @@ NNVM_REGISTER_OP(leaky_relu)
     NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero",
                               {n->inputs[0]});
     NodeEntry sub0 = MakeNode("greater", n->attrs.name + "_pos_grad",
-                              {n->inputs[0], zero}, {{"exclude", "true"}});
+                              {n->inputs[0], zero});
     NodeEntry sub1 = MakeNode("less", n->attrs.name + "_neg_grad",
-                              {n->inputs[0], zero}, {{"exclude", "true"}});
+                              {n->inputs[0], zero});
     NodeEntry sub2 = MakeNode("__mul_scalar__", n->attrs.name + "_neg_mul_2",
                               {sub1},
                               {{"scalar", std::to_string(param.alpha)}});
+    NodeEntry sub3 = MakeNode("elemwise_add", n->attrs.name + "_sub3", {sub0, sub2});
     return std::vector<NodeEntry>{
-      MakeNode("elemwise_add", n->attrs.name + "_add_grad", {sub0, sub2})
+      MakeNode("elemwise_mul", n->attrs.name + "_grad", {ograds[0], sub3})
     };
 })
 .set_support_level(1);
diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc
index d23d9c4f0a3f7ea9d701c26ba2f41147a22b00da..87fbf5823d9b2f19df805cbc0a4ff4f330953a77 100644
--- a/nnvm/src/top/tensor/elemwise.cc
+++ b/nnvm/src/top/tensor/elemwise.cc
@@ -190,7 +190,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
     // y = n0 + n1
     // grad_0 = grad_y
     // grad_1 = grad_y
-    return std::vector<NodeEntry>{ograds[0], ograds[0]};
+    return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
+                                            {ograds[0]}),
+                                   MakeNode("copy", n->attrs.name + "_grad_0",
+                                            {ograds[0]}) };
 });
 
 NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub)
@@ -311,7 +314,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
                   const std::vector<NodeEntry>& ograds){
     // y = copy(n0)
     // grad_0 = grad_y
-    return std::vector<NodeEntry>{ograds[0]};
+    return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
+                                            {ograds[0]}) };
 });
 
 DMLC_REGISTER_PARAMETER(InitOpParam);
@@ -329,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full)
 .add_arguments(InitOpWithScalarParam::__FIELDS__())
 .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
 .set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
-.set_support_level(1);
+.set_support_level(4);
 
 NNVM_REGISTER_INIT_OP(zeros)
 .describe(R"code(Fill target with zeros
@@ -341,7 +345,7 @@ NNVM_REGISTER_INIT_OP(zeros)
 .add_arguments(InitOpParam::__FIELDS__())
 .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
 .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
-.set_support_level(1);
+.set_support_level(4);
 
 NNVM_REGISTER_INIT_OP(ones)
 .describe(R"code(Fill target with ones
@@ -353,7 +357,7 @@ NNVM_REGISTER_INIT_OP(ones)
 .add_arguments(InitOpParam::__FIELDS__())
 .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
 .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
-.set_support_level(1);
+.set_support_level(4);
 
 // full_like
 NNVM_REGISTER_INIT_LIKE_OP(full_like)
@@ -364,21 +368,21 @@ as the input array
 .add_arguments(FillValueParam::__FIELDS__())
 .set_attr_parser(ParamParser<FillValueParam>)
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FillValueParam>)
-.set_support_level(1);
+.set_support_level(4);
 
 NNVM_REGISTER_INIT_LIKE_OP(zeros_like)
 .describe(R"code(Return an array of zeros with the same shape and type
 as the input array.
 
 )code")
-.set_support_level(1);
+.set_support_level(4);
 
 NNVM_REGISTER_INIT_LIKE_OP(ones_like)
 .describe(R"code(Return an array of ones with the same shape and type
 as the input array.
 
 )code")
-.set_support_level(1);
+.set_support_level(4);
 
 // unary scalar op
 DMLC_REGISTER_PARAMETER(ScalarParam);
@@ -415,7 +419,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__)
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds){
-    return std::vector<NodeEntry>{ograds[0]};
+    return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
+                                            {ograds[0]}) };
 });
 
 NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__)
@@ -601,10 +606,11 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum)
     CHECK_EQ(ograds.size(), 1);
     std::vector<NodeEntry> ret;
     for (size_t i = 0; i < n->inputs.size(); i++) {
-      ret.push_back(ograds[0]);
+      ret.push_back(MakeNode("copy", n->attrs.name + "_grad_0", {ograds[0]}));
     }
     return ret;
-  });
+  })
+.set_support_level(4);
 
 NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad)
 .describe(R"code(Blocks gradient computation for input.
@@ -614,7 +620,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad)
   "FInplaceIdentity", [](const NodeAttrs& attrs){
     return std::vector<bool>{true};
 })
-.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.set_support_level(4);
 
 DMLC_REGISTER_PARAMETER(IndicatorParam);
 
@@ -628,7 +635,7 @@ with 1.0 if (left > right), otherwise 0.0 element-wise.
 .add_argument("rhs", "Tensor", "Second input")
 .set_num_inputs(2)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
-.set_support_level(1);
+.set_support_level(4);
 
 
 NNVM_REGISTER_INDICATOR_OP(less)
@@ -640,7 +647,7 @@ with 1.0 if (left < right), otherwise 0.0 element-wise.
 .add_argument("rhs", "Tensor", "Second input")
 .set_num_inputs(2)
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
-.set_support_level(1);
+.set_support_level(4);
 
 NNVM_REGISTER_INDICATOR_OP(_max_mask)
   .describe(R"code(Function that returns a mask tensor
@@ -668,5 +675,73 @@ with 1.0 if the value is minimum over given axes, otherwise 0.0 element-wise.
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_support_level(1);
 
+
+DMLC_REGISTER_PARAMETER(ClipParam);
+
+NNVM_REGISTER_OP(clip)
+.describe(R"doc(Clips (limits) the values in an array.
+Given an interval, values outside the interval are clipped to the interval edges.
+Clipping ``x`` between `a_min` and `a_x` would be::
+   clip(x, a_min, a_max) = max(min(x, a_max), a_min))
+Example::
+    x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+    clip(x,1,8) = [ 1.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  8.]
+)doc" NNVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<ClipParam>)
+.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ClipParam params = get<ClipParam>(attrs.parsed);
+    return Array<Tensor>{
+      topi::clip(inputs[0], tvm::make_const(tvm::Float(32), params.a_min),
+                 tvm::make_const(tvm::Float(32), params.a_max)) };
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input array.")
+.add_arguments(ClipParam::__FIELDS__())
+.set_attr<nnvm::FGradient>(
+  "FGradient", [](const NodePtr& n,
+                  const std::vector<NodeEntry>& ograds){
+    // y = clip(x, a_min, a_max)
+    // min_mask = greater_equal(x, a_min*ones_like(x))
+    //          => ones_like(x) - less(x, a_min)
+    // max_mask = less_equal(x, a_max*ones_like(x))
+    //          => ones_like(x) - greater(x, a_max)
+    // grad_x = min_mask * max_mask * grad_y
+    CHECK_EQ(ograds.size(), 1);
+
+    NodeEntry sub0 = MakeNode("ones_like", n->attrs.name + "_grad_sub_0",
+                              {n->inputs[0]});
+    // min_mask
+    NodeEntry sub1 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_1",
+                              {sub0}, {{"scalar", n->attrs.dict["a_min"]}});
+    NodeEntry sub2 = MakeNode("less", n->attrs.name + "_grad_sub_2",
+                              {n->inputs[0], sub1});
+    NodeEntry sub3 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_3",
+                              {sub0, sub2});
+
+    // max_mask
+    NodeEntry sub4 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_4",
+                              {sub0}, {{"scalar", n->attrs.dict["a_max"]}});
+    NodeEntry sub5 = MakeNode("greater", n->attrs.name + "_grad_sub_5",
+                              {n->inputs[0], sub4});
+    NodeEntry sub6 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_6",
+                              {sub0, sub5});
+
+    // min_mask * max_mask
+    NodeEntry sub7 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_7",
+                              {sub3, sub6});
+    return std::vector<NodeEntry>{
+      MakeNode("elemwise_mul", n->attrs.name + "_grad",
+               {sub7, ograds[0]})
+    };
+  })
+.set_support_level(4);
+
 }  // namespace top
 }  // namespace nnvm
diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc
index 8eac2449b27187eae64e39f15e8984b2eb4cb804..84a7dd0f0e12f6de7245cdf0824e65714e46ea97 100644
--- a/nnvm/src/top/tensor/reduce.cc
+++ b/nnvm/src/top/tensor/reduce.cc
@@ -137,7 +137,20 @@ Example::
                     const Array<Tensor>& inputs,
                     const Array<Tensor>& out_info) {
     const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
-    auto axis = ShapeToArray(param.axis);
+    Array<Expr> axis;
+    if (param.exclude) {
+      std::set<dim_t> exclude_axis;
+      for (dim_t i = 0; i < param.axis.ndim(); ++i) {
+        exclude_axis.insert(param.axis[i]);
+      }
+      for (dim_t i = 0; i < inputs[0].ndim(); ++i) {
+        if (exclude_axis.count(i) == 0) {
+          axis.push_back(make_const(Int(32), i));
+        }
+      }
+    } else {
+      axis = ShapeToArray(param.axis);
+    }
     return Array<Tensor>{
       topi::sum(inputs[0], axis, param.keepdims) };
 })
@@ -150,7 +163,6 @@ Example::
       MakeNode("expand_like", n->attrs.name + "_grad",
                {ograds[0], n->inputs[0]},
                {{"axis", axis.str()},
-                {"keepdims", std::to_string(param.keepdims)},
                 {"exclude", std::to_string(param.exclude)}})
   };
 });
diff --git a/nnvm/src/top/tensor/state_op.cc b/nnvm/src/top/tensor/state_op.cc
index f275cf309511b79c8a74b4ed1eadad001136697b..ebce07696fe4420de91c6bfe59b057621706d833 100644
--- a/nnvm/src/top/tensor/state_op.cc
+++ b/nnvm/src/top/tensor/state_op.cc
@@ -48,6 +48,15 @@ This is an experimental operator.
 .set_attr<FInplaceOption>(
   "FInplaceOption", [](const NodeAttrs& attrs) {
     return std::vector<std::pair<int, int> >{{1, 0}};
+})
+.set_attr<FGradient>(
+  "FGradient", [](const NodePtr& n,
+                  const std::vector<NodeEntry>& ograds){
+    return std::vector<NodeEntry>{
+      MakeNode("zeros_like", n->attrs.name + "_zero_grad",
+               {n->inputs[0]}),
+      ograds[0]
+    };
 });
 
 }  // namespace top
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index dcdedbb867f90cfe9dd7a918093870f63d9455fa..2457747341c933f7c7037c04b158748a44ec2589 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -229,29 +229,24 @@ will return a new array with shape ``(2,5,3,4)``.
 
 NNVM_REGISTER_OP(expand_like)
   .describe(R"code(Expand an input array with the shape of second array.
-
 This operation can always be composed of unsqueezing and expanding dims.
-
 Examples::
   input = [ 12.  19.  27.]
   input.shape = (3,)
-
   new_shape_array = [[[1,2],[2,3],[1,3]],
                      [[1,4],[4,3],[5,2]],
                      [[7,1],[7,2],[7,3]]]
   new_shape_array.shape = (3, 3, 2)
-
   expand_like(input, [1,2], new_shape_array) =
                     [[[12,12],[12,12],[12,12]],
                      [[19,19],[19,19],[19,19]],
                      [[27,27],[27,27],[27,27]]]
-
 )code" NNVM_ADD_FILELINE)
 .add_argument("input", "Tensor", "Source input")
 .add_argument("shape_like", "Tensor", "Input with new shape")
-.add_arguments(ReduceParam::__FIELDS__())
-.set_attr_parser(ParamParser<ReduceParam>)
-.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>)
+.add_arguments(IndicatorParam::__FIELDS__())
+.set_attr_parser(ParamParser<IndicatorParam>)
+.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
 .set_num_inputs(2)
@@ -259,7 +254,7 @@ Examples::
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
-    const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed);
+    const IndicatorParam& param = nnvm::get<IndicatorParam>(n->attrs.parsed);
     std::ostringstream axis;
     axis << param.axis;
 
@@ -267,11 +262,11 @@ Examples::
       MakeNode("sum", n->attrs.name + "_grad",
                {ograds[0]},
                {{"axis", axis.str()},
-                {"keepdims", std::to_string(param.keepdims)},
-                {"exclude", std::to_string(param.exclude)}})
+                {"exclude", std::to_string(param.exclude)}}),
+      MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]})
     };
-})
-.set_support_level(1);
+  })
+  .set_support_level(4);
 
 // split
 DMLC_REGISTER_PARAMETER(SplitParam);
@@ -564,13 +559,10 @@ The significance of each is explained below:
 
 NNVM_REGISTER_OP(reshape_like)
   .describe(R"code(Reshapes the input array by the size of another array.
-
 For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
 the input array into an output array with the same shape as the second input array.
-
 .. note::
     Sizes for both array should be compatible.
-
 )code" NNVM_ADD_FILELINE)
 .add_argument("data", "Tensor", "Input data.")
 .add_argument("shape_like", "Tensor", "Input data.")
@@ -589,10 +581,12 @@ the input array into an output array with the same shape as the second input arr
 .set_attr<FGradient>(
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
-    return MakeGradNode("reshape_like", n,
-                        {ograds[0], n->inputs[0]});
+    return std::vector<NodeEntry>{
+      MakeNode("reshape_like", n->attrs.name + "_grad", {ograds[0], n->inputs[0]}),
+      MakeNode("zeros_like", n->attrs.name + "_zero_grad", { n->inputs[1]})
+    };
 })
-.set_support_level(3);
+.set_support_level(4);
 
 // squeeze
 DMLC_REGISTER_PARAMETER(SqueezeParam);
@@ -680,7 +674,8 @@ Examples::
   "FGradient", [](const NodePtr& n,
                   const std::vector<NodeEntry>& ograds) {
     return std::vector<NodeEntry>{
-      MakeNode("reshape_like", n->attrs.name + "_grad", {n->inputs[0]})
+      MakeNode("reshape_like", n->attrs.name + "_grad",
+               {ograds[0], n->inputs[0]})
     };
 })
 .set_support_level(1);
diff --git a/nnvm/tests/python/compiler/test_optimizer.py b/nnvm/tests/python/compiler/test_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd620271d8615cd1d11ed92e713a91502f10d70a
--- /dev/null
+++ b/nnvm/tests/python/compiler/test_optimizer.py
@@ -0,0 +1,118 @@
+import numpy as np
+import tvm
+import nnvm
+import nnvm.compiler.optimizer as optimizer
+import nnvm.compiler.lr_scheduler as lr_scheduler
+
+from nnvm.testing.config import ctx_list
+from tvm.contrib import graph_runtime
+
+
+def helper(symbol, inputs, params, update_func, run_times, target, ctx, dtype="float32"):
+    ishapes = {}
+    np_inputs = {}
+    params_dict = {}
+    for (name, shape, s) in inputs:
+        ishapes.update({name: shape})
+        np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
+    for (name, shape, s) in params:
+        np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
+        params_dict.update({name: np_inputs[name]})
+
+    graph, lib, rt_params = nnvm.compiler.build(symbol, target, shape=ishapes)
+    m = graph_runtime.create(graph, lib, ctx)
+    m.set_input(**np_inputs)
+    m.set_input(**rt_params)
+    for _ in range(run_times):
+        m.run()
+    y_np = update_func(**np_inputs)
+    out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
+    np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
+
+
+def test_sgd():
+    for target, ctx in ctx_list():
+        data = nnvm.sym.Variable("data")
+        weight = nnvm.sym.Variable("weight")
+        out = nnvm.sym.elemwise_mul(data, weight ** 2)
+
+        dshape = (1, 2, 3)
+        wshape = dshape
+
+        base_lr = 0.1
+        lr_factor = 0.5
+        rescale_grad = 0.2
+        wd = 0.1
+        clip_gradient = 0.25
+
+        scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor)
+        opt = optimizer.SGD(learning_rate=base_lr, lr_scheduler=scheduler,
+                            rescale_grad=rescale_grad, clip_gradient=clip_gradient,
+                            wd=wd)
+        opt_sym = opt.minimize(out, var=weight)
+
+        inputs = [("data", dshape, data)]
+        params = [("weight", wshape, weight)]
+
+        def update_func(data, weight):
+            gradient_0 = data * 2 * weight * rescale_grad
+            gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient)
+            weight_0 = weight - base_lr * lr_factor * (gradient_0 + wd * weight)
+            gradient_1 = data * 2 * weight_0 * rescale_grad
+            gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient)
+            weight_1 = weight_0 - base_lr * (lr_factor ** 2) * (gradient_1 + wd * weight_0)
+            return weight_1
+
+        helper(opt_sym, inputs, params, update_func, 2, target, ctx)
+
+
+
+def test_adam():
+    for target, ctx in ctx_list():
+        data = nnvm.sym.Variable("data")
+        weight = nnvm.sym.Variable("weight")
+        out = nnvm.sym.elemwise_mul(data, weight ** 2)
+
+        dshape = (1, 2, 3)
+        wshape = dshape
+
+        base_lr = 0.1
+        beta1 = 0.9
+        beta2 = 0.999
+        epsilon = 1e-8
+        lr_factor = 0.5
+        rescale_grad = 0.2
+        wd = 0.1
+        clip_gradient = 0.25
+
+        scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor)
+        opt = optimizer.Adam(learning_rate=base_lr, beta1=beta1, beta2=beta2, epsilon=epsilon,
+                             lr_scheduler=scheduler, rescale_grad=rescale_grad,
+                             clip_gradient=clip_gradient, wd=wd)
+        opt_sym = opt.minimize(out, var=weight)
+
+        inputs = [("data", dshape, data)]
+        params = [("weight", wshape, weight)]
+
+        def update_func(data, weight):
+            rate_0 = np.sqrt(1 - beta2) / (1 - beta1)
+            lr_0 = base_lr * lr_factor * rate_0
+            gradient_0 = data * 2 * weight * rescale_grad
+            gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient)
+            m_0 = (1 - beta1) * gradient_0
+            v_0 = (1 - beta2) * (gradient_0 ** 2)
+            weight_0 = weight - lr_0 * (m_0 / (np.sqrt(v_0) + epsilon) + wd * weight)
+            rate_1 = np.sqrt(1 - beta2 ** 2) / (1 - beta1 ** 2)
+            lr_1 = base_lr * (lr_factor ** 2) * rate_1
+            gradient_1 = data * 2 * weight_0 * rescale_grad
+            gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient)
+            m_1 = beta1 * m_0 + (1 - beta1) * gradient_1
+            v_1 = beta2 * v_0 + (1 - beta2) * (gradient_1 ** 2)
+            weight_1 = weight_0 - lr_1 * (m_1 / (np.sqrt(v_1) + epsilon) + wd * weight_0)
+            return weight_1
+
+        helper(opt_sym, inputs, params, update_func, 2, target, ctx)
+
+if __name__ == "__main__":
+    test_sgd()
+    test_adam()
diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py
index abe30d1ec8d026c4e84891caf65ac322453f57c1..480aa271af8f0f317593a1dc5c836038f66f23b9 100644
--- a/nnvm/tests/python/compiler/test_top_level1.py
+++ b/nnvm/tests/python/compiler/test_top_level1.py
@@ -8,15 +8,14 @@ from nnvm.testing.config import ctx_list
 
 
 def helper(symbol, inputs, dtype,
-           np_forward, np_backward=None):
+           np_forward, np_backward=None, need_input=True, need_head_grads=True):
     ishapes = {}
     input_syms = []
     np_inputs = {}
-    for (k, v) in inputs.items():
-        ishapes.update({k: v[0]})
-        np_inputs.update({k: np.random.uniform(size=v[0]).astype(dtype)})
-        if len(v) > 1:
-            input_syms.append(v[1])
+    for (name, shape, s) in inputs:
+        ishapes.update({name: shape})
+        np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
+        input_syms.append(s)
 
     for target, ctx in ctx_list():
         graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
@@ -25,23 +24,26 @@ def helper(symbol, inputs, dtype,
         y_np = np_forward(**np_inputs)
         out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
         np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
-
         # backward
         if np_backward:
             graph._set_symbol_list_attr("grad_ys", symbol)
-            for x in input_syms:
-                graph._set_symbol_list_attr("grad_xs", x)
-            graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads"))
+            graph._set_symbol_list_attr("grad_xs", input_syms)
+            graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape))
             graph = graph.apply("Gradient")
             ishapes.update({"head_grads": y_np.shape})
             graph, lib, _ = nnvm.compiler.build(graph, target, ishapes)
             m = graph_runtime.create(graph, lib, ctx)
             head_grads = np.random.uniform(size=y_np.shape).astype(dtype)
-            y_np = head_grads * np_backward(**np_inputs)
-            m.run(head_grads=head_grads, **np_inputs)
-            out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
-
-            np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
+            y_np = np_backward(head_grads=head_grads, **np_inputs)
+            b_inputs = {}
+            if need_input:
+                b_inputs.update(np_inputs)
+            if need_head_grads:
+                b_inputs.update({"head_grads":head_grads})
+            m.run(**b_inputs)
+            for i in range(len(y_np)):
+                out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype))
+                np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5)
 
 
 def test_relu():
@@ -52,10 +54,15 @@ def test_relu():
         x = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
         return (x > 0) * x
 
+    def backward(head_grads, x):
+        sub = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
+        return [(sub > 0).astype("float") * \
+                ((x > 0).astype("float") + 0.3 * (x < 0).astype("float")) * head_grads]
+
     dtype = "float32"
     dshape = (1, 3, 32, 32)
-    inputs = {'x': (dshape, x)}
-    helper(y, inputs, dtype, forward)
+    inputs = [('x', dshape, x)]
+    helper(y, inputs, dtype, forward, backward)
 
 
 def test_sym_scalar_pow():
@@ -66,12 +73,12 @@ def test_sym_scalar_pow():
     def forward(x):
         return x**scalar
 
-    def backward(x):
-        return scalar * x**(scalar -  1)
+    def backward(head_grads, x):
+        return [scalar * x**(scalar -  1) * head_grads]
 
     dtype = "float32"
     dshape = (1, 3, 32, 32)
-    inputs = {'x': (dshape, x)}
+    inputs = [('x', dshape, x)]
     helper(y, inputs, dtype, forward, backward)
 
 
@@ -83,12 +90,12 @@ def test_scalar_sym_pow():
     def forward(x):
         return scalar**x
 
-    def backward(x):
-        return np.log(scalar) * scalar**x
+    def backward(head_grads, x):
+        return [np.log(scalar) * scalar**x * head_grads]
 
     dtype = "float32"
     dshape = (1, 3, 32, 32)
-    inputs = {'x': (dshape, x)}
+    inputs = [('x', dshape, x)]
     helper(y, inputs, dtype, forward, backward)
 
 
@@ -99,12 +106,12 @@ def test_exp():
     def forward(x):
         return np.exp(x)
 
-    def backward(x):
-        return np.exp(x)
+    def backward(head_grads, x):
+        return [np.exp(x) * head_grads]
 
     dtype = "float32"
     dshape = (1, 3, 32, 32)
-    inputs = {'x': (dshape, x)}
+    inputs = [('x', dshape, x)]
     helper(y, inputs, dtype, forward, backward)
 
 
@@ -115,12 +122,12 @@ def test_log():
     def forward(x):
         return np.log(x)
 
-    def backward(x):
-        return 1. / x
+    def backward(head_grads, x):
+        return [1. / x * head_grads]
 
     dtype = "float32"
     dshape = (1, 3, 32, 32)
-    inputs = {'x': (dshape, x)}
+    inputs = [('x', dshape, x)]
     helper(y, inputs, dtype, forward, backward)
 
 
@@ -131,13 +138,13 @@ def test_tanh():
     def forward(x):
         return np.sinh(x) / np.cosh(x)
 
-    def backward(x):
+    def backward(head_grads, x):
         y_np = forward(x)
-        return (1 - y_np**2)
+        return [(1 - y_np**2) * head_grads]
 
     dtype = "float32"
     dshape = (1, 3, 32, 32)
-    inputs = {'x': (dshape, x)}
+    inputs = [('x', dshape, x)]
     helper(y, inputs, dtype, forward, backward)
 
 
@@ -148,13 +155,13 @@ def test_sigmoid():
     def forward(x):
         return 1.0 / (1.0 + np.exp(-x))
 
-    def backward(x):
+    def backward(head_grads, x):
         y_np = forward(x)
-        return y_np *(1 - y_np)
+        return [y_np *(1 - y_np) * head_grads]
 
     dtype = "float32"
     dshape = (1, 3, 32, 32)
-    inputs = {'x': (dshape, x)}
+    inputs = [('x', dshape, x)]
     helper(y, inputs, dtype, forward, backward)
 
 
@@ -165,10 +172,15 @@ def test_softmax():
     def forward(x):
         return topi.testing.softmax_python(x)
 
+    def backward(head_grads, x):
+        y = topi.testing.softmax_python(x)
+        grad = y * (head_grads - np.sum(y * head_grads, axis=1, keepdims=True))
+        return [grad]
+
     dtype = "float32"
     dshape = (10, 1000)
-    inputs = {'x': (dshape, x)}
-    helper(y, inputs, dtype, forward)
+    inputs = [('x', dshape, x)]
+    helper(y, inputs, dtype, forward), backward
 
 
 def test_log_softmax():
@@ -178,26 +190,32 @@ def test_log_softmax():
     def forward(x):
         return topi.testing.log_softmax_python(x)
 
+    def backward(head_grads, x):
+        y = topi.testing.log_softmax_python(x)
+        grad = head_grads - np.sum(y * head_grads, axis=1, keepdims=True)
+        return [grad]
+
     dtype = "float32"
     dshape = (10, 1000)
-    inputs = {'x': (dshape, x)}
-    helper(y, inputs, dtype, forward)
+    inputs = [('x', dshape, x)]
+    helper(y, inputs, dtype, forward, backward)
 
 
 def test_dense():
-    x = sym.Variable("x")
-    y = sym.dense(x, units=3, name="dense")
+    x = sym.Variable("x", shape=(10, 100))
+    w = sym.Variable("dense_weight", shape=(3, 100))
+    b = sym.Variable("dense_bias", shape=(3,))
+    y = sym.dense(x, w, b, use_bias=True, units=3, name="dense")
     y = sym.flatten(y)
 
     def forward(x, dense_weight, dense_bias):
         return np.dot(x, dense_weight.T) + dense_bias
-
     dtype = "float32"
-    inputs = {
-        'x': ((10, 100), x),
-        'dense_weight': ((3, 100),),
-        'dense_bias': ((3,),)
-    }
+    inputs = [
+        ('x', (10, 100), x),
+        ('dense_weight', (3, 100), w),
+        ('dense_bias', (3,), b)
+    ]
     helper(y, inputs, dtype, forward)
 
 
@@ -215,13 +233,13 @@ def test_batchnorm():
         return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta
 
     dtype = "float32"
-    inputs = {
-        'x': ((10, 20), x),
-        'gamma': ((20,),),
-        'beta': ((20,),),
-        'moving_mean': ((20,),),
-        'moving_var': ((20,),)
-    }
+    inputs = [
+        ('x', (10, 20), x),
+        ('gamma', (20,), gamma),
+        ('beta', (20,), beta),
+        ('moving_mean', (20,), moving_var),
+        ('moving_var', (20,), moving_mean)
+    ]
 
     helper(y, inputs,  dtype, forward)
 
@@ -283,9 +301,12 @@ def verify_squeeze(dshape, axis):
     def forward(x):
         return np.squeeze(x, axis=axis) + 1
 
+    def backward(head_grads, x):
+        return [np.reshape(head_grads, x.shape)]
+
     dtype = "float32"
-    inputs = {'x': (dshape, x)}
-    helper(y, inputs, dtype, forward)
+    inputs = [('x', dshape, x)]
+    helper(y, inputs, dtype, forward, backward)
 
 
 def test_squeeze():
@@ -304,7 +325,7 @@ def test_pad():
                       mode='constant', constant_values=1.)
 
     dtype = "float32"
-    inputs = {'x': ((1, 3, 28, 28), x)}
+    inputs = [('x', (1, 3, 28, 28), x)]
     helper(y, inputs, dtype, forward)
 
 
diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py
index ad09d73ec28dadeba199eed8c087ea32b9d35572..c6e8620fc972ddac6d9aa19410de290bcbf81284 100644
--- a/nnvm/tests/python/compiler/test_top_level4.py
+++ b/nnvm/tests/python/compiler/test_top_level4.py
@@ -6,6 +6,46 @@ import nnvm.symbol as sym
 import nnvm.compiler
 from nnvm.testing.config import ctx_list
 
+
+def helper(symbol, inputs, dtype,
+           np_forward, np_backward=None, need_input=True, need_head_grads=True):
+    ishapes = {}
+    input_syms = []
+    np_inputs = {}
+    for (name, shape, s) in inputs:
+        ishapes.update({name: shape})
+        np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
+        input_syms.append(s)
+
+    for target, ctx in ctx_list():
+        graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run(**np_inputs)
+        y_np = np_forward(**np_inputs)
+        out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
+        np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
+        # backward
+        if np_backward:
+            graph._set_symbol_list_attr("grad_ys", symbol)
+            graph._set_symbol_list_attr("grad_xs", input_syms)
+            graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape))
+            graph = graph.apply("Gradient")
+            ishapes.update({"head_grads": y_np.shape})
+            graph, lib, _ = nnvm.compiler.build(graph, target, ishapes)
+            m = graph_runtime.create(graph, lib, ctx)
+            head_grads = np.random.uniform(size=y_np.shape).astype(dtype)
+            y_np = np_backward(head_grads=head_grads, **np_inputs)
+            b_inputs = {}
+            if need_input:
+                b_inputs.update(np_inputs)
+            if need_head_grads:
+                b_inputs.update({"head_grads":head_grads})
+            m.run(**b_inputs)
+            for i in range(len(y_np)):
+                out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype))
+                np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5)
+
+
 def verify_transpose(dshape, axes):
     x = sym.Variable("x")
     if axes:
@@ -66,13 +106,245 @@ def verify_reshape(dshape, oshape):
         out = m.get_output(0, tvm.nd.empty(out_np.shape))
         np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
 
+
 def test_reshape():
     verify_reshape((2, 3, 4), (-1, 2, 1))
     verify_reshape((2, 3, 4), (8, 3))
     verify_reshape((4, 7), (2, 7, 2))
 
+
+def test_clip():
+    x = sym.Variable("x")
+    a_min=0.2
+    a_max=0.75
+    y = sym.clip(x, a_min=a_min, a_max=a_max)
+
+    def forward(x):
+        return np.clip(x, a_min=a_min, a_max=a_max)
+
+    def backward(head_grads, x):
+        mask1 = np.greater_equal(x, a_min).astype("float")
+        mask2 = np.less_equal(x, a_max).astype("float")
+        return [head_grads * mask1 * mask2]
+
+
+    dtype = "float32"
+    inputs = [('x', (3, 4, 5), x)]
+    helper(y, inputs, dtype, forward, backward)
+
+
+def test_greater():
+    l = sym.Variable("l")
+    r = sym.Variable("r")
+    y = sym.greater(l, r)
+
+    def forward(l, r):
+        return np.greater(l, r).astype("float32")
+
+    def backward(head_grads, l, r):
+        return [np.zeros_like(l)]
+
+
+    dtype = "float32"
+    inputs = [('l', (3, 4, 5), l),
+              ('r', (3, 4, 5), r)]
+    helper(y, inputs, dtype, forward, backward, need_head_grads=False)
+
+
+def test_less():
+    l = sym.Variable("l")
+    r = sym.Variable("r")
+    y = sym.less(l, r)
+
+    def forward(l, r):
+        return np.less(l, r).astype("float32")
+
+    def backward(head_grads, l, r):
+        return [np.zeros_like(l)]
+
+
+    dtype = "float32"
+    inputs = [('l', (3, 4, 5), l),
+              ('r', (3, 4, 5), r)]
+    helper(y, inputs, dtype, forward, backward, need_head_grads=False)
+
+
+def test_reshape_like():
+    x = sym.Variable("x")
+    y = sym.Variable("y")
+    z = sym.reshape_like(x, y)
+
+    def forward(x, y):
+        return np.reshape(x, y.shape)
+
+    def backward(head_grads, x, y):
+        return [np.reshape(head_grads, x.shape),
+                np.zeros_like(y)]
+
+
+    dtype = "float32"
+    inputs = [('x', (3, 4, 5), x),
+              ('y', (5, 4, 3), y)]
+    helper(z, inputs, dtype, forward, backward)
+
+
+def verify_expand_like(in_shape, out_shape, axis, exclude):
+    x = sym.Variable("x")
+    y = sym.Variable("y")
+    z = sym.expand_like(x, y, axis=axis, exclude=exclude)
+
+    def forward(x, y):
+        odim = len(out_shape)
+        real_axis = [i if i >= 0 else i + odim for i in axis]
+        real_axis = sorted(real_axis)
+        if exclude:
+            real_axis = list(set(range(odim)) - set(real_axis))
+        for i in real_axis:
+            x = np.expand_dims(x, i).astype(x.dtype)
+        for i in real_axis:
+            x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype)
+
+        return x
+
+    def backward(head_grads, x, y):
+        odim = len(out_shape)
+        real_axis = [i if i >= 0 else i + odim for i in axis]
+        real_axis = sorted(real_axis)
+        if exclude:
+            real_axis = list(set(range(odim)) - set(real_axis))
+        return [np.sum(head_grads, axis=tuple(real_axis)),
+                np.zeros_like(y)]
+
+
+    dtype = "float32"
+    inputs = [('x', in_shape, x),
+              ('y', out_shape, y)]
+    helper(z, inputs, dtype, forward, backward, need_input=False)
+
+
+def test_expand_like():
+    verify_expand_like((3,), (3, 2), [1], False)
+    verify_expand_like((2,), (2, 3), [1], False)
+    verify_expand_like((3, 4), (3, 5, 4), [1], False)
+    verify_expand_like((5, 7), (5, 6, 7, 8), [0, 2], True)
+
+
+def verify_elemwise_sum(num_args):
+    s = [sym.Variable("input" + str(i)) for i in range(num_args)]
+    y = sym.elemwise_sum(*s, num_args=num_args)
+
+    def forward(**inputs):
+        return np.sum(np.array(list(inputs.values())), axis=0)
+
+    def backward(head_grads, **inputs):
+        return [head_grads] * num_args
+
+    dtype = "float32"
+    inputs = [("input" + str(i), (3, 4, 5), s[i])
+              for i in range(num_args)]
+    helper(y, inputs, dtype, forward, backward, need_input=False)
+
+
+def test_elemwise_sum():
+    verify_elemwise_sum(1)
+    verify_elemwise_sum(5)
+    verify_elemwise_sum(7)
+
+
+def test_block_grad():
+    x = sym.Variable("x")
+    y = sym.block_grad(x)
+
+    def forward(x):
+        return x
+
+    def backward(head_grads, x):
+        return [np.zeros_like(head_grads)]
+
+
+    dtype = "float32"
+    inputs = [('x', (3, 4, 5), x)]
+    helper(y, inputs, dtype, forward, backward, need_head_grads=False)
+
+
+def test_full():
+    shape = (3, 4, 5)
+    value = 7
+    dtype = "float32"
+    for target, ctx in ctx_list():
+        data = sym.Variable("data", dtype=dtype)
+        # full_like
+        s = sym.full_like(data=data, fill_value=value, name="s")
+        graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run(data=np.random.uniform(size=shape).astype(dtype))
+        out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
+        np.testing.assert_allclose(
+            out.asnumpy(),
+            np.full(shape, fill_value=value, dtype=dtype),
+            atol=1e-5, rtol=1e-5)
+        # ones_like
+        s = sym.ones_like(data=data, fill_value=value, name="s")
+        graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run(data=np.random.uniform(size=shape).astype(dtype))
+        out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
+        np.testing.assert_allclose(
+            out.asnumpy(),
+            np.full(shape, fill_value=1, dtype=dtype),
+            atol=1e-5, rtol=1e-5)
+        # zeros_like
+        s = sym.zeros_like(data=data, fill_value=value, name="s")
+        graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run(data=np.random.uniform(size=shape).astype(dtype))
+        out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
+        np.testing.assert_allclose(
+            out.asnumpy(),
+            np.full(shape, fill_value=0, dtype=dtype),
+            atol=1e-5, rtol=1e-5)
+        # full
+        s = sym.full(shape=shape, dtype=dtype, fill_value=value, name="s")
+        graph, lib, _ = nnvm.compiler.build(s, target)
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run()
+        out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
+        np.testing.assert_allclose(
+            out.asnumpy(),
+            np.full(shape, fill_value=value, dtype=dtype),
+            atol=1e-5, rtol=1e-5)
+        # ones
+        s = sym.ones(shape=shape, dtype=dtype, name="s")
+        graph, lib, _ = nnvm.compiler.build(s, target)
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run()
+        out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
+        np.testing.assert_allclose(
+            out.asnumpy(),
+            np.full(shape, fill_value=1, dtype=dtype),
+            atol=1e-5, rtol=1e-5)
+        # zeros
+        s = sym.zeros(shape=shape, dtype=dtype, name="s")
+        graph, lib, _ = nnvm.compiler.build(s, target)
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run()
+        out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
+        np.testing.assert_allclose(
+            out.asnumpy(),
+            np.full(shape, fill_value=0, dtype=dtype),
+            atol=1e-5, rtol=1e-5)
+
+
 if __name__ == "__main__":
     test_reshape()
     test_reduce()
     test_tranpose()
+    test_clip()
+    test_greater()
+    test_less()
+    test_reshape_like()
+    test_expand_like()
+    test_elemwise_sum()
+    test_block_grad()
+    test_full()
     print(nnvm.compiler.engine.dump())