diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index eaad08de75fef7f7dc3249b441d8a220c8912ba4..d97b7c511bc54f4cd5d0b1d966446f35a475e173 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -10,3 +10,5 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc from .reduction import schedule_reduce from .softmax import schedule_softmax from .injective import schedule_injective, schedule_elemwise, schedule_broadcast +from .dense import schedule_dense +from .pooling import schedule_global_pool diff --git a/topi/python/topi/cuda/dense.py b/topi/python/topi/cuda/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..207aabc9e7ce3dbebc41aa15ea4baa7feee38668 --- /dev/null +++ b/topi/python/topi/cuda/dense.py @@ -0,0 +1,60 @@ +# pylint: disable=invalid-name, unused-variable +"""Schedule for dense operator""" +from __future__ import absolute_import as _abs +import tvm +from .. import tag + +def schedule_dense(outs): + """Schedule for dense operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of dense + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for dense. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + def _schedule(Dense): + num_thread = 64 + k = Dense.op.reduce_axis[0] + ko, kf = s[Dense].split(k, factor=num_thread) + DenseF = s.rfactor(Dense, kf) + + if Dense.op in s.outputs: + Out = Dense + else: + Out = outs[0].op.output(0) + s[Dense].compute_at(s[Out], s[Out].op.axis[1]) + s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y")) + s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x")) + + tx = s[Dense].op.reduce_axis[0] + thread_x = tvm.thread_axis("threadIdx.x") + s[Dense].bind(tx, thread_x) + s[DenseF].compute_at(s[Dense], tx) + s[Dense].set_store_predicate(thread_x.var.equal(0)) + s[Out].set_store_predicate(thread_x.var.equal(0)) + + def traverse(OP): + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(OP.tag): + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + # schedule dense + elif OP.tag == 'dense': + Dense = OP.output(0) + _schedule(Dense) + else: + raise RuntimeError("Unsupported operator: %s" % OP.tag) + + traverse(outs[0].op) + return s diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..db3714e9bf5f1f3d2ee90df078a38ae539dafd25 --- /dev/null +++ b/topi/python/topi/cuda/pooling.py @@ -0,0 +1,66 @@ +# pylint: disable=invalid-name, unused-variable +"""Schedule for pooling operators""" +import tvm +from .. import tag + +def schedule_global_pool(outs): + """Schedule for global_pool. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of global_pool + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for global_pool. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + def _schedule(Pool): + num_thread = 8 + block_x = tvm.thread_axis("blockIdx.x") + block_y = tvm.thread_axis("blockIdx.y") + thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") + thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") + if Pool.op in s.outputs: + Out = Pool + OL = s.cache_write(Pool, "local") + else: + Out = outs[0].op.output(0) + s[Pool].set_scope("local") + i, c, h, w = s[Out].op.axis + dh, dw = s[Pool].op.reduce_axis + fuse_index = s[Pool].fuse(dw, dh) + s[Pool].unroll(fuse_index) + by, ty = s[Out].split(i, factor=num_thread) + bx, tx = s[Out].split(c, factor=num_thread) + s[Out].reorder(by, bx, ty, tx) + s[Out].bind(ty, thread_y) + s[Out].bind(tx, thread_x) + s[Out].bind(by, block_y) + s[Out].bind(bx, block_x) + if Pool.op in s.outputs: + s[OL].compute_at(s[Out], tx) + else: + s[Pool].compute_at(s[Out], tx) + + def traverse(OP): + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(OP.tag): + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + # schedule global_pool + elif 'global_pool' in OP.tag: + Pool = OP.output(0) + _schedule(Pool) + else: + raise RuntimeError("Unsupported operator: %s" % OP.tag) + + traverse(outs[0].op) + return s diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 0cc6c9472fe8d91561ac3af9f7d48895afb83a81..46edac975183b21f72987618a66ca76d232a9855 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -8,7 +8,7 @@ from .depthwise_convolution import * from .elemwise import * from .dilate import * from .flatten import * -from .fully_connected import * +from .dense import * from .mapping import * from .pooling import * from .softmax import * diff --git a/topi/python/topi/nn/dense.py b/topi/python/topi/nn/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..e64bd9e9adc0ecb932922bc1e2fac3e9b36e0ea5 --- /dev/null +++ b/topi/python/topi/nn/dense.py @@ -0,0 +1,41 @@ +"""TVM operator fully connected compute.""" +from __future__ import absolute_import +import tvm +from .. import tag + + +def dense(data, weight, bias, use_bias=True): + """Applies a linear transformation: :math:`Y = XW^T + b`. + + Parameters + ---------- + data : tvm.Tensor + 2-D with shape [batch, in_dim] + + weight : tvm.Tensor + 2-D with shape [out_dim, in_dim] + + bias : tvm.Tensor + 1-D with shape [out_dim] + + use_bias : bool, optional, default=True + Whether to use bias parameter + + Returns + ------- + output : tvm.Tensor + 2-D with shape [batch, out_dim] + """ + assert len(data.shape) == 2 and len(weight.shape) == 2 and len(bias.shape) == 1, \ + "only support 2-dim dense" + batch, in_dim = data.shape + out_dim, _ = weight.shape + k = tvm.reduce_axis((0, in_dim), name='k') + matmul = tvm.compute((batch, out_dim), \ + lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \ + tag='dense') + if not use_bias: + return matmul + return tvm.compute((batch, out_dim), \ + lambda i, j: matmul[i, j] + bias[j], \ + tag=tag.BROADCAST) diff --git a/topi/python/topi/nn/fully_connected.py b/topi/python/topi/nn/fully_connected.py deleted file mode 100644 index 870df02424f55d0db16216b65cbc9bee9a3031e5..0000000000000000000000000000000000000000 --- a/topi/python/topi/nn/fully_connected.py +++ /dev/null @@ -1,62 +0,0 @@ -"""TVM operator fully connected compute.""" -from __future__ import absolute_import -import tvm - - -@tvm.tag_scope(tag='fully_connected') -def fully_connected(data, weight): - """Matrix multiplication - - Parameters - ---------- - data : tvm.Tensor - 2-D with shape [batch, in_dim] - - weight : tvm.Tensor - 2-D with shape [out_dim, in_dim] - - Returns - ------- - output : tvm.Tensor - 2-D with shape [batch, out_dim] - """ - assert len(data.shape) == 2 and len(weight.shape) == 2, \ - "only support 2-dim fully_connected" - batch, in_dim = data.shape - out_dim, _ = weight.shape - k = tvm.reduce_axis((0, in_dim), name='k') - return tvm.compute((batch, out_dim), lambda i, j: \ - tvm.sum(data[i][k] * weight[j][k], axis=k)) - - -@tvm.tag_scope(tag='fully_connected_with_bias') -def fully_connected_with_bias(data, weight, bias): - """Applies a linear transformation: :math:`Y = XW^T + b`. - - Parameters - ---------- - data : tvm.Tensor - 2-D with shape [batch, in_dim] - - weight : tvm.Tensor - 2-D with shape [out_dim, in_dim] - - bias : tvm.Tensor - 1-D with shape [out_dim] - - Returns - ------- - output : tvm.Tensor - 2-D with shape [batch, out_dim] - """ - assert len(data.shape) == 2 and len(weight.shape) == 2, \ - "only support 2-dim fully_connected" - assert len(data.shape) == 2 and len(weight.shape) == 2 and len(bias.shape) == 1, \ - "only support 2-dim fully_connected" - batch, in_dim = data.shape - out_dim, _ = weight.shape - k = tvm.reduce_axis((0, in_dim), name='k') - matmul = tvm.compute((batch, out_dim), lambda i, j: \ - tvm.sum(data[i, k] * weight[j, k], axis=k)) - return tvm.compute((batch, out_dim), lambda i, j: \ - matmul[i, j] + bias[j]) diff --git a/topi/python/topi/nn/pooling.py b/topi/python/topi/nn/pooling.py index 26d3a3e11f0b5e19491772210288074978f7a49a..511d10afce4984e42a610e9d7a990d925936f16f 100644 --- a/topi/python/topi/nn/pooling.py +++ b/topi/python/topi/nn/pooling.py @@ -4,6 +4,7 @@ import tvm from .pad import pad from .util import get_pad_tuple from .. import util +from .. import tag def max_pool(data, kernel, stride, padding): """Perform max pooling on the data @@ -51,15 +52,17 @@ def max_pool(data, kernel, stride, padding): tag="max_pool") -@tvm.tag_scope(tag='global_avg_pool') -def global_avg_pool(data): - """Perform global average pooling on the data +def global_pool(data, pool_type): + """Perform global pooling on the data Parameters ---------- data : tvm.Tensor 4-D with shape [batch, channel, in_height, in_width] + pool_type : str + Pool type, 'max' or 'avg' + Returns ------- output : tvm.Tensor @@ -71,7 +74,16 @@ def global_avg_pool(data): dheight = tvm.reduce_axis((0, height)) dwidth = tvm.reduce_axis((0, width)) - tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ - tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth])) - return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ - tsum[n, c, h, w] / (height*width)) + if pool_type == 'max': + return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ + tvm.max(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \ + tag="global_pool_max") + elif pool_type == 'avg': + tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ + tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \ + tag="global_pool_sum") + return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ + tsum[n, c, h, w] / (height*width), \ + tag=tag.ELEMWISE) + else: + raise ValueError("Pool type should be 'avg' or 'max'.") diff --git a/topi/tests/python/test_topi_dense.py b/topi/tests/python/test_topi_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..20c9c6cc4eaa0cc5d1b00659a35c8ca5ab71823b --- /dev/null +++ b/topi/tests/python/test_topi_dense.py @@ -0,0 +1,54 @@ +"""Test code for dense operator""" +import numpy as np +import tvm +import topi +from topi.util import get_const_tuple +from tvm.contrib.pickle_memoize import memoize + + +def verify_dense(batch, in_dim, out_dim, use_bias=True): + A = tvm.placeholder((batch, in_dim), name='A') + B = tvm.placeholder((out_dim, in_dim), name='B') + C = tvm.placeholder((out_dim,), name='C') + D = topi.nn.dense(A, B, C, use_bias=use_bias) + D = topi.nn.relu(D) + s = topi.cuda.schedule_dense(D) + dtype = A.dtype + + # use memoize to pickle the test data for next time use + @memoize("topi.tests.test_topi_dense") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) + b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.uniform(size=(out_dim,)).astype(dtype) + if use_bias: + d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) + else: + d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) + return (a_np, b_np, c_np, d_np) + # get the test data + a_np, b_np, c_np, d_np = get_ref_data() + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(c_np, ctx) + d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) + f = tvm.build(s, [A, B, C, D], device, name="dense") + f(a, b, c, d) + np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) + + for device in ['cuda', 'opencl', 'metal']: + check_device(device) + +def test_dense(): + verify_dense(1, 1024, 1000, use_bias=True) + verify_dense(1, 1024, 1000, use_bias=False) + + +if __name__ == "__main__": + test_dense() diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..397f5133c66ed1b43af0c4b835de08a0fd0ab6a0 --- /dev/null +++ b/topi/tests/python/test_topi_pooling.py @@ -0,0 +1,42 @@ +"""Test code for pooling""" +import numpy as np +import tvm +import topi +from topi.util import get_const_tuple + +def verify_global_pool(n, c, h, w, pool_type): + A = tvm.placeholder((n, c, h, w), name='A') + B = topi.nn.global_pool(A, pool_type=pool_type) + B = topi.nn.relu(B) + s = topi.cuda.schedule_global_pool(B) + + a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) + if pool_type == 'avg': + b_np = np.mean(a_np, axis=(2,3), keepdims=True) + elif pool_type =='max': + b_np = np.max(a_np, axis=(2,3), keepdims=True) + b_np = np.maximum(b_np, 0.0) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + f = tvm.build(s, [A, B], device, name="global_avg_pool") + f(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['cuda', 'opencl', 'metal']: + check_device(device) + +def test_global_pool(): + verify_global_pool(1, 1024, 7, 7, 'avg') + verify_global_pool(4, 1024, 7, 7, 'avg') + verify_global_pool(1, 1024, 7, 7, 'max') + verify_global_pool(4, 1024, 7, 7, 'max') + + +if __name__ == "__main__": + test_global_pool()