From 56c50d2d072d976b058ba7e29a276e4e9b1e8839 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn> Date: Wed, 22 Aug 2018 20:19:43 -0700 Subject: [PATCH] trigger ci (#1620) --- topi/python/topi/x86/nn.py | 43 +++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/topi/python/topi/x86/nn.py b/topi/python/topi/x86/nn.py index 03e07222c..6802d4c01 100644 --- a/topi/python/topi/x86/nn.py +++ b/topi/python/topi/x86/nn.py @@ -2,8 +2,9 @@ """x86 nn operators""" from __future__ import absolute_import as _abs import tvm + from .. import generic -from .. import tag +from ..util import traverse_inline @generic.schedule_softmax.register(["cpu"]) def schedule_softmax(outs): @@ -53,44 +54,38 @@ def schedule_dense(outs): outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def traverse(op): - """Traverse operators from computation graph""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(op.tag): - if op not in s.outputs: - s[op].compute_inline() - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) + def _callback(op): if 'dense' in op.tag: - C = op.output(0) - x, y = C.op.axis + output = outs[0] + dense = op.output(0) # Write cache for blocks - CC = s.cache_write(C, 'global') + if dense.op in s.outputs: + CC = s.cache_write(dense, 'local') + else: + CC = dense # Tile bnx = 1 bny = 4 - _, yo, _, yi = s[C].tile(x, y, bnx, bny) - s[CC].compute_at(s[C], yo) + x, y = output.op.axis + xo, yo, xi, yi = s[output].tile(x, y, bnx, bny) + xc, yc = s[CC].op.axis k, = s[CC].op.reduce_axis ko, ki = s[CC].split(k, factor=4) s[CC].reorder(ko, xc, ki, yc) + s[CC].unroll(ki) s[CC].vectorize(yc) - # Vectorization - s[C].vectorize(yi) - - # Parallelization - s[C].parallel(yo) + s[output].unroll(xi) + s[output].vectorize(yi) - scheduled_ops.append(op) + fused = s[output].fuse(xo, yo) + s[output].parallel(fused) + s[CC].compute_at(s[output], fused) - traverse(outs[0].op) + traverse_inline(s, outs[0].op, _callback) return s -- GitLab