diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py
index eaad08de75fef7f7dc3249b441d8a220c8912ba4..d97b7c511bc54f4cd5d0b1d966446f35a475e173 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -10,3 +10,5 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
 from .reduction import schedule_reduce
 from .softmax import schedule_softmax
 from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
+from .dense import schedule_dense
+from .pooling import schedule_global_pool
diff --git a/topi/python/topi/cuda/dense.py b/topi/python/topi/cuda/dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..207aabc9e7ce3dbebc41aa15ea4baa7feee38668
--- /dev/null
+++ b/topi/python/topi/cuda/dense.py
@@ -0,0 +1,60 @@
+# pylint: disable=invalid-name, unused-variable
+"""Schedule for dense operator"""
+from __future__ import absolute_import as _abs
+import tvm
+from .. import tag
+
+def schedule_dense(outs):
+    """Schedule for dense operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of dense
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for dense.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    def _schedule(Dense):
+        num_thread = 64
+        k = Dense.op.reduce_axis[0]
+        ko, kf = s[Dense].split(k, factor=num_thread)
+        DenseF = s.rfactor(Dense, kf)
+
+        if Dense.op in s.outputs:
+            Out = Dense
+        else:
+            Out = outs[0].op.output(0)
+            s[Dense].compute_at(s[Out], s[Out].op.axis[1])
+        s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
+        s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
+
+        tx = s[Dense].op.reduce_axis[0]
+        thread_x = tvm.thread_axis("threadIdx.x")
+        s[Dense].bind(tx, thread_x)
+        s[DenseF].compute_at(s[Dense], tx)
+        s[Dense].set_store_predicate(thread_x.var.equal(0))
+        s[Out].set_store_predicate(thread_x.var.equal(0))
+
+    def traverse(OP):
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(OP.tag):
+            if OP not in s.outputs:
+                s[OP].compute_inline()
+            for tensor in OP.input_tensors:
+                if tensor.op.input_tensors:
+                    traverse(tensor.op)
+        # schedule dense
+        elif OP.tag == 'dense':
+            Dense = OP.output(0)
+            _schedule(Dense)
+        else:
+            raise RuntimeError("Unsupported operator: %s" % OP.tag)
+
+    traverse(outs[0].op)
+    return s
diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py
new file mode 100644
index 0000000000000000000000000000000000000000..db3714e9bf5f1f3d2ee90df078a38ae539dafd25
--- /dev/null
+++ b/topi/python/topi/cuda/pooling.py
@@ -0,0 +1,66 @@
+# pylint: disable=invalid-name, unused-variable
+"""Schedule for pooling operators"""
+import tvm
+from .. import tag
+
+def schedule_global_pool(outs):
+    """Schedule for global_pool.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of global_pool
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for global_pool.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    def _schedule(Pool):
+        num_thread = 8
+        block_x = tvm.thread_axis("blockIdx.x")
+        block_y = tvm.thread_axis("blockIdx.y")
+        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
+        thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
+        if Pool.op in s.outputs:
+            Out = Pool
+            OL = s.cache_write(Pool, "local")
+        else:
+            Out = outs[0].op.output(0)
+            s[Pool].set_scope("local")
+        i, c, h, w = s[Out].op.axis
+        dh, dw = s[Pool].op.reduce_axis
+        fuse_index = s[Pool].fuse(dw, dh)
+        s[Pool].unroll(fuse_index)
+        by, ty = s[Out].split(i, factor=num_thread)
+        bx, tx = s[Out].split(c, factor=num_thread)
+        s[Out].reorder(by, bx, ty, tx)
+        s[Out].bind(ty, thread_y)
+        s[Out].bind(tx, thread_x)
+        s[Out].bind(by, block_y)
+        s[Out].bind(bx, block_x)
+        if Pool.op in s.outputs:
+            s[OL].compute_at(s[Out], tx)
+        else:
+            s[Pool].compute_at(s[Out], tx)
+
+    def traverse(OP):
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(OP.tag):
+            if OP not in s.outputs:
+                s[OP].compute_inline()
+            for tensor in OP.input_tensors:
+                if tensor.op.input_tensors:
+                    traverse(tensor.op)
+        # schedule global_pool
+        elif 'global_pool' in OP.tag:
+            Pool = OP.output(0)
+            _schedule(Pool)
+        else:
+            raise RuntimeError("Unsupported operator: %s" % OP.tag)
+
+    traverse(outs[0].op)
+    return s
diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py
index 0cc6c9472fe8d91561ac3af9f7d48895afb83a81..46edac975183b21f72987618a66ca76d232a9855 100644
--- a/topi/python/topi/nn/__init__.py
+++ b/topi/python/topi/nn/__init__.py
@@ -8,7 +8,7 @@ from .depthwise_convolution import *
 from .elemwise import *
 from .dilate import *
 from .flatten import *
-from .fully_connected import *
+from .dense import *
 from .mapping import *
 from .pooling import *
 from .softmax import *
diff --git a/topi/python/topi/nn/dense.py b/topi/python/topi/nn/dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..e64bd9e9adc0ecb932922bc1e2fac3e9b36e0ea5
--- /dev/null
+++ b/topi/python/topi/nn/dense.py
@@ -0,0 +1,41 @@
+"""TVM operator fully connected compute."""
+from __future__ import absolute_import
+import tvm
+from .. import tag
+
+
+def dense(data, weight, bias, use_bias=True):
+    """Applies a linear transformation: :math:`Y = XW^T + b`.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        2-D with shape [batch, in_dim]
+
+    weight : tvm.Tensor
+        2-D with shape [out_dim, in_dim]
+
+    bias : tvm.Tensor
+        1-D with shape [out_dim]
+
+    use_bias : bool, optional, default=True
+        Whether to use bias parameter
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [batch, out_dim]
+    """
+    assert len(data.shape) == 2 and len(weight.shape) == 2 and len(bias.shape) == 1, \
+        "only support 2-dim dense"
+    batch, in_dim = data.shape
+    out_dim, _ = weight.shape
+    k = tvm.reduce_axis((0, in_dim), name='k')
+    matmul = tvm.compute((batch, out_dim), \
+                         lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \
+                         tag='dense')
+    if not use_bias:
+        return matmul
+    return tvm.compute((batch, out_dim), \
+                        lambda i, j: matmul[i, j] + bias[j], \
+                        tag=tag.BROADCAST)
diff --git a/topi/python/topi/nn/fully_connected.py b/topi/python/topi/nn/fully_connected.py
deleted file mode 100644
index 870df02424f55d0db16216b65cbc9bee9a3031e5..0000000000000000000000000000000000000000
--- a/topi/python/topi/nn/fully_connected.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""TVM operator fully connected compute."""
-from __future__ import absolute_import
-import tvm
-
-
-@tvm.tag_scope(tag='fully_connected')
-def fully_connected(data, weight):
-    """Matrix multiplication
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        2-D with shape [batch, in_dim]
-
-    weight : tvm.Tensor
-        2-D with shape [out_dim, in_dim]
-
-    Returns
-    -------
-    output : tvm.Tensor
-        2-D with shape [batch, out_dim]
-    """
-    assert len(data.shape) == 2 and len(weight.shape) == 2, \
-        "only support 2-dim fully_connected"
-    batch, in_dim = data.shape
-    out_dim, _ = weight.shape
-    k = tvm.reduce_axis((0, in_dim), name='k')
-    return tvm.compute((batch, out_dim), lambda i, j: \
-        tvm.sum(data[i][k] * weight[j][k], axis=k))
-
-
-@tvm.tag_scope(tag='fully_connected_with_bias')
-def fully_connected_with_bias(data, weight, bias):
-    """Applies a linear transformation: :math:`Y = XW^T + b`.
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        2-D with shape [batch, in_dim]
-
-    weight : tvm.Tensor
-        2-D with shape [out_dim, in_dim]
-
-    bias : tvm.Tensor
-        1-D with shape [out_dim]
-
-    Returns
-    -------
-    output : tvm.Tensor
-        2-D with shape [batch, out_dim]
-    """
-    assert len(data.shape) == 2 and len(weight.shape) == 2, \
-        "only support 2-dim fully_connected"
-    assert len(data.shape) == 2 and len(weight.shape) == 2 and len(bias.shape) == 1, \
-        "only support 2-dim fully_connected"
-    batch, in_dim = data.shape
-    out_dim, _ = weight.shape
-    k = tvm.reduce_axis((0, in_dim), name='k')
-    matmul = tvm.compute((batch, out_dim), lambda i, j: \
-        tvm.sum(data[i, k] * weight[j, k], axis=k))
-    return tvm.compute((batch, out_dim), lambda i, j: \
-        matmul[i, j] + bias[j])
diff --git a/topi/python/topi/nn/pooling.py b/topi/python/topi/nn/pooling.py
index 26d3a3e11f0b5e19491772210288074978f7a49a..511d10afce4984e42a610e9d7a990d925936f16f 100644
--- a/topi/python/topi/nn/pooling.py
+++ b/topi/python/topi/nn/pooling.py
@@ -4,6 +4,7 @@ import tvm
 from .pad import pad
 from .util import get_pad_tuple
 from .. import util
+from .. import tag
 
 def max_pool(data, kernel, stride, padding):
     """Perform max pooling on the data
@@ -51,15 +52,17 @@ def max_pool(data, kernel, stride, padding):
         tag="max_pool")
 
 
-@tvm.tag_scope(tag='global_avg_pool')
-def global_avg_pool(data):
-    """Perform global average pooling on the data
+def global_pool(data, pool_type):
+    """Perform global pooling on the data
 
     Parameters
     ----------
     data : tvm.Tensor
         4-D with shape [batch, channel, in_height, in_width]
 
+    pool_type : str
+        Pool type, 'max' or 'avg'
+
     Returns
     -------
     output : tvm.Tensor
@@ -71,7 +74,16 @@ def global_avg_pool(data):
     dheight = tvm.reduce_axis((0, height))
     dwidth = tvm.reduce_axis((0, width))
 
-    tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
-        tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]))
-    return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
-        tsum[n, c, h, w] / (height*width))
+    if pool_type == 'max':
+        return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
+                            tvm.max(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
+                            tag="global_pool_max")
+    elif pool_type == 'avg':
+        tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
+                            tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
+                            tag="global_pool_sum")
+        return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
+                            tsum[n, c, h, w] / (height*width), \
+                            tag=tag.ELEMWISE)
+    else:
+        raise ValueError("Pool type should be 'avg' or 'max'.")
diff --git a/topi/tests/python/test_topi_dense.py b/topi/tests/python/test_topi_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..20c9c6cc4eaa0cc5d1b00659a35c8ca5ab71823b
--- /dev/null
+++ b/topi/tests/python/test_topi_dense.py
@@ -0,0 +1,54 @@
+"""Test code for dense operator"""
+import numpy as np
+import tvm
+import topi
+from topi.util import get_const_tuple
+from tvm.contrib.pickle_memoize import memoize
+
+
+def verify_dense(batch, in_dim, out_dim, use_bias=True):
+    A = tvm.placeholder((batch, in_dim), name='A')
+    B = tvm.placeholder((out_dim, in_dim), name='B')
+    C = tvm.placeholder((out_dim,), name='C')
+    D = topi.nn.dense(A, B, C, use_bias=use_bias)
+    D = topi.nn.relu(D)
+    s = topi.cuda.schedule_dense(D)
+    dtype = A.dtype
+
+    # use memoize to pickle the test data for next time use
+    @memoize("topi.tests.test_topi_dense")
+    def get_ref_data():
+        a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
+        b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
+        c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
+        if use_bias:
+            d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
+        else:
+            d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
+        return (a_np, b_np, c_np, d_np)
+    # get the test data
+    a_np, b_np, c_np, d_np = get_ref_data()
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(c_np, ctx)
+        d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
+        f = tvm.build(s, [A, B, C, D], device, name="dense")
+        f(a, b, c, d)
+        np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
+
+    for device in ['cuda', 'opencl', 'metal']:
+        check_device(device)
+
+def test_dense():
+    verify_dense(1, 1024, 1000, use_bias=True)
+    verify_dense(1, 1024, 1000, use_bias=False)
+
+
+if __name__ == "__main__":
+    test_dense()
diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py
new file mode 100644
index 0000000000000000000000000000000000000000..397f5133c66ed1b43af0c4b835de08a0fd0ab6a0
--- /dev/null
+++ b/topi/tests/python/test_topi_pooling.py
@@ -0,0 +1,42 @@
+"""Test code for pooling"""
+import numpy as np
+import tvm
+import topi
+from topi.util import get_const_tuple
+
+def verify_global_pool(n, c, h, w, pool_type):
+    A = tvm.placeholder((n, c, h, w), name='A')
+    B = topi.nn.global_pool(A, pool_type=pool_type)
+    B = topi.nn.relu(B)
+    s = topi.cuda.schedule_global_pool(B)
+
+    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
+    if pool_type == 'avg':
+        b_np = np.mean(a_np, axis=(2,3), keepdims=True)
+    elif pool_type =='max':
+        b_np = np.max(a_np, axis=(2,3), keepdims=True)
+    b_np = np.maximum(b_np, 0.0)
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
+        f = tvm.build(s, [A, B], device, name="global_avg_pool")
+        f(a, b)
+        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+
+    for device in ['cuda', 'opencl', 'metal']:
+        check_device(device)
+
+def test_global_pool():
+    verify_global_pool(1, 1024, 7, 7, 'avg')
+    verify_global_pool(4, 1024, 7, 7, 'avg')
+    verify_global_pool(1, 1024, 7, 7, 'max')
+    verify_global_pool(4, 1024, 7, 7, 'max')
+
+
+if __name__ == "__main__":
+    test_global_pool()