diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py
index 913f97ecd4a19cea58cb24052b53e58c8314f8df..43160d64549c95d72cec4088850797cbc4771e29 100644
--- a/python/tvm/relay/testing/__init__.py
+++ b/python/tvm/relay/testing/__init__.py
@@ -6,4 +6,5 @@ from . import resnet
 from . import dqn
 from . import dcgan
 from . import mobilenet
+from . import lstm
 from .config import ctx_list
diff --git a/python/tvm/relay/testing/layers.py b/python/tvm/relay/testing/layers.py
index 1b279d9e72af7ecc5be48b073f3865f1bd1b2e55..9d4d3b3b4e133400f874607ce24d4d3cf95cce76 100644
--- a/python/tvm/relay/testing/layers.py
+++ b/python/tvm/relay/testing/layers.py
@@ -105,7 +105,7 @@ def conv2d_transpose(data, weight=None, **kwargs):
         weight = relay.var(name + "_weight")
     return relay.nn.conv2d_transpose(data, weight, **kwargs)
 
-def dense_add_bias(data, weight=None, bias=None, **kwargs):
+def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
     """Wrapper of dense which automatically creates weights if not given.
 
     Parameters
@@ -133,6 +133,6 @@ def dense_add_bias(data, weight=None, bias=None, **kwargs):
         weight = relay.var(name + "_weight")
     if not bias:
         bias = relay.var(name + "_bias")
-    data = relay.nn.dense(data, weight, **kwargs)
+    data = relay.nn.dense(data, weight, units, **kwargs)
     data = relay.nn.bias_add(data, bias)
     return data
diff --git a/python/tvm/relay/testing/lstm.py b/python/tvm/relay/testing/lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..47e68a988dab7166cba5cca2c2bacd92601025f3
--- /dev/null
+++ b/python/tvm/relay/testing/lstm.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Implementation of a Long Short-Term Memory (LSTM) cell.
+
+Adapted from:
+https://gist.github.com/merrymercy/5eb24e3b019f84200645bd001e9caae9
+"""
+
+from tvm import relay
+from . import layers
+from .init import create_workload
+
+def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
+    """Long-Short Term Memory (LSTM) network cell.
+
+    Parameters
+    ----------
+    num_hidden : int
+        Number of units in output symbol.
+
+    batch_size : int
+        Batch size (length of states).
+
+    Returns
+    -------
+    result : tvm.relay.Function
+        A Relay function that evaluates an LSTM cell.
+        The function takes in a tensor of input data, a tuple of two
+        states, and weights and biases for dense operations on the
+        inputs and on the state. It returns a tuple with two members,
+        an output tensor and a tuple of two new states.
+    """
+    builder = relay.ScopeBuilder()
+
+    input_type = relay.TensorType((batch_size, num_hidden), dtype)
+    weight_type = relay.TensorType((num_hidden, 4*num_hidden), dtype)
+    bias_type = relay.TensorType((4*num_hidden,), dtype)
+
+    dense_type = relay.TensorType((batch_size, 4*num_hidden), dtype)
+    slice_type = relay.TupleType([input_type, input_type,
+                                  input_type, input_type])
+    ret_type = relay.TupleType([input_type,
+                                relay.TupleType([input_type, input_type])])
+
+    inputs = relay.Var("inputs", input_type)
+    states = relay.Var("states",
+                       relay.TupleType([input_type, input_type]))
+
+    i2h_weight = relay.Var("i2h_weight", weight_type)
+    i2h_bias = relay.Var("i2h_bias", bias_type)
+
+    h2h_weight = relay.Var("h2h_weight", weight_type)
+    h2h_bias = relay.Var("h2h_bias", bias_type)
+
+    i2h = builder.let(("i2h", dense_type),
+                      layers.dense_add_bias(
+                          data=inputs,
+                          units=num_hidden * 4,
+                          weight=i2h_weight, bias=i2h_bias,
+                          name="%si2h" % name))
+    h2h = builder.let(("h2h", dense_type),
+                      layers.dense_add_bias(
+                          data=relay.TupleGetItem(states, 0),
+                          units=num_hidden * 4,
+                          weight=h2h_weight, bias=h2h_bias,
+                          name="%sh2h" % name))
+
+    gates = builder.let(("gates", dense_type), relay.add(i2h, h2h))
+    slice_gates = builder.let(("slice_gates", slice_type),
+                              relay.split(gates,
+                                          indices_or_sections=4,
+                                          axis=1).astuple())
+
+    in_gate = builder.let(("in_gate", input_type),
+                          relay.sigmoid(relay.TupleGetItem(slice_gates, 0)))
+    forget_gate = builder.let(("forget_gate", input_type),
+                              relay.sigmoid(relay.TupleGetItem(slice_gates, 1)))
+    in_transform = builder.let(("in_transform", input_type),
+                               relay.tanh(relay.TupleGetItem(slice_gates, 2)))
+    out_gate = builder.let(("out_gate", input_type),
+                           relay.sigmoid(relay.TupleGetItem(slice_gates, 3)))
+
+    next_c = builder.let(("next_c", input_type),
+                         relay.add(relay.multiply(forget_gate,
+                                                  relay.TupleGetItem(states, 1)),
+                                   relay.multiply(in_gate, in_transform)))
+    next_h = builder.let(("next_h", input_type),
+                         relay.multiply(out_gate, relay.tanh(next_c)))
+    ret = builder.let(("ret", ret_type),
+                      relay.Tuple([next_h, relay.Tuple([next_h, next_c])]))
+    builder.ret(ret)
+
+    body = builder.get()
+
+    return relay.Function([inputs, states, i2h_weight,
+                           i2h_bias, h2h_weight, h2h_bias],
+                          body, ret_type)
+
+
+def get_net(iterations, num_hidden, batch_size=1, dtype="float32"):
+    '''Constructs an unrolled RNN with LSTM cells'''
+    input_type = relay.TensorType((batch_size, num_hidden), dtype)
+    weight_type = relay.TensorType((num_hidden, 4*num_hidden), dtype)
+    bias_type = relay.TensorType((4*num_hidden,), dtype)
+
+    state_type = relay.TupleType([input_type, input_type])
+    cell_type = relay.TupleType([input_type, state_type])
+
+    builder = relay.ScopeBuilder()
+
+    zeros = builder.let(("zeros", input_type),
+                        relay.zeros((batch_size, num_hidden), dtype))
+    init_states = builder.let(("init_states", state_type),
+                              relay.Tuple([zeros, zeros]))
+
+    states = init_states
+    out = None
+
+    for i in range(iterations):
+        inputs = relay.Var("data", input_type)
+        i2h_weight = relay.Var("i2h_%s_weight" % i, weight_type)
+        i2h_bias = relay.Var("i2h_%i_bias" % i, bias_type)
+        h2h_weight = relay.Var("h2h_%s_weight" % i, weight_type)
+        h2h_bias = relay.Var("h2h_%s_bias" % i, bias_type)
+
+        cell_fn = lstm_cell(num_hidden, batch_size, dtype, "lstm_%s" % i)
+
+        call = builder.let(("call_%s" % i, cell_type),
+                           relay.Call(cell_fn,
+                                      [inputs, states, i2h_weight,
+                                       i2h_bias, h2h_weight, h2h_bias]))
+        new_out = builder.let(("out_%s" % i, input_type),
+                              relay.TupleGetItem(call, 0))
+        new_states = builder.let(("states_%s" % i, state_type),
+                                 relay.TupleGetItem(call, 1))
+        states = new_states
+        out = new_out
+
+    builder.ret(out)
+    body = builder.get()
+    args = relay.ir_pass.free_vars(body)
+    return relay.Function(args, body, input_type)
+
+
+def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"):
+    """Get benchmark workload for an LSTM RNN.
+
+    Parameters
+    ----------
+    iterations : int
+        The number of iterations in the desired LSTM RNN.
+    num_hidden : int
+        The size of the hiddxen state
+    batch_size : int, optional (default 1)
+        The batch size used in the model
+    dtype : str, optional (default "float32")
+        The data type
+    Returns
+    -------
+    net : nnvm.symbol
+        The computational graph
+    params : dict of str to NDArray
+        The parameters.
+    """
+    net = get_net(iterations, num_hidden, batch_size, dtype)
+    return create_workload(net)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 7a3a2151158d2963643295706c79d56ec67986ab..a9e0a969fc5bbdf4d25283fe099dfc96215bb58c 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1078,7 +1078,7 @@ bool SplitRel(const Array<Type>& types,
   }
   CHECK_LT(axis, data->shape.size())
     << "axis should be within the input dimension range.";
-  CHECK_GT(axis, 0)
+  CHECK_GE(axis, 0)
     << "axis should be within the input dimension range.";
 
   if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py
index d12804d512f09321e2d733917f2fd1666ce6e9eb..30130fd7bcac7c7451012428069c6e0624f9f206 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -97,10 +97,12 @@ def test_variable_name():
     v1 = relay.var("1")
     assert "%v1" in v1.astext()
 
+
 def test_mlp():
     net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
     net.astext()
 
+
 def test_resnet():
     net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
     net.astext()
@@ -117,6 +119,12 @@ def test_dcgan():
     net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
     net.astext()
 
+
+def test_lstm():
+    net, params = tvm.relay.testing.lstm.get_workload(4, 4)
+    net.astext()
+
+
 if __name__ == "__main__":
     do_print[0] = True
     test_resnet()
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 22469cc7fdbe91ea6a52a62726897440242b49a5..6f8fbd55129372cc3e833e9658c9b5b2afc06f45 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -161,6 +161,14 @@ def test_split_infer_type():
                      relay.ty.TensorType((5, 1, 2, 2), "float32"),
                      relay.ty.TensorType((5, 1, 2, 2), "float32")])),
                   axis=1)
+    verify_split((5, 5, 2, 2), 5,
+                 relay.ty.TupleType(tvm.convert([
+                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                     relay.ty.TensorType((1, 5, 2, 2), "float32"),
+                     relay.ty.TensorType((1, 5, 2, 2), "float32")])),
+                  axis=0)
     verify_split((d1, d2, d3, d4), 4,
                  relay.ty.TupleType(tvm.convert([
                      relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
@@ -168,6 +176,11 @@ def test_split_infer_type():
                      relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
                      relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
                   axis=2)
+    verify_split((d1, d2, d3, d4), 2,
+                 relay.ty.TupleType(tvm.convert([
+                     relay.ty.TensorType((d1/2, d2, d3, d4), "float32"),
+                     relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])),
+                  axis=0)
     verify_split((d1, d2, d3, d4), (2, 4, 7),
                  relay.ty.TupleType(tvm.convert([
                      relay.ty.TensorType((d1, 2, d3, d4), "float32"),