From 58398d38620fc984a3f8143f68b2b5de4b8dcb1e Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 4 Dec 2018 12:54:01 -0800
Subject: [PATCH] Port from_nnvm to NNVM as to_relay (#2144)

---
 nnvm/python/nnvm/to_relay.py                | 506 ++++++++++++++++++++
 nnvm/tests/python/compiler/test_to_relay.py |  41 ++
 python/tvm/relay/frontend/common.py         |  55 ++-
 python/tvm/relay/frontend/mxnet.py          | 137 +-----
 python/tvm/relay/frontend/nnvm_common.py    | 132 +++++
 python/tvm/relay/op/_transform.py           |   1 +
 python/tvm/relay/op/nn/_nn.py               |   7 +-
 src/relay/backend/graph_plan_memory.cc      |   3 +
 src/relay/ir/alpha_equal.cc                 |  10 +-
 src/relay/op/nn/upsampling.cc               |  48 +-
 tests/python/relay/frontend/test_keras.py   | 332 +++++++++++++
 topi/include/topi/image/resize.h            |   3 +-
 12 files changed, 1116 insertions(+), 159 deletions(-)
 create mode 100644 nnvm/python/nnvm/to_relay.py
 create mode 100644 nnvm/tests/python/compiler/test_to_relay.py
 create mode 100644 python/tvm/relay/frontend/nnvm_common.py
 create mode 100644 tests/python/relay/frontend/test_keras.py

diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py
new file mode 100644
index 000000000..318ff1ee9
--- /dev/null
+++ b/nnvm/python/nnvm/to_relay.py
@@ -0,0 +1,506 @@
+# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument
+"""Convert an NNVM graph to Relay."""
+import json
+from tvm import relay, nd
+from tvm.relay import op, expr, var
+from tvm.relay.frontend.common import StrAttrsDict
+from tvm.relay.frontend.nnvm_common import _rename
+import numpy
+from .symbol import Symbol
+from .compiler import graph_attr
+from .graph import create as graph_create
+
+def _nn_batch_flatten(children, attrs, odtype='float32'):
+    assert len(children) == 1
+    return op.nn.batch_flatten(children[0])
+
+
+def _dense(children, attrs, odtype='float32'):
+    use_bias = attrs.get_bool('use_bias', True)
+    units = attrs.get_int('units')
+    dense = op.nn.dense(children[0], children[1], units=units)
+    if use_bias:
+        return op.nn.bias_add(dense, children[2])
+    else:
+        return dense
+
+def _nn_softmax(children, attrs, odtype='float32'):
+    assert len(children) == 1
+    axis = attrs.get_int('axis', 1)
+    return op.nn.softmax(children[0], axis)
+
+def _conv2d(children, attrs, odtype='float32'):
+    use_bias = attrs.get_bool('use_bias', False)
+
+    if use_bias:
+        data, weight, bias = children
+    else:
+        data, weight = children
+
+    strides = attrs.get_int_tuple('strides', (1, 1))
+    padding = attrs.get_int_tuple('padding', (0, 0))
+    dilation = attrs.get_int_tuple('dilation', (1, 1))
+    groups = attrs.get_int('groups', 1)
+    data_layout = attrs.get_str('layout', 'NCHW')
+    weight_layout = attrs.get_str('kernel_layout', 'OIHW')
+    out_layout = ''
+    out_dtype = attrs.get_str('out_dtype', '')
+
+    conv_out = op.nn.conv2d(
+        data,
+        weight,
+        strides=strides,
+        padding=padding,
+        dilation=dilation,
+        groups=groups,
+        data_layout=data_layout,
+        weight_layout=weight_layout,
+        out_layout=out_layout,
+        out_dtype=out_dtype)
+
+    if use_bias:
+        return op.nn.bias_add(conv_out, bias)
+    else:
+        return conv_out
+
+
+def _conv2d_transpose(children, attrs, odtype='float32'):
+    use_bias = attrs.get_bool('use_bias', False)
+
+    if use_bias:
+        data, weight, bias = children
+    else:
+        data, weight = children
+
+    strides = attrs.get_int_tuple('strides', (1, 1))
+    padding = attrs.get_int_tuple('padding', (0, 0))
+    dilation = attrs.get_int_tuple('dilation', (1, 1))
+    groups = attrs.get_int('groups', 1)
+    data_layout = attrs.get_str('layout', 'NCHW')
+    weight_layout = attrs.get_str('kernel_layout', 'OIHW')
+    out_dtype = attrs.get_str('out_dtype', '')
+
+    out_conv2d = op.nn.conv2d_transpose(
+        data,
+        weight,
+        strides=strides,
+        padding=padding,
+        dilation=dilation,
+        groups=groups,
+        data_layout=data_layout,
+        weight_layout=weight_layout,
+        out_dtype=out_dtype)
+
+    if use_bias:
+        return op.nn.bias_add(out_conv2d, bias)
+    else:
+        return out_conv2d
+
+
+def _batch_norm(children, attrs, odtype='float32'):
+    data, gamma, beta, moving_mean, moving_view = children
+    axis = attrs.get_int('axis', 1)
+    epsilon = attrs.get_float('epsilon', 1e-05)
+    center = attrs.get_bool('center', True)
+    scale = attrs.get_bool('scale', True)
+
+    return op.nn.batch_norm(
+        data,
+        gamma,
+        beta,
+        moving_mean,
+        moving_view,
+        axis=axis,
+        epsilon=epsilon,
+        center=center,
+        scale=scale)[0]
+
+
+def _max_pool2d(children, attrs, odtype='float32'):
+    assert len(children) == 1
+    data = children[0]
+    pool_size = attrs.get_int_tuple('pool_size', (1, 1))
+    strides = attrs.get_int_tuple('strides', (1, 1))
+    padding = attrs.get_int_tuple('padding', (0, 0))
+    layout = attrs.get_int_tuple('layout', 'NCHW')
+    ceil_mode = attrs.get_bool('ceil_mode', False)
+
+    return op.nn.max_pool2d(
+        data,
+        pool_size=pool_size,
+        strides=strides,
+        padding=padding,
+        layout=layout,
+        ceil_mode=ceil_mode)
+
+
+def _reshape(children, attrs, odtype='float32'):
+    data = children[0]
+    shape = attrs.get_int_list('shape')
+    return op.reshape(data, shape)
+
+
+def _transpose(children, attrs, odtype='float32'):
+    axes = attrs.get_int_list('axes', None)
+    return op.transpose(children[0], axes=axes)
+
+
+def _add(children, attrs, odtype='float32'):
+    if len(children) == 1:
+        left = children[0]
+        scalar = attrs.get_float('scalar')
+        right = relay.const(scalar, dtype=odtype)
+    else:
+        assert len(children) == 2
+        left = children[0]
+        right = children[1]
+
+    return op.add(left, right)
+
+
+def _subtract(children, attrs, odtype='float32'):
+    if len(children) == 1:
+        left = children[0]
+        scalar = attrs.get_float('scalar')
+        right = relay.const(scalar, dtype=odtype)
+    else:
+        assert len(children) == 2
+        left = children[0]
+        right = children[1]
+
+    return op.subtract(left, right)
+
+
+def _rsubtract(children, attrs, odtype='float32'):
+    if len(children) == 1:
+        left = children[0]
+        scalar = attrs.get_float('scalar')
+        right = relay.const(scalar, dtype=odtype)
+    else:
+        assert len(children) == 2
+        left = children[0]
+        right = children[1]
+
+    return op.subtract(right, left)
+
+
+def _multiply(children, attrs, odtype='float32'):
+    if len(children) == 1:
+        left = children[0]
+        scalar = attrs.get_float('scalar')
+        right = relay.const(scalar, dtype=odtype)
+    else:
+        assert len(children) == 2
+        left = children[0]
+        right = children[1]
+
+    return op.multiply(left, right)
+
+
+def _divide(children, attrs, odtype='float32'):
+    if len(children) == 1:
+        left = children[0]
+        scalar = attrs.get_float('scalar')
+        right = relay.const(scalar, dtype=odtype)
+    else:
+        assert len(children) == 2
+        left = children[0]
+        right = children[1]
+
+    return op.divide(left, right)
+
+
+def _rshift(children, attrs, odtype='float32'):
+    if len(children) == 1:
+        left = children[0]
+        scalar = attrs.get_float('scalar')
+        right = relay.const(scalar, dtype='int32')
+    else:
+        assert len(children) == 2
+        left = children[0]
+        right = children[1]
+
+    return op.right_shift(left, right)
+
+
+def _clip(children, attrs, odtype='float32'):
+    a_min = attrs.get_float('a_min')
+    a_max = attrs.get_float('a_max')
+    return op.clip(children[0], a_min, a_max)
+
+
+def _cast(children, attrs, odtype='float32'):
+    data = children[0]
+    dtype = attrs.get_str('dtype')
+    return data.astype(dtype)
+
+
+def _expand_dims(children, attrs, odtype='float32'):
+    data = children[0]
+    axis = attrs.get_int('axis')
+    num_newaxis = attrs.get_int('num_newaxis', 1)
+    return op.transform.expand_dims(data, axis, num_newaxis=num_newaxis)
+
+
+def broadcast_to(children, attrs, odtype='float32'):
+    # TODO(@jroesch) export broadcast to?
+    data = children[0]
+    shape = attrs.get_int_tuple('shape')
+    array = numpy.zeros(shape).astype(odtype)
+    rconst = relay.Constant(nd.array(array))
+    return op.broadcast_to_like(data, rconst)
+
+def _copy(children, attrs, odtype='float32'):
+    return op.copy(children[0])
+
+
+def _global_avg_pool2d(children, attrs, odtype='float32'):
+    data = children[0]
+    layout = attrs.get_str('layout', "NCHW")
+    return op.nn.global_avg_pool2d(data, layout)
+
+
+def _avg_pool2d(children, attrs, odtype='float32'):
+    data = children[0]
+    pool_size = attrs.get_int_tuple('pool_size', (1, 1))
+    strides = attrs.get_int_tuple('strides', (1, 1))
+    padding = attrs.get_int_tuple('padding', (0, 0))
+    layout = attrs.get_str('layout', "NCHW")
+    ceil_mode = attrs.get_bool('ceil_mode', False)
+    count_include_pad = attrs.get_bool('layout', False)
+    return op.nn.avg_pool2d(
+        data,
+        pool_size=pool_size,
+        strides=strides,
+        padding=padding,
+        layout=layout,
+        ceil_mode=ceil_mode,
+        count_include_pad=count_include_pad)
+
+
+def _upsampling(children, attrs, odtype='float32'):
+    scale = attrs.get_int('scale')
+    layout = attrs.get_str('layout', 'NCHW')
+    method = attrs.get_str('method', 'NEAREST_NEIGHBOR')
+    return op.nn.upsampling(
+        children[0],
+        scale=scale,
+        layout=layout,
+        method=method)
+
+
+def _pad(children, attrs, odtype='float32'):
+    pad_value = attrs.get_float('pad_value', 0.0)
+    pad_width = attrs.get_tuple_tuple_int('pad_width')
+    return op.nn.pad(children[0], pad_width, pad_value=pad_value)
+
+def _leaky_relu(children, attrs, odtype='float32'):
+    alpha = attrs.get_float('alpha')
+    return op.nn.leaky_relu(children[0], alpha)
+
+
+def _full_like(children, attrs, odtype='float32'):
+    fill_value = relay.const(attrs.get_float('fill_value'), dtype='float32')
+    return op.full_like(children[0], fill_value)
+
+
+def _greater(children, attrs, odtype='float32'):
+    out_type = attrs.get_str('out_type')
+    if out_type:
+        return op.greater(children[0], children[1]).astype(out_type)
+    else:
+        return op.greater(children[0], children[1])
+
+
+def _greater_equal(children, attrs, odtype='float32'):
+    out_type = attrs.get_str('out_type', None)
+    if out_type:
+        return op.greater_equal(children[0], children[1]).astype(out_type)
+    else:
+        return op.greater_equal(children[0], children[1])
+
+
+def _less(children, attrs, odtype='float32'):
+    out_type = attrs.get_str('out_type', None)
+    if out_type:
+        return op.less(children[0], children[1]).astype(out_type)
+    else:
+        return op.less(children[0], children[1])
+
+
+def _less_equal(children, attrs, odtype='float32'):
+    out_type = attrs.get_str('out_type', None)
+    if out_type:
+        return op.less_equal(children[0], children[1]).astype(out_type)
+    else:
+        return op.less_equal(children[0], children[1])
+
+
+def _strided_slice(children, attrs, odtype='float32'):
+    begin = attrs.get_int_list('begin')
+    end = attrs.get_int_list('end')
+    strides = attrs.get_int_list('strides', None)
+    return op.strided_slice(children[0], begin, end, strides=strides)
+
+
+def _split(children, attrs, odtype='float32'):
+    indices_or_sections = None
+    try:
+        indices_or_sections = attrs.get_int('indices_or_sections', None)
+    except ValueError:
+        indices_or_sections = indices_or_sections or attrs.get_int_tuple(
+            'indices_or_sections')
+
+    axis = attrs.get_int('axis', 0)
+
+    return op.split(children[0], indices_or_sections, axis)
+
+def _squeeze(children, attrs, odtype='float32'):
+    axis = None
+    try:
+        axis = [attrs.get_int('axis', None)]
+    except ValueError:
+        axis = axis or attrs.get_int_tuple('axis', None)
+
+    return op.squeeze(children[0], axis)
+
+NNVM_OP_2_RELAY_OP = {
+    'flatten': _nn_batch_flatten,
+    'dense': _dense,
+    'softmax': _nn_softmax,
+    'conv2d': _conv2d,
+    'batch_norm': _batch_norm,
+    'max_pool2d': _max_pool2d,
+    'reshape': _reshape,
+    'transpose': _transpose,
+    # Addition
+    '__add_scalar__': _add,
+    'broadcast_add': _add,
+    'elemwise_add': _add,
+    # Subtraction
+    '__sub_scalar__': _subtract,
+    '__rsub_scalar__': _rsubtract,
+    'broadcast_sub': _subtract,
+    'elemwise_sub': _subtract,
+    # Multiply
+    '__mul_scalar__': _multiply,
+    'broadcast_mul': _multiply,
+    'elemwise_mul': _multiply,
+    # Division
+    '__div_scalar__': _divide,
+    'broadcast_div': _divide,
+    'elemwise_div': _divide,
+    # Negative
+    'negative': _rename("negative"),
+
+    # Comparsion
+    'greater': _greater,
+    'greater_equal': _greater_equal,
+    'less': _less,
+    'less_equal': _less_equal,
+
+    # Activations
+    'sigmoid': _rename('sigmoid'),
+    'relu': _rename('nn.relu'),
+    'exp': _rename('exp'),
+    'log': _rename('log'),
+    'tanh': _rename('tanh'),
+    'leaky_relu': _leaky_relu,
+    'clip': _clip,
+    'round': _rename('round'),
+    'cast': _cast,
+    'expand_dims': _expand_dims,
+    'broadcast_to': broadcast_to,
+    '__rshift_scalar__': _rshift,
+    'copy': _copy,
+    'global_avg_pool2d': _global_avg_pool2d,
+    'avg_pool2d': _avg_pool2d,
+    'conv2d_transpose': _conv2d_transpose,
+    'upsampling': _upsampling,
+    'pad': _pad,
+    'full_like': _full_like,
+    'strided_slice': _strided_slice,
+    'split': _split,
+    'squeeze': _squeeze,
+}
+
+
+def to_relay(graph, shape_dict, dtype_dict, params):
+    """Convert an NNVM graph into the corresponding Relay expression.
+
+    Parameters
+    ----------
+    graph : Graph
+       The input graph.
+
+    shape_dict : dict of str to shape
+       The input shape.
+
+    dtype_dict : dict of str to shape
+       The input shape.
+
+    params : dict of str to array
+        The parameters.
+
+    Returns
+    -------
+    (expr, params) : Tuple[relay.Expr, dict of str to array]
+        The corresponding Relay expression and parameters.
+    """
+    if isinstance(graph, Symbol):
+        graph = graph_create(graph)
+
+    param_shapes = dict((k, params[k].shape) for k in params)
+    shape_dict = shape_dict.copy()
+    shape_dict.update(param_shapes)
+    graph = graph_attr.set_shape_inputs(graph, shape_dict)
+    graph = graph_attr.set_dtype_inputs(graph, dtype_dict)
+    graph = graph.apply(["InferShape", "InferType"])
+    shape = graph.json_attr("shape")
+    dtype = [graph_attr.TCODE_TO_DTYPE[di] for di in graph.json_attr("dtype")]
+    heads = [x[0] for x in json.loads(graph.json())['heads']]
+
+    gidx = graph.index
+    relay_map = {}
+    fn_params = []
+    output_ids = []
+
+    for nid, node in enumerate(gidx.nodes):
+        children = []
+        for i in node['inputs']:
+            child = relay_map[i[0]]
+            if isinstance(child, expr.TupleWrapper):
+                children.append(child[i[1]])
+            else:
+                children.append(child)
+
+        oshape = shape[gidx.entry_id(nid, 0)]
+        odtype = dtype[gidx.entry_id(nid, 0)]
+        attrs = node.get("attrs", {})
+        node_name = node["name"]
+        op_name = node["op"]
+
+        if op_name == "null":
+            v = var(node_name, shape=oshape, dtype=odtype)
+            fn_params.append(v)
+            relay_map[nid] = v
+        else:
+            if nid in heads:
+                output_ids.append(nid)
+
+            if op_name in NNVM_OP_2_RELAY_OP:
+                str_attrs = StrAttrsDict(attrs)
+                call = NNVM_OP_2_RELAY_OP[op_name](children, str_attrs, odtype)
+                relay_map[nid] = call
+            else:
+                raise Exception(
+                    "nnvm.to_relay: unsupported operator: {0}".format(op_name))
+
+    outputs = [relay_map[nid] for nid in output_ids]
+    if len(outputs) == 1:
+        body = outputs[0]
+    else:
+        body = expr.Tuple(outputs)
+
+    func = relay.Function(fn_params, body)
+    return func, params
diff --git a/nnvm/tests/python/compiler/test_to_relay.py b/nnvm/tests/python/compiler/test_to_relay.py
new file mode 100644
index 000000000..25037cfd3
--- /dev/null
+++ b/nnvm/tests/python/compiler/test_to_relay.py
@@ -0,0 +1,41 @@
+import nnvm
+from nnvm import testing
+from nnvm import to_relay
+import tvm
+from tvm.relay import ir_pass
+from tvm.relay import create_executor
+from tvm.contrib import graph_runtime
+import numpy as np
+
+def check_model(sym, shapes, dtypes, params):
+    net = nnvm.graph.create(sym)
+    graph_json, mod, params = nnvm.compiler.build(
+        net,
+        'llvm',
+        shape=shapes,
+        dtype=dtypes,
+        params=params)
+    nnvm_rts = graph_runtime.create(graph_json, mod, tvm.cpu(0))
+    inputs = {}
+    for name in shapes:
+        np_array = np.random.rand(*shapes[name]).astype('float32')
+        inputs[name] = tvm.nd.array(np_array)
+
+    nnvm_rts.set_input(**params)
+    nnvm_rts.run(**inputs)
+    nnvm_out = nnvm_rts.get_output(0)
+    relay_model, params = to_relay.to_relay(net, shapes, dtypes, params)
+    relay_model = ir_pass.infer_type(relay_model)
+    relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm')
+    inputs.update(params)
+    relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values()))
+    np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy())
+
+# def test_mlp():
+#     mlp, params = testing.mlp.get_workload(1)
+#     shapes =  { "data": (10, 3, 224, 224) }
+#     dtypes =  { "data": 'float32' }
+#     check_model(mlp, shapes, dtypes, params)
+
+if __name__ == "__main__":
+    test_mlp()
diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 8e037d4bc..95633a4d4 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -101,11 +101,64 @@ class StrAttrsDict(object):
         """
         if key in self.attrs:
             tshape = self.attrs[key]
-            return tuple(int(x.strip()) for x in tshape.strip('()').split(','))
+            return tuple(int(x.strip()) for x in tshape.strip('()[]').split(','))
         if isinstance(default, RequiredAttr):
             raise AttributeError("Required attribute {} not found.".format(key))
         return default
 
+    def get_tuple_tuple_int(self, key, default=RequiredAttr()):
+        """Get int list attribute
+
+        Parameters
+        ----------
+        key : str
+            The attribute key
+
+        default : float
+            The default value.
+
+        Returns
+        -------
+        value : The result
+        """
+        if key in self.attrs:
+            value = self.attrs[key]
+            seq = []
+            for tup in value.strip('()').split('),'):
+                tup = tup.strip('[]()')
+                els = [int(x.strip('( ')) for x in tup.split(',')]
+                seq.append(tuple(els))
+
+            return tuple(seq)
+
+        if isinstance(default, RequiredAttr):
+            raise AttributeError("Required attribute {} not found.".format(key))
+        return default
+
+    def get_int_list(self, key, default=RequiredAttr()):
+        """Get int list attribute
+
+        Parameters
+        ----------
+        key : str
+            The attribute key
+
+        default : float
+            The default value.
+
+        Returns
+        -------
+        value : The result
+        """
+        if key in self.attrs:
+            tshape = self.attrs[key]
+            return tuple(int(x.strip()) for x in tshape.strip('[]()').split(','))
+        if isinstance(default, RequiredAttr):
+            raise AttributeError("Required attribute {} not found.".format(key))
+        return default
+
+
+
     def get_bool(self, key, default=RequiredAttr()):
         """Get bool tuple attribute
 
diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index b0b1e7009..77e97d26e 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -8,138 +8,14 @@ from .. import expr as _expr
 from .. import op as _op
 from ... import nd as _nd
 from .common import StrAttrsDict
+from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
+from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
+from .nnvm_common import _clip, _transpose, _upsampling
+from .nnvm_common import _elemwise_sum, _reshape
+from .nnvm_common import _warn_not_used
 
 __all__ = ['from_mxnet']
 
-
-def _get_relay_op(op_name):
-    op = getattr(_op, op_name)
-    if not op:
-        raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
-    return op
-
-
-def _warn_not_used(attr, op='nnvm'):
-    import warnings
-    err = "{} is ignored in {}.".format(attr, op)
-    warnings.warn(err)
-
-
-def _rename(new_op):
-    if isinstance(new_op, str):
-        new_op = _get_relay_op(new_op)
-    # attrs are ignored.
-    def impl(inputs, _):
-        return new_op(*inputs)
-    return impl
-
-
-def _reshape(inputs, attrs):
-    if attrs.get_bool("reverse", False):
-        raise RuntimeError("reshape do not support option reverse")
-    shape = attrs.get_int_tuple("shape")
-    return _op.reshape(inputs[0], newshape=shape)
-
-
-def _init_op(new_op):
-    """Init ops like zeros/ones"""
-    def _impl(inputs, attrs):
-        assert len(inputs) == 0
-        shape = attrs.get_int_tuple("shape")
-        dtype = attrs.get_str("dtype", "float32")
-        return new_op(shape=shape, dtype=dtype)
-    return _impl
-
-
-def _softmax_op(new_op):
-    """softmax/log_softmax"""
-    def _impl(inputs, attrs):
-        assert len(inputs) == 1
-        axis = attrs.get_int("axis", -1)
-        return new_op(inputs[0], axis=axis)
-    return _impl
-
-
-def _reduce(new_op):
-    """Reduction ops like sum/min/max"""
-    def _impl(inputs, attrs):
-        assert len(inputs) == 1
-        axis = attrs.get_int_tuple("axis", [])
-        keepdims = attrs.get_bool("keepdims", False)
-        # use None for reduce over all axis.
-        axis = None if len(axis) == 0 else axis
-        return new_op(inputs[0], axis=axis, keepdims=keepdims)
-    return _impl
-
-
-def _arg_reduce(new_op):
-    """Arg Reduction ops like argmin/argmax"""
-    def _impl(inputs, attrs):
-        assert len(inputs) == 1
-        axis = attrs.get_int("axis", None)
-        keepdims = attrs.get_bool("keepdims", False)
-        res = new_op(inputs[0], axis=[axis], keepdims=keepdims)
-        # cast to dtype.
-        res = res.astype("float32")
-        return res
-    return _impl
-
-
-def _cast(inputs, attrs):
-    """Type cast"""
-    dtype = attrs.get_str("dtype")
-    return _op.cast(inputs[0], dtype=dtype)
-
-
-def _clip(inputs, attrs):
-    a_min = attrs.get_float("a_min")
-    a_max = attrs.get_float("a_max")
-    return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
-
-
-def _transpose(inputs, attrs):
-    axes = attrs.get_int_tuple("axes", None)
-    # translate default case
-    axes = None if len(axes) == 0 else axes
-    return _op.transpose(inputs[0], axes=axes)
-
-
-def _upsampling(inputs, attrs):
-    scale = attrs.get_int("scale")
-    return _op.nn.upsampling(inputs[0], scale=scale)
-
-
-def _elemwise_sum(inputs, _):
-    assert len(inputs) > 0
-    res = inputs[0]
-    for x in inputs[1:]:
-        res = _op.add(res, x)
-    return res
-
-
-def _binop_scalar(new_op):
-    def _impl(inputs, attrs):
-        assert len(inputs) == 1
-        scalar = attrs.get_float("scalar")
-        # Note: binary scalar only works for float op for now
-        scalar = _expr.const(scalar, dtype="float32")
-        return new_op(inputs[0], scalar)
-    return _impl
-
-
-def _rbinop_scalar(new_op):
-    def _impl(inputs, attrs):
-        assert len(inputs) == 1
-        scalar = attrs.get_float("scalar")
-        # Note: binary scalar only works for float op for now
-        scalar = _expr.const(scalar, dtype="float32")
-        return new_op(scalar, inputs[0])
-    return _impl
-
-# All the functions with _mx prefix specific to MXNet.
-# The functions without _mx prefix can be reused for
-# NNVMv1 conversion to _op.
-
 def _mx_fully_connected(inputs, attrs):
     import mxnet as mx
     units = attrs.get_int("num_hidden")
@@ -493,6 +369,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
     jnodes = jgraph["nodes"]
     node_map = {}
 
+
     for nid, node in enumerate(jnodes):
         children = [node_map[e[0]][e[1]] for e in node["inputs"]]
         attrs = StrAttrsDict(node.get("attrs", {}))
@@ -501,7 +378,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
         if op_name == "null":
             shape = shape_dict[node_name] if node_name in shape_dict else None
             if isinstance(dtype_info, dict):
-                dtype = dtype_info[node_name] if node_name in dtype_dict else "float32"
+                dtype = dtype_info[node_name] if node_name in dtype_info else "float32"
             else:
                 dtype = dtype_info
             node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py
new file mode 100644
index 000000000..17502dbaa
--- /dev/null
+++ b/python/tvm/relay/frontend/nnvm_common.py
@@ -0,0 +1,132 @@
+# pylint: disable=invalid-name, import-self, len-as-condition
+"""Utility functions common to NNVM and MxNet conversion."""
+from __future__ import absolute_import as _abs
+
+from .. import expr as _expr
+from .. import op as _op
+
+def _get_relay_op(op_name):
+    op = _op
+    for path in op_name.split("."):
+        op = getattr(op, path)
+    if not op:
+        raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
+    return op
+
+
+def _warn_not_used(attr, op='nnvm'):
+    import warnings
+    err = "{} is ignored in {}.".format(attr, op)
+    warnings.warn(err)
+
+
+def _rename(new_op):
+    if isinstance(new_op, str):
+        new_op = _get_relay_op(new_op)
+    # attrs are ignored.
+    def impl(inputs, _, _dtype='float32'):
+        return new_op(*inputs)
+    return impl
+
+
+def _reshape(inputs, attrs):
+    if attrs.get_bool("reverse", False):
+        raise RuntimeError("reshape do not support option reverse")
+    shape = attrs.get_int_tuple("shape")
+    return _op.reshape(inputs[0], newshape=shape)
+
+
+def _init_op(new_op):
+    """Init ops like zeros/ones"""
+    def _impl(inputs, attrs):
+        assert len(inputs) == 0
+        shape = attrs.get_int_tuple("shape")
+        dtype = attrs.get_str("dtype", "float32")
+        return new_op(shape=shape, dtype=dtype)
+    return _impl
+
+
+def _softmax_op(new_op):
+    """softmax/log_softmax"""
+    def _impl(inputs, attrs):
+        assert len(inputs) == 1
+        axis = attrs.get_int("axis", -1)
+        return new_op(inputs[0], axis=axis)
+    return _impl
+
+
+def _reduce(new_op):
+    """Reduction ops like sum/min/max"""
+    def _impl(inputs, attrs):
+        assert len(inputs) == 1
+        axis = attrs.get_int_tuple("axis", [])
+        keepdims = attrs.get_bool("keepdims", False)
+        # use None for reduce over all axis.
+        axis = None if len(axis) == 0 else axis
+        return new_op(inputs[0], axis=axis, keepdims=keepdims)
+    return _impl
+
+
+def _arg_reduce(new_op):
+    """Arg Reduction ops like argmin/argmax"""
+    def _impl(inputs, attrs):
+        assert len(inputs) == 1
+        axis = attrs.get_int("axis", None)
+        keepdims = attrs.get_bool("keepdims", False)
+        res = new_op(inputs[0], axis=[axis], keepdims=keepdims)
+        # cast to dtype.
+        res = res.astype("float32")
+        return res
+    return _impl
+
+
+def _cast(inputs, attrs):
+    """Type cast"""
+    dtype = attrs.get_str("dtype")
+    return inputs[0].astype(dtype=dtype)
+
+
+def _clip(inputs, attrs):
+    a_min = attrs.get_float("a_min")
+    a_max = attrs.get_float("a_max")
+    return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
+
+
+def _transpose(inputs, attrs):
+    axes = attrs.get_int_tuple("axes", None)
+    # translate default case
+    axes = None if len(axes) == 0 else axes
+    return _op.transpose(inputs[0], axes=axes)
+
+
+def _upsampling(inputs, attrs):
+    scale = attrs.get_int("scale")
+    return _op.nn.upsampling(inputs[0], scale=scale)
+
+
+def _elemwise_sum(inputs, _):
+    assert len(inputs) > 0
+    res = inputs[0]
+    for x in inputs[1:]:
+        res = _op.add(res, x)
+    return res
+
+
+def _binop_scalar(new_op):
+    def _impl(inputs, attrs):
+        assert len(inputs) == 1
+        scalar = attrs.get_float("scalar")
+        # Note: binary scalar only works for float op for now
+        scalar = _expr.const(scalar, dtype="float32")
+        return new_op(inputs[0], scalar)
+    return _impl
+
+
+def _rbinop_scalar(new_op):
+    def _impl(inputs, attrs):
+        assert len(inputs) == 1
+        scalar = attrs.get_float("scalar")
+        # Note: binary scalar only works for float op for now
+        scalar = _expr.const(scalar, dtype="float32")
+        return new_op(scalar, inputs[0])
+    return _impl
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 1aaf376a7..c1e71e913 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -9,6 +9,7 @@ from .op import schedule_injective, OpPattern
 schedule_injective = _reg.schedule_injective
 schedule_broadcast = _reg.schedule_injective
 
+
 _reg.register_schedule("collapse_sum_like", _schedule_reduce)
 _reg.register_schedule("broadcast_to_like", schedule_broadcast)
 _reg.register_schedule("expand_dims", schedule_broadcast)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 007888996..f5f76e6af 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -243,14 +243,11 @@ def schedule_l2_normalize(attrs, outs, target):
 
 reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
 
-
-@reg.register_schedule("nn.upsampling")
+# Upsampling
+reg.register_schedule("nn.upsampling", reg.schedule_injective)
 def schedule_upsampling(_, outs, target):
     """Schedule definition of upsampling"""
     with target:
         return topi.generic.schedule_injective(outs)
-
-reg.register_pattern("nn.upsampling", OpPattern.INJECTIVE)
-
 # pad
 reg.register_schedule("nn.pad", schedule_broadcast)
diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc
index 5001e2cd4..4a5aa4ea0 100644
--- a/src/relay/backend/graph_plan_memory.cc
+++ b/src/relay/backend/graph_plan_memory.cc
@@ -253,6 +253,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
     size_t size = 1;
     for (IndexExpr dim : ttype->shape) {
       const int64_t* pval = as_const_int(dim);
+      CHECK_GE(*pval, 0) <<
+        "can not allocate memory for tensor with negative shape" <<
+        *pval;
       CHECK(pval != nullptr)
           << "Cannot allocate memory symbolic tensor shape "
           << ttype->shape;
diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 16af572a9..064343c83 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -13,7 +13,7 @@
 namespace tvm {
 namespace relay {
 
-// Alpha equal handler for relay.
+// Alpha Equal handler for Relay.
 class AlphaEqualHandler:
       public AttrsEqualHandler,
       public TypeFunctor<bool(const Type&, const Type&)>,
@@ -26,7 +26,7 @@ class AlphaEqualHandler:
    * Check equality of two nodes.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return The compare result.
+   * \return The comparison result.
    */
   bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
     if (lhs.same_as(rhs)) return true;
@@ -46,7 +46,7 @@ class AlphaEqualHandler:
    * Check equality of two attributes.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return The compare result.
+   * \return The comparison result.
    */
   bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
     return AttrsEqualHandler::Equal(lhs, rhs);
@@ -55,7 +55,7 @@ class AlphaEqualHandler:
    * Check equality of two types.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return The compare result.
+   * \return the comparison result.
    */
   bool TypeEqual(const Type& lhs, const Type& rhs) {
     if (lhs.same_as(rhs)) return true;
@@ -72,7 +72,7 @@ class AlphaEqualHandler:
    *
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return The compare result.
+   * \return The comparison result.
    */
   bool ExprEqual(const Expr& lhs, const Expr& rhs) {
     if (lhs.same_as(rhs)) return true;
diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc
index 6a98d2884..d386437ae 100644
--- a/src/relay/op/nn/upsampling.cc
+++ b/src/relay/op/nn/upsampling.cc
@@ -6,8 +6,11 @@
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/op_attr_types.h>
+#include <tvm/build_module.h>
 #include <topi/elemwise.h>
 #include <topi/nn/upsampling.h>
+#include <vector>
+#include "../op_common.h"
 #include "../layout.h"
 
 namespace tvm {
@@ -86,26 +89,37 @@ RELAY_REGISTER_OP("nn.upsampling")
 .add_argument("data", "Tensor", "The input tensor.")
 .set_support_level(2)
 .add_type_rel("UpSampling", UpSamplingRel)
+.set_attr<TOpPattern>("TOpPattern", kInjective)
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const Attrs& attrs,
-          const Array<Tensor>& inputs,
-          const Type& out_type,
-          const Target& target) {
-  const auto* param = attrs.as<UpSamplingAttrs>();
-  const auto* out_ttype = out_type.as<TensorTypeNode>();
-  CHECK(param != nullptr);
-  CHECK(param->layout == "NCHW" || param->layout == "NHWC");
-  CHECK(out_ttype != nullptr);
-  Array<IndexExpr> oshape;
-  if (param->layout == "NCHW") {
-    oshape.push_back(out_ttype->shape[2]);
-    oshape.push_back(out_ttype->shape[3]);
-  } else if (param->layout == "NHWC") {
-    oshape.push_back(out_ttype->shape[1]);
-    oshape.push_back(out_ttype->shape[2]);
-  }
-  return Array<Tensor>{ topi::nn::upsampling(inputs[0], oshape, param->layout, param->method)};
+                    const Array<Tensor>& inputs,
+                    const Type& out_type,
+                    const Target& target) {
+    const auto* uattrs = attrs.as<UpSamplingAttrs>();
+    CHECK(uattrs != nullptr);
+    auto out_tt = out_type.as<TensorTypeNode>();
+    CHECK(out_tt) << "expected a tensor type: " << out_type;
+    CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC")
+      << "unknown layout: " << uattrs->layout;
+
+    Array<HalideIR::Expr> oshape;
+    if (uattrs->layout == "NCHW") {
+      oshape.push_back(out_tt->shape[2]);
+      oshape.push_back(out_tt->shape[3]);
+    } else if (uattrs->layout == "NHWC") {
+      oshape.push_back(out_tt->shape[1]);
+      oshape.push_back(out_tt->shape[2]);
+    }
+
+    return Array<Tensor>{
+      topi::nn::upsampling(
+        inputs[0],
+        oshape,
+        uattrs->layout,
+        uattrs->method)
+    };
 });
 
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/frontend/test_keras.py b/tests/python/relay/frontend/test_keras.py
new file mode 100644
index 000000000..f508c5b44
--- /dev/null
+++ b/tests/python/relay/frontend/test_keras.py
@@ -0,0 +1,332 @@
+import numpy as np
+import nnvm
+from nnvm import to_relay
+import tvm
+from tvm import relay
+from tvm.contrib import graph_runtime
+from nnvm.testing.config import ctx_list
+import keras
+
+# prevent keras from using up all gpu memory
+import tensorflow as tf
+from keras.backend.tensorflow_backend import set_session
+config = tf.ConfigProto()
+config.gpu_options.per_process_gpu_memory_fraction = 0.5
+set_session(tf.Session(config=config))
+
+
+def verify_keras_frontend(keras_model, need_transpose=True):
+    # Keras frontend currently supports tensorflow backend only.
+    assert(keras.backend.backend() == 'tensorflow')
+
+    in_shapes = []
+    for layer in keras_model._input_layers:
+        in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
+
+    def get_keras_output(xs, dtype='float32'):
+        return keras_model.predict(xs)
+
+    def get_tvm_output(xs, target, ctx, dtype='float32'):
+        sym, params = nnvm.frontend.from_keras(keras_model)
+        shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
+        with relay.build_module.build_config(opt_level=2):
+            func, params = to_relay.to_relay(sym, shape_dict, dtype, params)
+            graph, lib, params = relay.build(func, target='llvm', params=params)
+        m = graph_runtime.create(graph, lib, ctx)
+        for name, x in zip(keras_model.input_names, xs):
+            m.set_input(name, tvm.nd.array(x.astype(dtype)))
+        m.set_input(**params)
+        m.run()
+
+        return [m.get_output(i).asnumpy() for i in range(m.get_num_outputs())]
+
+    def to_channels_first(arr):
+        return arr.transpose([0, -1] + list(range(1, arr.ndim - 1)))
+
+    def to_channels_last(arr):
+        return arr.transpose([0] + list(range(2, arr.ndim)) + [1])
+
+    xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
+    keras_out = get_keras_output(xs)
+
+    keras_out = keras_out if isinstance(keras_out, list) else [keras_out]
+    for target, ctx in ctx_list():
+        inputs = [to_channels_first(x) for x in xs] if need_transpose else xs
+        tvm_out = get_tvm_output(inputs, target, ctx)
+        for kout, tout in zip(keras_out, tvm_out):
+            if need_transpose:
+                tout = to_channels_last(tout)
+            tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5)
+
+def test_forward_elemwise_add():
+    r = []
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
+    r.append(x)
+    x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
+    r.append(x)
+    x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
+    # add two symbols
+    y = keras.layers.add([keras.layers.add([x, r[0]]), r[1]])
+    y = keras.layers.GlobalAveragePooling2D()(y)
+    keras_model = keras.models.Model(data, y)
+    verify_keras_frontend(keras_model)
+    # add three symbols
+    y = keras.layers.add([x, r[0], r[1]])
+    y = keras.layers.GlobalAveragePooling2D()(y)
+    keras_model = keras.models.Model(data, y)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_dense():
+    data = keras.layers.Input(shape=(32,32,1))
+    x = keras.layers.Flatten()(data)
+    x = keras.layers.Dropout(0.5)(x)
+    x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_pool():
+    data = keras.layers.Input(shape=(32,32,1))
+    # maxpool
+    x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+    # avgpool
+    y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data)
+    keras_model = keras.models.Model(data, y)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_conv():
+    data = keras.layers.Input(shape=(32,32,3))
+    conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3,3),
+                                      strides=(2,2), padding='same'),
+                  keras.layers.Conv2D(filters=10, kernel_size=(3,3),
+                                      dilation_rate=(2,2), padding='same'),
+                  keras.layers.DepthwiseConv2D(kernel_size=(3,3), padding='same'),
+                  keras.layers.Conv2DTranspose(filters=10, kernel_size=(3,3), padding='valid'),
+                  keras.layers.SeparableConv2D(filters=10, kernel_size=(3,3), padding='same')]
+    for conv_func in conv_funcs:
+        x = conv_func(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model)
+
+
+def test_forward_upsample():
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.UpSampling2D(size=(3,3))(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_reshape():
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Reshape(target_shape=(32,32,3))(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_crop():
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
+    x = keras.layers.Cropping2D(cropping=(1, 1))(x)
+    x = keras.layers.Cropping2D(cropping=1)(x)
+    x = keras.layers.Cropping2D(cropping=((0, 1), (1, 0)))(x)
+    x = keras.layers.Cropping2D(cropping=(1, 0))(x)
+    x = keras.layers.Cropping2D(cropping=0)(x)
+    x = keras.layers.Add()([x, x])
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_vgg16():
+    keras_model = keras.applications.vgg16.VGG16(include_top=True, weights='imagenet',
+        input_shape=(224,224,3), classes=1000)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_xception():
+    keras_model = keras.applications.xception.Xception(include_top=True, weights='imagenet',
+        input_shape=(299,299,3), classes=1000)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_resnet50():
+    keras_model = keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet',
+        input_shape=(224,224,3), classes=1000)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_mobilenet():
+    keras_model = keras.applications.mobilenet.MobileNet(include_top=True, weights='imagenet',
+        input_shape=(224,224,3), classes=1000)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_activations():
+    data = keras.layers.Input(shape=(32,32,3))
+    weights = np.random.rand(1, 32, 32, 3)
+    act_funcs = [keras.layers.Activation('softmax'),
+                 keras.layers.Activation('softplus'),
+                 keras.layers.ReLU(),
+                 keras.layers.ReLU(max_value=6.),
+                 keras.layers.LeakyReLU(alpha=0.3),
+                 keras.layers.PReLU(weights=weights, alpha_initializer="zero"),
+                 keras.layers.ELU(alpha=0.5),
+                 keras.layers.Activation('selu'),
+                 keras.layers.ThresholdedReLU(theta=0.5),
+                 keras.layers.Activation('softsign'),
+                 keras.layers.Activation('hard_sigmoid'),
+                 keras.layers.Activation('sigmoid'),
+                 keras.layers.Activation('tanh'),
+                 keras.layers.Activation('linear')]
+    for act_func in act_funcs:
+        x = act_func(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model)
+
+
+def test_forward_multi_inputs():
+    data1 = keras.layers.Input(shape=(32,32,3))
+    data2 = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
+    y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2)
+    z = keras.layers.add([x, y])
+    z = keras.layers.GlobalAveragePooling2D()(z)
+    keras_model = keras.models.Model([data1, data2], z)
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_multi_outputs():
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
+    x = keras.layers.GlobalAveragePooling2D()(x)
+    y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
+    y = keras.layers.GlobalAveragePooling2D()(y)
+    keras_model = keras.models.Model(data, [x, y])
+    verify_keras_frontend(keras_model)
+
+
+def test_forward_reuse_layers():
+    # reuse conv2d
+    data = keras.layers.Input(shape=(32,32,3))
+    conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
+    x = conv2d(data)
+    y = conv2d(data)
+    z = keras.layers.add([x, y])
+    z = keras.layers.GlobalAveragePooling2D()(z)
+    keras_model = keras.models.Model(data, z)
+    verify_keras_frontend(keras_model)
+
+    # reuse add
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
+    add = keras.layers.Add()
+    x = add([x, x])
+    x = add([x, x])
+    z = keras.layers.GlobalAveragePooling2D()(x)
+    keras_model = keras.models.Model(data, z)
+    verify_keras_frontend(keras_model)
+
+def _test_LSTM(inputs, hidden, return_state=True):
+    data = keras.layers.Input(shape=(1, inputs))
+    lstm_out = keras.layers.LSTM(hidden,
+                                 return_state=return_state,
+                                 recurrent_activation='sigmoid',
+                                 activation='tanh')
+    x = lstm_out(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def _test_LSTM_MultiLayer(inputs, hidden):
+    inputs = keras.layers.Input(shape=(1, inputs))
+    layer = keras.layers.LSTM(hidden, return_state=True, return_sequences=True,
+                                 recurrent_activation='sigmoid',
+                                 activation='tanh')
+    outputs = layer(inputs)
+    output, state = outputs[0], outputs[1:]
+    output = keras.layers.LSTM(hidden, recurrent_activation='sigmoid',
+                               activation='tanh')(output, initial_state=state)
+    keras_model = keras.models.Model(inputs, output)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+
+def test_forward_LSTM():
+    # TODO(@jroesch): need to modify compile engine to fix return_state=True
+    _test_LSTM(8, 8, return_state=False)
+    _test_LSTM(4, 4, return_state=False)
+    _test_LSTM_MultiLayer(4, 4)
+
+def _test_RNN(inputs, units):
+    data = keras.layers.Input(shape=(1, inputs))
+    rnn_out = keras.layers.SimpleRNN(units, return_state=True,
+                                 activation='tanh')
+    x = rnn_out(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def _test_RNN_MultiLayer(inputs, units):
+    inputs = keras.layers.Input(shape=(1, inputs))
+    layer = keras.layers.SimpleRNN(units, return_state=True, return_sequences=True,
+                                   activation='tanh')
+    outputs = layer(inputs)
+    output, state = outputs[0], outputs[1:]
+    output = keras.layers.SimpleRNN(units, activation='tanh')(output, initial_state=state)
+    keras_model = keras.models.Model(inputs, output)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def test_forward_RNN():
+    _test_RNN(2, 4)
+    _test_RNN(4, 3)
+    _test_RNN_MultiLayer(4, 12)
+
+def _test_GRU(inputs, units):
+    data = keras.layers.Input(shape=(1, inputs))
+    gru_out = keras.layers.GRU(units,
+                               return_state=True,
+                               recurrent_activation='sigmoid',
+                               activation='tanh')
+    x = gru_out(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def _test_GRU_MultiLayer(inputs, units):
+    inputs = keras.layers.Input(shape=(1, inputs))
+    layer = keras.layers.GRU(units,
+                             return_state=True,
+                             return_sequences=True,
+                             recurrent_activation='sigmoid',
+                             activation='tanh')
+    outputs = layer(inputs)
+    output, state = outputs[0], outputs[1:]
+    output = keras.layers.GRU(units, recurrent_activation='sigmoid',
+                              activation='tanh')(output, initial_state=state)
+    keras_model = keras.models.Model(inputs, output)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def test_forward_GRU():
+    _test_GRU(2, 4)
+    _test_GRU(4, 3)
+    _test_GRU_MultiLayer(4, 4)
+
+if __name__ == '__main__':
+    test_forward_elemwise_add()
+    test_forward_activations()
+    test_forward_dense()
+    test_forward_pool()
+    test_forward_conv()
+    test_forward_upsample()
+    test_forward_reshape()
+    test_forward_crop()
+    test_forward_vgg16()
+    test_forward_xception()
+    test_forward_resnet50()
+    test_forward_mobilenet()
+    test_forward_multi_inputs()
+    test_forward_multi_outputs()
+    test_forward_reuse_layers()
+    test_forward_LSTM()
+    test_forward_RNN()
+    test_forward_GRU()
diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h
index b6bd51ef0..2ffe4f453 100644
--- a/topi/include/topi/image/resize.h
+++ b/topi/include/topi/image/resize.h
@@ -12,6 +12,7 @@
 #include <algorithm>
 
 #include "topi/tags.h"
+#include "topi/elemwise.h"
 #include "topi/detail/ravel_unravel.h"
 #include "topi/detail/constant_utils.h"
 #include "tvm/tvm.h"
@@ -288,7 +289,7 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_bilinear(const Tensor& input,
-                              const Array<Expr>& shape,
+                              const Array<tvm::Expr>& shape,
                               std::string layout = "NCHW",
                               bool align_corners = false,
                               std::string name = "tensor",
-- 
GitLab