From d87c94d475e5eac3b98dbdf1ac7811985905f6c6 Mon Sep 17 00:00:00 2001
From: Liangfu Chen <liangfu.chen@icloud.com>
Date: Fri, 7 Sep 2018 01:29:47 +0800
Subject: [PATCH] [Sparse] add sparse tensor computation support (#1289)

---
 python/tvm/autotvm/task/dispatcher.py |   2 +-
 python/tvm/contrib/sparse.py          | 163 ++++++++++++++++++++
 tests/python/contrib/test_sparse.py   | 100 +++++++++++++
 topi/python/topi/__init__.py          |   1 +
 topi/python/topi/sparse/__init__.py   |   7 +
 topi/python/topi/sparse/csrmm.py      |  94 ++++++++++++
 topi/python/topi/sparse/csrmv.py      |  90 +++++++++++
 topi/python/topi/sparse/dense.py      | 173 ++++++++++++++++++++++
 topi/tests/python/test_topi_sparse.py | 205 ++++++++++++++++++++++++++
 9 files changed, 834 insertions(+), 1 deletion(-)
 create mode 100644 python/tvm/contrib/sparse.py
 create mode 100644 tests/python/contrib/test_sparse.py
 create mode 100644 topi/python/topi/sparse/__init__.py
 create mode 100644 topi/python/topi/sparse/csrmm.py
 create mode 100644 topi/python/topi/sparse/csrmv.py
 create mode 100644 topi/python/topi/sparse/dense.py
 create mode 100644 topi/tests/python/test_topi_sparse.py

diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py
index ec1dcc44f..398e850d8 100644
--- a/python/tvm/autotvm/task/dispatcher.py
+++ b/python/tvm/autotvm/task/dispatcher.py
@@ -16,8 +16,8 @@ from __future__ import absolute_import as _abs
 
 import logging
 
-from decorator import decorate
 import numpy as np
+from decorator import decorate
 
 from tvm import target as _target
 
diff --git a/python/tvm/contrib/sparse.py b/python/tvm/contrib/sparse.py
new file mode 100644
index 000000000..523039912
--- /dev/null
+++ b/python/tvm/contrib/sparse.py
@@ -0,0 +1,163 @@
+"""Tensor and Operation class for computation declaration."""
+# pylint: disable=invalid-name
+from __future__ import absolute_import as _abs
+import numpy as _np
+from .. import expr as _expr
+from .. import api as _api
+from .. import tensor as _tensor
+from .. import ndarray as _nd
+
+float32 = "float32"
+itype = 'int32'
+
+class CSRNDArray(object):
+    """Sparse tensor object in CSR format."""
+    def __init__(self, arg1, ctx=None, shape=None):
+        """Construct a sparse matrix in CSR format.
+
+        Parameters
+        ----------
+        arg1 : numpy.ndarray or a tuple with (data, indices, indptr)
+            The corresponding a dense numpy array,
+            or a tuple for constructing a sparse matrix directly.
+
+        ctx: tvm.TVMContext
+            The corresponding context.
+
+        shape : tuple of int
+            The shape of the array
+        """
+        if isinstance(arg1, tuple):
+            assert len(arg1) == 3
+            self.data, self.indices, self.indptr = arg1
+            self.shape = shape
+        elif isinstance(arg1, _np.ndarray):
+            source_array = arg1
+            ridx, cidx = _np.nonzero(source_array)
+            data = source_array[ridx, cidx]
+            self.data = _nd.array(data, ctx)
+            indices = _np.nonzero(source_array)[1].astype(itype)
+            self.indices = _nd.array(indices, ctx)
+            indptr = [0]+_np.apply_along_axis(_np.count_nonzero, axis=1, arr=source_array).tolist()
+            indptr = _np.cumsum(_np.array(indptr, itype)).astype(itype)
+            self.indptr = _nd.array(indptr, ctx)
+            self.shape = source_array.shape
+        else:
+            raise RuntimeError("Construct CSRNDArray with either a tuple (data, indices, indptr) "
+                               "or a numpy.array, can't handle type %s." % (type(arg1),))
+        self.stype = 'csr'
+        self.dtype = self.data.dtype
+        assert self.shape is not None
+        assert isinstance(self.data, _nd.NDArray)
+        assert isinstance(self.indices, _nd.NDArray)
+        assert str(self.indices.dtype) == 'int32' or \
+            str(self.indices.dtype) == 'int64', str(self.indices.dtype)
+        assert isinstance(self.indptr, _nd.NDArray)
+        assert str(self.indptr.dtype) == 'int32' or \
+            str(self.indptr.dtype) == 'int64', str(self.indptr.dtype)
+
+    def asnumpy(self):
+        """Construct a full matrix and convert it to numpy array."""
+        full = _np.zeros(self.shape, self.dtype)
+        ridx = _np.diff(self.indptr.asnumpy())
+        ridx = _np.hstack((_np.ones((v,), itype)*i for i, v in enumerate(ridx)))
+        full[ridx, self.indices.asnumpy().astype(itype)] = self.data.asnumpy()
+        return full
+
+def array(source_array, ctx=None, shape=None, stype='csr'):
+    """Construct a sparse NDArray from numpy.ndarray"""
+    ret = None
+    if stype == 'csr':
+        ret = CSRNDArray(source_array, shape=shape, ctx=ctx)
+    else:
+        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
+    return ret
+
+class SparsePlaceholderOp(object):
+    """Placeholder class for sparse tensor representations."""
+    def __init__(self, shape, nonzeros, dtype, name):
+        # pylint: disable=unused-argument
+        """Contructing a bare bone structure for a sparse matrix
+
+        Parameters
+        ----------
+        shape: Tuple of Expr
+            The shape of the tensor
+
+        nonzeros: int
+            The number of non-zero values
+
+        dtype: str, optional
+            The data type of the tensor
+
+        name: str, optional
+            The name hint of the tensor
+        """
+        self.shape = shape
+        self.dtype = dtype
+        self.name = name
+        self.stype = 'unknown'
+
+class CSRPlaceholderOp(SparsePlaceholderOp):
+    """Placeholder class for CSR based sparse tensor representation."""
+    def __init__(self, shape, nonzeros, dtype, name):
+        """Contructing a bare bone structure for a csr_matrix
+
+        Parameters
+        ----------
+        shape: Tuple of Expr
+            The shape of the tensor
+
+        nonzeros: int
+            The number of non-zero values
+
+        dtype: str, optional
+            The data type of the tensor
+
+        name: str, optional
+            The name hint of the tensor
+        """
+        SparsePlaceholderOp.__init__(self, shape, nonzeros, dtype, name)
+        self.stype = 'csr'
+        self.data = _api.placeholder((nonzeros,), dtype=dtype, name=self.name+'_data')
+        self.indices = _api.placeholder((nonzeros,), dtype=itype, name=self.name+'_indices')
+        self.indptr = _api.placeholder((self.shape[0]+1,), dtype=itype, name=self.name+'_indptr')
+        assert isinstance(self.data, _tensor.Tensor)
+        assert isinstance(self.indices, _tensor.Tensor)
+        assert isinstance(self.indptr, _tensor.Tensor)
+
+def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None):
+    """Construct an empty sparse tensor object.
+
+    Parameters
+    ----------
+    shape: Tuple of Expr
+        The shape of the tensor
+
+    nonzeros: int
+        The number of non-zero values
+
+    dtype: str, optional
+        The data type of the tensor
+
+    name: str, optional
+        The name hint of the tensor
+
+    stype: str, optional
+        The name storage type of the sparse tensor (e.g. csr, coo, ell)
+
+    Returns
+    -------
+    tensor: SparsePlaceholderOp
+        The created sparse tensor placeholder
+    """
+    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+    nonzeros = 0 if nonzeros is None else nonzeros
+    dtype = float32 if dtype is None else dtype
+    stype = 'csr' if stype is None else stype
+    ret = None
+    if stype == 'csr':
+        ret = CSRPlaceholderOp(shape=shape, nonzeros=nonzeros, dtype=dtype, name=name)
+    else:
+        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
+    return ret
diff --git a/tests/python/contrib/test_sparse.py b/tests/python/contrib/test_sparse.py
new file mode 100644
index 000000000..f7a0d1d13
--- /dev/null
+++ b/tests/python/contrib/test_sparse.py
@@ -0,0 +1,100 @@
+import tvm
+import tvm.contrib.sparse as tvmsp
+import tvm.ndarray as _nd
+import numpy as np
+from collections import namedtuple
+
+def test_static_tensor():
+    dtype = 'float32'
+    stype = 'csr'
+    target = 'llvm'
+    ctx = tvm.context(target, 0)
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A = tvmsp.placeholder(shape=(m, n), name='A', dtype=dtype)
+    assert(A.stype == 'csr')
+    n = 3
+    a = np.maximum(np.random.uniform(size=(n,n)).astype(dtype)-.6, 0.)
+    a = tvmsp.array(a, ctx)
+    A.data = tvm.placeholder(a.data.shape, dtype, name='A_data')
+    Ab = tvm.decl_buffer(a.data.shape, dtype, name='A_data')
+    binds = {A.data: Ab}
+    C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
+    s = tvm.create_schedule(C.op)
+    f = tvm.build(s, [A.data, C], target, binds=binds)
+    c = tvmsp.array(np.zeros((n,n), dtype), ctx)
+    c.data = tvm.nd.empty(a.data.shape, dtype)
+    c.indices = a.indices
+    c.indptr = a.indptr
+    f(a.data, c.data)
+    np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
+
+def test_dynamic_tensor():
+    dtype = 'float32'
+    stype = 'csr'
+    target = 'llvm'
+    ctx = tvm.context(target, 0)
+    nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n')
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
+    assert(A.stype == 'csr')
+    C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
+    s = tvm.create_schedule(C.op)
+    _nr, _nc = 3, 5
+    a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
+    a = tvmsp.array(a, ctx)
+    assert a.data.dtype == a.dtype
+    Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
+    Ab.data = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
+    Ab.indices = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
+    binds = {A.data: Ab.data, A.indices: Ab.indices}
+    f = tvm.build(s, [nr, A.data, C], target, binds=binds)
+    c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
+    c.data = tvm.nd.empty(a.data.shape, dtype)
+    c.indices = a.indices
+    c.indptr = a.indptr
+    f(a.data.shape[0], a.data, c.data)
+    np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
+
+def test_sparse_array_tuple():
+    dtype, itype = 'float32', 'int32'
+    stype = 'csr'
+    target = 'llvm'
+    ctx = tvm.context(target, 0)
+    nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n')
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
+    assert(A.stype == 'csr')
+    C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
+    s = tvm.create_schedule(C.op)
+    _nr, _nc = 3, 5
+    a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
+    # convert to sparse array tuple
+    source_array = a
+    ridx, cidx = np.nonzero(source_array)
+    data = source_array[ridx, cidx]
+    a_data = _nd.array(data, ctx)
+    indices = np.nonzero(source_array)[1].astype(itype)
+    a_indices = _nd.array(indices, ctx)
+    indptr = [0]+np.apply_along_axis(np.count_nonzero, axis=1, arr=source_array).tolist()
+    indptr = np.cumsum(np.array(indptr, itype)).astype(itype)
+    a_indptr = _nd.array(indptr, ctx)
+    a_init = (a_data, a_indices, a_indptr)
+    # construct tvm sparse array with tuple
+    a = tvmsp.array(a_init, shape=source_array.shape, ctx=ctx)
+    assert a.data.dtype == a.dtype
+    Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
+    Ab.data = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
+    Ab.indices = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
+    binds = {A.data: Ab.data, A.indices: Ab.indices}
+    f = tvm.build(s, [nr, A.data, C], target, binds=binds)
+    c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
+    c.data = tvm.nd.empty(a.data.shape, dtype)
+    c.indices = a.indices
+    c.indptr = a.indptr
+    f(a.data.shape[0], a.data, c.data)
+    np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
+
+if __name__ == "__main__":
+    test_static_tensor()
+    test_dynamic_tensor()
+    test_sparse_array_tuple()
+
diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py
index 3ef59913e..2eb460d15 100644
--- a/topi/python/topi/__init__.py
+++ b/topi/python/topi/__init__.py
@@ -32,6 +32,7 @@ from . import util
 from . import rocm
 from . import vision
 from . import image
+from . import sparse
 from . import hls
 # not import testing by default
 # because testing can have extra deps that are not necessary
diff --git a/topi/python/topi/sparse/__init__.py b/topi/python/topi/sparse/__init__.py
new file mode 100644
index 000000000..bfac967d2
--- /dev/null
+++ b/topi/python/topi/sparse/__init__.py
@@ -0,0 +1,7 @@
+# pylint: disable=wildcard-import
+"""Sparse operators"""
+from __future__ import absolute_import as _abs
+
+from .csrmv import csrmv
+from .csrmm import csrmm
+from .dense import dense
diff --git a/topi/python/topi/sparse/csrmm.py b/topi/python/topi/sparse/csrmm.py
new file mode 100644
index 000000000..f0574bf3d
--- /dev/null
+++ b/topi/python/topi/sparse/csrmm.py
@@ -0,0 +1,94 @@
+"""TVM operator compute SpMM in CSR format."""
+from __future__ import absolute_import
+import tvm
+from .. import tag
+from ..util import simplify
+
+def csrmm_default(data, indices, indptr, weight, bias=None):
+    # pylint: disable=invalid-name
+    """The default implementation of csrmm in topi.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        1-D with shape [nonzeros]
+
+    indices : tvm.Tensor
+        1-D with shape [nonzeros]
+
+    indptr : tvm.Tensor
+        1-D with shape [m+1]
+
+    weight : tvm.Tensor
+        2-D with shape [k, n]
+
+    bias : tvm.Tensor, optional
+        1-D with shape [m]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [m, n]
+    """
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
+        and len(weight.shape) == 2, "only support 2-dim csrmm"
+    assert isinstance(weight, tvm.tensor.Tensor), \
+        "weight matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(weight))
+    if bias is not None:
+        assert len(bias.shape) == 1
+    M = simplify(indptr.shape[0]-1)
+    _, N = weight.shape
+    def csrmm_default_ir(data, indices, indptr, weight, out):
+        """define ir for csrmm"""
+        irb = tvm.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        weight_ptr = irb.buffer_ptr(weight)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        _, N = weight.shape
+        with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+            with irb.for_range(0, M, for_type="parallel", name='row') as row:
+                dot = irb.allocate('float32', (1,), name='dot', scope='local')
+                out_ptr[row*N+n] = 0.
+                dot[0] = 0.
+                row_start = indptr_ptr[row]
+                row_end = indptr_ptr[row+1]
+                row_elems = row_end-row_start
+                with irb.for_range(0, row_elems, name='idx') as idx:
+                    elem = row_start+idx
+                    dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]*N+n]
+                out_ptr[row*N+n] += dot[0]
+        return irb.get()
+    oshape = (M, N)
+    matmul = tvm.extern(oshape, [data, indices, indptr, weight],
+                        lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+                        tag="csrmm", dtype='float32', name='out')
+    if bias is not None:
+        matmul = tvm.compute(oshape, lambda i, j: matmul[i, j] + bias[i], \
+                             tag=tag.BROADCAST)
+    return matmul
+
+
+def csrmm(a, b, c=None):
+    """The `csrmm` routine performs a matrix-matrix operation defined as :math:`C := A*B + C`,
+    where `B` and `C` are dense matrices, `A` is an m-by-k sparse matrix in the CSR format.
+
+    Parameters
+    ----------
+    a : tvm.contrib.sparse.CSRNDArray
+        2-D sparse matrix with shape [m, k]
+
+    b : tvm.Tensor
+        2-D dense matrix with shape [k, n]
+
+    c : tvm.Tensor, optional
+        1-D dense vector with shape [n]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [m, n]
+    """
+    return csrmm_default(a.data, a.indices, a.indptr, b, c)
diff --git a/topi/python/topi/sparse/csrmv.py b/topi/python/topi/sparse/csrmv.py
new file mode 100644
index 000000000..7cd101711
--- /dev/null
+++ b/topi/python/topi/sparse/csrmv.py
@@ -0,0 +1,90 @@
+"""TVM operator compute SpMV in CSR format."""
+from __future__ import absolute_import
+import tvm
+from .. import tag
+
+def csrmv_default(data, indices, indptr, weight, bias=None):
+    """The default implementation of csrmv in topi.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        1-D with shape [nonzeros]
+
+    indices : tvm.Tensor
+        1-D with shape [nonzeros]
+
+    indptr : tvm.Tensor
+        1-D with shape [m+1]
+
+    weight : tvm.Tensor
+        2-D with shape [k, 1]
+
+    bias : tvm.Tensor, optional
+        1-D with shape [1]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [m, 1]
+    """
+    assert len(data.shape) == 1 and len(weight.shape) == 2, \
+        "only support 2-dim csrmv"
+    assert isinstance(weight, tvm.tensor.Tensor), \
+        "weight matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(weight))
+    if bias is not None:
+        assert len(bias.shape) == 1
+    batch = indptr.shape[0]-1
+    def csrmv_default_ir(data, indices, indptr, weight, out):
+        """define ir for csrmv"""
+        irb = tvm.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        weight_ptr = irb.buffer_ptr(weight)
+        out_ptr = irb.buffer_ptr(out)
+        num_rows = indptr.shape[0]-1
+        with irb.for_range(0, num_rows, for_type="parallel", name='row') as row:
+            dot = irb.allocate('float32', (1,), name='dot', scope='local')
+            out_ptr[row] = 0.
+            dot[0] = 0.
+            row_start = indptr_ptr[row]
+            row_end = indptr_ptr[row+1]
+            row_elems = row_end-row_start
+            with irb.for_range(0, row_elems, name='elemidx') as elemidx:
+                elem = row_start+elemidx
+                dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]]
+            out_ptr[row] += dot[0]
+        return irb.get()
+    oshape = (batch, 1)
+    matmul = tvm.extern(oshape, [data, indices, indptr, weight],
+                        lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+                        tag="csrmv", dtype='float32', name='csrmv')
+    if bias is not None:
+        matmul = tvm.compute((batch, 1), lambda i, j: matmul[i, 0] + bias[i], \
+                             tag=tag.BROADCAST)
+    return matmul
+
+
+def csrmv(a, x, y=None):
+    """The `csrmv` routine performs a matrix-vector operation defined as :math:`y := A*x + y`,
+    where `x` and `y` are vectors, `A` is an m-by-k sparse matrix in the CSR format.
+
+    Parameters
+
+    ----------
+    a : tvm.contrib.sparse.CSRNDArray
+        2-D sparse matrix with shape [m, k]
+
+    x : tvm.Tensor
+        2-D dense matrix with shape [k, 1]
+
+    y : tvm.Tensor, optional
+        1-D dense vector with shape [1]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D dense matrix with shape [m, 1]
+    """
+    return csrmv_default(a.data, a.indices, a.indptr, x, y)
diff --git a/topi/python/topi/sparse/dense.py b/topi/python/topi/sparse/dense.py
new file mode 100644
index 000000000..01f323bc8
--- /dev/null
+++ b/topi/python/topi/sparse/dense.py
@@ -0,0 +1,173 @@
+"""TVM operator compute Dense in CSR format."""
+from __future__ import absolute_import
+import tvm
+from .. import tag
+from ..util import simplify
+
+def dense_si(data, indices, indptr, weight, bias=None):
+    # pylint: disable=invalid-name
+    """The implementation of dense in topi, assuming sparse input.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        1-D with shape [num_nonzeros]
+
+    indices : tvm.Tensor
+        1-D with shape [num_nonzeros]
+
+    indptr : tvm.Tensor
+        1-D with shape [m+1]
+
+    weight : tvm.Tensor
+        2-D with shape [k, n]
+
+    bias : tvm.Tensor, optional
+        1-D with shape [m]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [m, n]
+    """
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
+        and len(weight.shape) == 2, "only support 2-dim dense"
+    assert isinstance(weight, tvm.tensor.Tensor), \
+        "weight matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(weight))
+    if bias is not None:
+        assert len(bias.shape) == 1
+    dtype = data.dtype
+    M = simplify(indptr.shape[0]-1)
+    N, _ = weight.shape
+    def dense_default_ir(data, indices, indptr, weight, out):
+        """Define IR for Dense"""
+        dtype = data.dtype
+        irb = tvm.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        weight_ptr = irb.buffer_ptr(weight)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        N, K = weight.shape
+        with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+            with irb.for_range(0, M, for_type="parallel", name='m') as m:
+                dot = irb.allocate(dtype, (1,), name='dot', scope='local')
+                out_ptr[m*N+n] = tvm.const(0, dtype)
+                dot[0] = tvm.const(0, dtype)
+                row_start = indptr_ptr[m]
+                row_elems = indptr_ptr[m+1]-row_start
+                with irb.for_range(0, row_elems, name='k') as k:
+                    elem = row_start+k
+                    dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]+n*K]
+                out_ptr[m*N+n] += dot[0]
+        return irb.get()
+    oshape = (M, N)
+    matmul = tvm.extern(oshape, [data, indices, indptr, weight],
+                        lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+                        tag="dense", dtype=dtype, name='out')
+    if bias is not None:
+        matmul = tvm.compute(oshape, lambda i, j: matmul[i, j] + bias[j], \
+                             tag=tag.BROADCAST)
+    return matmul
+
+
+def dense_sw(data, w_data, w_indices, w_indptr, bias=None):
+    # pylint: disable=invalid-name
+    """The implementation of dense in topi, assuming sparse weight.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        2-D with shape [m, k]
+
+    w_data : tvm.Tensor
+        1-D with shape [nonzeros]
+
+    w_indices : tvm.Tensor
+        1-D with shape [nonzeros]
+
+    w_indptr : tvm.Tensor
+        1-D with shape [n+1]
+
+    bias : tvm.Tensor, optional
+        1-D with shape [n]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [m, n]
+    """
+    assert len(w_data.shape) == 1 and len(w_indices.shape) == 1 and len(w_indptr.shape) == 1 \
+        and len(data.shape) == 2, "only support 2-dim dense"
+    assert isinstance(data, tvm.tensor.Tensor), \
+        "data matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(data))
+    if bias is not None:
+        assert len(bias.shape) == 1
+    dtype = data.dtype
+    M, _ = data.shape
+    N = simplify(w_indptr.shape[0]-1)
+    def dense_default_ir(data, w_data, w_indices, w_indptr, out):
+        """Define IR for Dense"""
+        dtype = data.dtype
+        irb = tvm.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        w_data_ptr = irb.buffer_ptr(w_data)
+        w_indices_ptr = irb.buffer_ptr(w_indices)
+        w_indptr_ptr = irb.buffer_ptr(w_indptr)
+        out_ptr = irb.buffer_ptr(out)
+        M, K = data.shape
+        N = simplify(w_indptr.shape[0]-1)
+        with irb.for_range(0, M, for_type="vectorize", name='m') as m:
+            with irb.for_range(0, N, for_type="parallel", name='n') as n:
+                dot = irb.allocate(dtype, (1,), name='dot', scope='local')
+                out_ptr[m*N+n] = tvm.const(0, dtype)
+                dot[0] = tvm.const(0, dtype)
+                row_start = w_indptr_ptr[n]
+                row_elems = w_indptr_ptr[n+1]-row_start
+                with irb.for_range(0, row_elems, name='k') as k:
+                    elem = row_start+k
+                    dot[0] += w_data_ptr[elem] * data_ptr[w_indices_ptr[elem]+m*K]
+                out_ptr[m*N+n] += dot[0]
+        return irb.get()
+    oshape = (M, N)
+    matmul = tvm.extern(oshape, [data, w_data, w_indices, w_indptr],
+                        lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+                        tag="dense", dtype=dtype, name='out')
+    if bias is not None:
+        matmul = tvm.compute(oshape, lambda i, j: matmul[i, j] + bias[j], \
+                             tag=tag.BROADCAST)
+    return matmul
+
+
+def dense(data, weight, bias=None):
+    """Applies a linear transformation: :math:`Y = XW^T + b`.
+    Either data or weight should be tvm.contrib.sparse.CSRNDArray.
+
+    Parameters
+    ----------
+    data : tvm.contrib.sparse.CSRNDArray or tvm.tensor.Tensor
+        2-D with shape [batch, in_dim]
+
+    weight : tvm.tensor.Tensor or tvm.contrib.sparse.CSRNDArray
+        2-D with shape [out_dim, in_dim]
+
+    bias : tvm.tensor.Tensor, optional
+        1-D with shape [out_dim]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [batch, out_dim]
+    """
+    ret = None
+    if isinstance(data, tvm.contrib.sparse.CSRPlaceholderOp) and \
+       isinstance(weight, tvm.tensor.Tensor):
+        ret = dense_si(data.data, data.indices, data.indptr, weight, bias)
+    elif isinstance(data, tvm.tensor.Tensor) and \
+       isinstance(weight, tvm.contrib.sparse.CSRPlaceholderOp):
+        ret = dense_sw(data, weight.data, weight.indices, weight.indptr, bias)
+    else:
+        raise NotImplementedError("implementation for %s as data and %s as weights, "
+                                  "is not supported yet." % (type(data), type(weight), ))
+    return ret
diff --git a/topi/tests/python/test_topi_sparse.py b/topi/tests/python/test_topi_sparse.py
new file mode 100644
index 000000000..deb3a08ea
--- /dev/null
+++ b/topi/tests/python/test_topi_sparse.py
@@ -0,0 +1,205 @@
+"""Test code for sparse operator"""
+import numpy as np
+import tvm
+import topi
+import topi.testing
+from topi.util import get_const_tuple
+import tvm.contrib.sparse as tvmsp
+from collections import namedtuple
+import time
+
+def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
+    nr, nc, n = tvm.var("nr"), tvm.var("nc"), tvm.var("n")
+    dtype = 'float32'
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name='A')
+    B = tvm.placeholder((in_dim, 1), name='B')
+    C = tvm.placeholder((nr,), name='C')
+    D = topi.sparse.csrmv(A, B, C if use_bias else None)
+    s = tvm.create_schedule(D.op)
+    dtype = A.dtype
+
+    # get the test data
+    def get_ref_data():
+        a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype)-0.5, 0.)
+        b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype)-0.5
+        c_np = np.random.uniform(size=(batch, )).astype(dtype)
+        if use_bias:
+            d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
+        else:
+            d_np = np.dot(a_np, b_np)
+        return (a_np, b_np, c_np, d_np)
+    a_np, b_np, c_np, d_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        a = tvmsp.array(a_np, ctx)
+        _nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
+        assert a.shape[0] == a.indptr.shape[0]-1
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(c_np, ctx)
+        d = tvm.nd.array(np.zeros((_nr, 1), dtype=dtype), ctx)
+        assert a.data.dtype == A.data.dtype
+        assert a.indices.dtype == A.indices.dtype
+        assert a.indptr.dtype == A.indptr.dtype
+        f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], device, name="csrmv")
+        f(_nr, a.data, a.indices, a.indptr, b, c, d)
+        np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-4)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True):
+    nr, nc, n = tvm.var("nr"), tvm.var("nc"), tvm.var("n")
+    dtype = 'float32'
+    A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name='A')
+    B = tvm.placeholder((in_dim, out_dim), name='B')
+    C = tvm.placeholder((nr,), name='C')
+    D = topi.sparse.csrmm(A, B, C if use_bias else None)
+    s = tvm.create_schedule(D.op)
+    dtype = A.dtype
+
+    # get the test data
+    def get_ref_data():
+        a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype)-0.5, 0.)
+        b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype)-0.5
+        c_np = np.random.uniform(size=(batch, )).astype(dtype)
+        if use_bias:
+            d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
+        else:
+            d_np = np.dot(a_np, b_np)
+        return (a_np, b_np, c_np, d_np)
+    a_np, b_np, c_np, d_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        a = tvmsp.array(a_np, ctx)
+        _nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
+        assert a.shape[0] == a.indptr.shape[0]-1
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(c_np, ctx)
+        d = tvm.nd.array(np.zeros((_nr, out_dim), dtype=dtype), ctx)
+        f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], device, name="csrmm")
+
+        f(_nr, a.data, a.indices, a.indptr, b, c, d)
+        np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-2)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+def verify_dense_si(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
+    nonzeros = tvm.var('nonzeros')
+    A = tvmsp.placeholder(shape=(batch, in_dim), nonzeros=nonzeros, dtype=dtype, name='A')
+    B = tvm.placeholder((out_dim, in_dim), dtype=dtype, name='B')
+    C = tvm.placeholder((out_dim,), dtype=dtype, name='C')
+    D = topi.sparse.dense(A, B, C if use_bias else None)
+    s = tvm.create_schedule(D.op)
+
+    # get the test data
+    def get_ref_data():
+        mag = 10.
+        a_np = np.maximum(mag*(np.random.uniform(size=(batch, in_dim)).astype('float32')-0.5), 0.).astype(dtype)
+        b_np = (mag*(np.random.uniform(size=(out_dim, in_dim)).astype('float32')-.5)).astype(dtype)
+        c_np = (mag*(np.random.uniform(size=(out_dim,)).astype('float32')-.5)).astype(dtype)
+        if use_bias:
+            d_np = np.dot(a_np, b_np.T) + c_np
+        else:
+            d_np = np.dot(a_np, b_np.T)
+        return (a_np, b_np, c_np, d_np)
+    a_np, b_np, c_np, d_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        a = tvmsp.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(c_np, ctx)
+        d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
+        f = tvm.build(s, [A.data, A.indices, A.indptr, B, C, D], device, name="dense")
+        f(a.data, a.indices, a.indptr, b, c, d)
+        np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
+
+    check_device('llvm')
+
+def verify_dense_sw(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
+    nonzeros = tvm.var('nonzeros')
+    A = tvm.placeholder((batch, in_dim), dtype=dtype, name='A')
+    B = tvmsp.placeholder(shape=(out_dim, in_dim), nonzeros=nonzeros, dtype=dtype, name='B')
+    C = tvm.placeholder((out_dim,), dtype=dtype, name='C')
+    D = topi.sparse.dense(A, B, C if use_bias else None)
+    s = tvm.create_schedule(D.op)
+
+    # get the test data
+    def get_ref_data():
+        mag = 10.
+        a_np = (mag*(np.random.uniform(size=(batch, in_dim)).astype('float32')-.5)).astype(dtype)
+        b_np = np.maximum(mag*(np.random.uniform(size=(out_dim, in_dim)).astype('float32')-0.5), 0.).astype(dtype)
+        c_np = (mag*(np.random.uniform(size=(out_dim,)).astype('float32')-.5)).astype(dtype)
+        if use_bias:
+            d_np = np.dot(a_np, b_np.T) + c_np
+        else:
+            d_np = np.dot(a_np, b_np.T)
+        return (a_np, b_np, c_np, d_np)
+    a_np, b_np, c_np, d_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvmsp.array(b_np, ctx)
+        c = tvm.nd.array(c_np, ctx)
+        d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
+        f = tvm.build(s, [A, B.data, B.indices, B.indptr, C, D], device, name="dense")
+        f(a, b.data, b.indices, b.indptr, c, d)
+        np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
+
+    check_device('llvm')
+
+def test_csrmv():
+    verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=False)
+    verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=True)
+
+def test_csrmm():
+    M, K, N = 5, 7, 2
+    verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=False)
+    verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=True)
+
+def test_dense_si():
+    M, K, N = 3, 5, 2
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='float32')
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='float32')
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int32')
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int32')
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int16')
+    verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int16')
+
+def test_dense_sw():
+    M, K, N = 3, 5, 2
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='float32')
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='float32')
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int32')
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int32')
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int16')
+    verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int16')
+
+def test_dense():
+    test_dense_si()
+    test_dense_sw()
+
+if __name__ == "__main__":
+    test_csrmv()
+    test_csrmm()
+    test_dense()
-- 
GitLab