diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index 47cab696a8e14ea4d6133fea8adf2d705e117173..0b937f6636bfd0bb1e4e1f8492a0b74417d187a9 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -38,6 +38,8 @@ This level enables fully connected multi-layer perceptron.
    tvm.relay.tanh
    tvm.relay.sigmoid
    tvm.relay.nn.relu
+   tvm.relay.nn.dropout
+   tvm.relay.nn.batch_norm
 
 
 **Level 2: Convolutions**
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index de0da7477a35b73e76573ace8875b87112e69338..0be85d3d1bb935513824466b04a3fd7ac034a52d 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -237,6 +237,41 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
   }
 };
 
+/*! \brief Attributes used in dropout operator */
+struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
+  double rate;
+  TVM_DECLARE_ATTRS(DropoutAttrs, "relay.attrs.DropoutAttrs") {
+    TVM_ATTR_FIELD(rate)
+      .describe("Fraction of the input that gets dropped out during training time")
+      .set_default(0.5);
+  }
+};  // struct DropoutAttrs
+
+/*! \brief Attributes used in batch_norm operator */
+struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
+  int axis;
+  double epsilon;
+  bool center;
+  bool scale;
+
+  TVM_DECLARE_ATTRS(BatchNormAttrs, "relay.attrs.BatchNormAttrs") {
+    TVM_ATTR_FIELD(axis)
+      .describe("Specify which shape axis denotes the channel.")
+      .set_default(1);
+    TVM_ATTR_FIELD(epsilon)
+      .describe("Small float added to variance to avoid dividing by zero")
+      .set_default(1e-5);
+    TVM_ATTR_FIELD(center)
+      .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored")
+      .set_default(true);
+    TVM_ATTR_FIELD(scale)
+      .describe("If True, multiply by gamma. If False, gamma is not used. "
+                "When the next layer is piecewise linear (also, e.g., nn.relu), "
+                "this can be disabled since the scaling will be done by the next layer.")
+      .set_default(true);
+  }
+};  // struct BatchNormAttrs
+
 /*! \brief Attributes for LRN operator */
 struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
   IndexExpr size;
diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py
index a429aea7d5ea6b1ab71865e23835fa99abf44c8a..42a29b29b7d7c735e80c90daf6e225132efb830f 100644
--- a/python/tvm/relay/ir_builder.py
+++ b/python/tvm/relay/ir_builder.py
@@ -11,6 +11,32 @@ from .expr import Expr, Constant, Let, Var, Function, If
 from .env import Environment
 
 
+class TupleWrapper(tvm._ffi.node.NodeGeneric):
+    """TupleWrapper.
+
+    This class is a Python wrapper for a Relay tuple of known size.
+    It allows for accessing the fields of the Relay tuple as though
+    it were a Python tuple.
+    """
+
+    def __init__(self, tuple_value, size):
+        self.tuple_value = tuple_value
+        self.size = size
+
+
+    def asnode(self):
+        """Returns the underlying Relay tuple if this wrapper is passed
+        as an argument to an FFI function."""
+
+        return self.tuple_value
+
+    def __getitem__(self, key):
+        return self.tuple_value.fields[key]
+
+    def __len__(self):
+        return len(self.tuple_value.fields)
+
+
 def _convert_to_value(arg, ctxt=tvm.cpu(0)):
     # type: (Any, tvm.Context) -> tvm.nd.NDArray
     """Convert Python values into the appropriate types
@@ -61,6 +87,8 @@ def convert(arg):
         return relay.Tuple([convert(el) for el in arg])
     elif isinstance(arg, PartialFunc):
         return arg.to_func()
+    elif isinstance(arg, tvm._ffi.node.NodeGeneric):
+        return arg.asnode()
     else:
         value = _convert_to_value(arg)
         return Constant(value)
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index e95e3e9b715dd4a48ebbb35a234f28b7f2a8a64e..313c26da0234d26d67b29bbf1ccf7469ea078f13 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -1,5 +1,6 @@
 """Neural network operations."""
 from __future__ import absolute_import as _abs
+from tvm.relay.ir_builder import TupleWrapper
 from . import _make
 
 
@@ -484,6 +485,7 @@ def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
 
     .. math::
         (data / (bias + (alpha * sum_data ^2 /size))^beta)
+
     Parameters
     ----------
     data : relay.Expr
@@ -535,3 +537,103 @@ def l2_normalize(data, eps, axis=None):
         The computed result.
     """
     return _make.l2_normalize(data, eps, axis)
+
+def dropout(data, rate=0.5):
+    """Applies the dropout operation to the input array.
+
+    During training, each element of the input is set to zero with
+    probability ``p``. The whole array is rescaled by ``1/(1-p)``
+    to keep the expected sum of the input unchanged.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    rate : float, optional (default=0.5)
+        The probability for an element to be reset to 0.
+
+    Returns
+    -------
+    result : relay.Tuple([relay.Expr, relay.Expr])
+        The first member of the tuple is the result of dropping elements from ``data``
+        and rescaling. The second member is a "mask" tensor, which is of the same
+        shape and data type as ``data`` and, for each element in ``data``, is 1.0
+        if the element was not dropped and 0.0 if it was.
+    """
+    result = _make.dropout(data, rate)
+    return TupleWrapper(result, 2)
+
+def batch_norm(data, gamma, beta, moving_mean, moving_var,
+               axis=1, epsilon=1e-5, center=True, scale=True):
+    r"""
+    Batch normalization layer (Ioffe and Szegedy, 2014).
+    Normalizes the input at each batch, i.e. applies a transformation
+    that maintains the mean activation close to 0 and the activation
+    standard deviation close to 1.
+
+    .. math::
+
+        data\_mean[i] = mean(data[:,i,:,...]) \\
+        data\_var[i] = var(data[:,i,:,...])
+
+    Then compute the normalized output, which has the same shape as input, as following:
+
+    .. math::
+
+        out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}}
+            * gamma[i] + beta[i]
+
+    Both *mean* and *var* returns a scalar by treating the input as a vector.
+
+    Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+    have shape *(k,)*.
+
+    Besides the inputs and the outputs, this operator accepts two auxiliary
+    states, ``moving_mean`` and ``moving_var``, which are *k*-length
+    vectors. They are global statistics for the whole dataset, which are updated by::
+
+    moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
+    moving_var = moving_var * momentum + data_var * (1 - momentum)
+
+    The parameter ``axis`` specifies which axis of the input shape denotes
+    the 'channel' (separately normalized groups).  The default is 1.
+    Specifying -1 sets the channel axis to be the last item in the input shape.
+
+    .. note::
+
+        This operator can be optimized away for inference.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        Input to which batch_norm will be applied.
+    gamma : relay.Expr
+        The gamma scale factor.
+    beta : relay.Expr
+        The beta offset factor.
+    moving_mean : relay.Expr
+        Running mean of input,
+    moving_var : relay.Expr
+        Running variance of input.
+    axis : int, optional, default=1
+        Specify along which shape axis the channel is specified.
+    epsilon : double, optional, default=1e-5
+        Small float added to variance to avoid diving by zero.
+    center : boolean, optional, default=True
+        If True, add offset of beta to normalized tensor, If False,
+        beta is ignored.
+    scale : boolean, optional, default=True
+        If true, multiply by gamma. If False, gamma is not used.
+        When the next layer is piecewise linear (also e.g. nn.relu),
+        this can be disabled since the scalingwill be done by the next layer.
+
+    Returns
+    -------
+    result : relay.Tuple([relay.Expr, relay.Expr, relay.Expr])
+        Tuple of normed data (same shape as input), new running mean (k-length vector),
+        and new running variance (k-length vector)
+    """
+    result = _make.batch_norm(data, gamma, beta, moving_mean, moving_var,
+                              axis, epsilon, center, scale)
+    return TupleWrapper(result, 3)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index f2439b9fb7ca83463461e8fa9b747f481d9d7362..23dfe90eebf0aef7d490338a7cabdaf25341d4bd 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -217,5 +217,177 @@ Normalizes along dimension axis using an L2 norm
 .set_support_level(2)
 .add_type_rel("Identity", IdentityRel);
 
+// Dropout
+TVM_REGISTER_NODE_TYPE(DropoutAttrs);
+
+bool DropoutRel(const Array<Type>& types,
+                int num_inputs,
+                const Attrs& attrs,
+                const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  // dropout returns the original tensor with dropout applied
+  // and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
+  auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
+  reporter->Assign(types[1], TupleTypeNode::make(Array<Type>({ret_type, ret_type})));
+  return true;
+}
+
+Expr MakeDropout(Expr data, double rate) {
+  auto attrs = make_node<DropoutAttrs>();
+  attrs->rate = rate;
+  static const Op& op = Op::Get("nn.dropout");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op.nn._make.dropout")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 2>(MakeDropout, args, rv);
+  });
+
+RELAY_REGISTER_OP("nn.dropout")
+.describe(R"code(Applies the dropout operation to the input array.
+
+During training, each element of the input is set to zero with probability ``p``.
+The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "Input to which dropout will be applied.")
+.set_support_level(1)
+.add_type_rel("Dropout", DropoutRel);
+
+// batch_norm
+TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
+
+bool CheckVectorLength(int64_t dim, const DataType& dtype, Type vector, const char* name) {
+  const auto* candidate = vector.as<TensorTypeNode>();
+  CHECK(candidate != nullptr)
+    << name << " should be a vector but is not a tensor type,";
+  CHECK_EQ(dtype, candidate->dtype)
+    << name << " should be of the same data type as the original but it is not.";
+  CHECK_EQ(candidate->shape.size(), 1)
+    << name << " should be a vector but has a shape of "
+    << candidate->shape.size() << " dimensions instead of 1.";
+
+  const int64_t* length = as_const_int(candidate->shape[0]);
+  if (length == nullptr) return false;
+  CHECK(*length == dim)
+    << name << " should be as long as the channel but has length "
+    << *length << " instead of " << dim << ".";
+  return true;
+}
+
+bool BatchNormRel(const Array<Type>& types,
+                  int num_inputs,
+                  const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 6);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  if (data->shape.size() == 0) return false;
+
+  const BatchNormAttrs* param = attrs.as<BatchNormAttrs>();
+
+  // axis of -1 means use the last dimension
+  CHECK(param->axis >= -1 && param->axis < (int)data->shape.size());
+  int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1;
+
+  auto dim = as_const_int(data->shape[axis]);
+  if (dim == nullptr) return false;
+
+  // if we are using beta and gamma, they need to be of shape (dim,)
+  if (param->scale && !CheckVectorLength(*dim, data->dtype, types[1], "The gamma scale factor")) {
+    return false;
+  }
+
+  if (param->center && !CheckVectorLength(*dim, data->dtype, types[2], "The beta offset factor")) {
+    return false;
+  }
+
+  // the two running averages must also be vectors of length dim
+  if (!CheckVectorLength(*dim, data->dtype, types[3], "The moving mean")) {
+    return false;
+  }
+  if (!CheckVectorLength(*dim, data->dtype, types[4], "The moving variance")) {
+    return false;
+  }
+
+  // output is a tuple of the normed data (same shape as input), new running mean,
+  // and new running average (the latter two are both vectors of length dim)
+  std::vector<Type> fields;
+  auto vec_ty = TensorTypeNode::make(Array<IndexExpr>({data->shape[axis]}),
+                                     data->dtype);
+  fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
+  fields.push_back(vec_ty);
+  fields.push_back(vec_ty);
+  reporter->Assign(types[5], TupleTypeNode::make(Array<Type>(fields)));
+  return true;
+}
+
+Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var,
+                   int axis, double epsilon, bool center, bool scale) {
+  auto attrs = make_node<BatchNormAttrs>();
+  attrs->axis = axis;
+  attrs->epsilon = epsilon;
+  attrs->center = center;
+  attrs->scale = scale;
+  static const Op& op = Op::Get("nn.batch_norm");
+  return CallNode::make(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op.nn._make.batch_norm")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 9>(MakeBatchNorm, args, rv);
+  });
+
+RELAY_REGISTER_OP("nn.batch_norm")
+.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
+Normalizes the input at each batch, i.e. applies a transformation
+that maintains the mean activation close to 0 and the activation
+standard deviation close to 1.
+
+.. math::
+
+  data\_mean[i] = mean(data[:,i,:,...]) \\
+  data\_var[i] = var(data[:,i,:,...])
+
+Then compute the normalized output, which has the same shape as input, as following:
+
+.. math::
+
+  out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} \
+* gamma[i] + beta[i]
+
+Both *mean* and *var* returns a scalar by treating the input as a vector.
+
+Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` have shape *(k,)*.
+
+Besides the inputs and the outputs, this operator accepts two auxiliary
+states, ``moving_mean`` and ``moving_var``, which are *k*-length
+vectors. They are global statistics for the whole dataset, which are updated
+by::
+
+  moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
+  moving_var = moving_var * momentum + data_var * (1 - momentum)
+
+The parameter ``axis`` specifies which axis of the input shape denotes
+the 'channel' (separately normalized groups).  The default is 1.  Specifying -1 sets the channel
+axis to be the last item in the input shape.
+
+.. note::
+    This operator can be optimized away for inference.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(5)
+.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
+.add_argument("gamma", "Tensor", "The gamma scale factor.")
+.add_argument("beta", "Tensor", "The beta offset factor.")
+.add_argument("moving_mean", "Tensor", "Running mean of input.")
+.add_argument("moving_var", "Tensor", "Running variance of input.")
+.set_support_level(1)
+.add_type_rel("BatchNorm", BatchNormRel);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index 05c02ab5d197a957555db998f0e9ecf92896bac6..914eafeb57a96b485c83319f7cf48845179e3ba6 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -196,6 +196,93 @@ def test_l2_normalize():
     ftype = func.checked_type
     assert ftype.ret_type == relay.ty.TensorType((n, c , h, w), "float32")
 
+def test_dropout():
+    ib = relay.ir_builder.IRBuilder()
+    input_ty = relay.ty.TensorType((3, 4, 5), "int8")
+    x = ib.param("x", input_ty)
+    with ib.function(x) as func:
+        ib.ret(relay.nn.dropout(x))
+    ib.ret(func)
+
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TupleType([input_ty, input_ty])
+
+    ib = relay.ir_builder.IRBuilder()
+    n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d")
+    input_ty = relay.ty.TensorType((n, t, d), "float32")
+    x = ib.param("x", input_ty)
+    with ib.function(x) as func:
+        ib.ret(relay.nn.dropout(x, rate=0.75))
+    ib.ret(func)
+
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TupleType([input_ty, input_ty])
+
+
+def test_batch_norm():
+    # beta and gamma ignored
+    ib = relay.ir_builder.IRBuilder()
+    data = ib.param("data", relay.ty.TensorType((3, 2, 1), "float32"))
+    gamma = ib.param("gamma", relay.ty.TensorType((5,), "int8"))
+    beta = ib.param("beta", relay.ty.TensorType((12, 16), "int64"))
+    moving_mean = ib.param("moving_mean", relay.ty.TensorType((2,), "float32"))
+    moving_var = ib.param("moving_var", relay.ty.TensorType((2,), "float32"))
+    with ib.function(data, gamma, beta, moving_mean, moving_var) as func:
+        ib.ret(relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
+                                   center=False, scale=False))
+    ib.ret(func)
+
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TupleType(tvm.convert([
+        relay.ty.TensorType((3, 2, 1), "float32"),
+        relay.ty.TensorType((2,), "float32"),
+        relay.ty.TensorType((2,), "float32")
+    ]))
+
+    # with beta and gamma, different axis
+    ib = relay.ir_builder.IRBuilder()
+    data = ib.param("data", relay.ty.TensorType((3, 2, 1), "float32"))
+    gamma = ib.param("gamma", relay.ty.TensorType((3,), "float32"))
+    beta = ib.param("beta", relay.ty.TensorType((3,), "float32"))
+    moving_mean = ib.param("moving_mean", relay.ty.TensorType((3,), "float32"))
+    moving_var = ib.param("moving_var", relay.ty.TensorType((3,), "float32"))
+    with ib.function(data, gamma, beta, moving_mean, moving_var) as func:
+        ib.ret(relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
+                                   axis=0, center=False, scale=False))
+    ib.ret(func)
+
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TupleType(tvm.convert([
+        relay.ty.TensorType((3, 2, 1), "float32"),
+        relay.ty.TensorType((3,), "float32"),
+        relay.ty.TensorType((3,), "float32")
+    ]))
+
+    # axis=-1
+    ib = relay.ir_builder.IRBuilder()
+    data = ib.param("data", relay.ty.TensorType((1, 2, 3), "float32"))
+    gamma = ib.param("gamma", relay.ty.TensorType((3,), "float32"))
+    beta = ib.param("beta", relay.ty.TensorType((3,), "float32"))
+    moving_mean = ib.param("moving_mean", relay.ty.TensorType((3,), "float32"))
+    moving_var = ib.param("moving_var", relay.ty.TensorType((3,), "float32"))
+    with ib.function(data, gamma, beta, moving_mean, moving_var) as func:
+        ib.ret(relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
+                                   axis=-1, center=False, scale=False))
+    ib.ret(func)
+
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TupleType(tvm.convert([
+        relay.ty.TensorType((1, 2, 3), "float32"),
+        relay.ty.TensorType((3,), "float32"),
+        relay.ty.TensorType((3,), "float32")
+    ]))
+
+
 if __name__ == "__main__":
     test_unary_op()
     test_single_op()
@@ -207,3 +294,5 @@ if __name__ == "__main__":
     test_binary_broadcast_op()
     test_lrn()
     test_l2_normalize()
+    test_dropout()
+    test_batch_norm()