From b840e9602e357c50124d0c7fb131c52321062570 Mon Sep 17 00:00:00 2001
From: Siju <sijusamuel@gmail.com>
Date: Tue, 30 Oct 2018 08:15:44 +0530
Subject: [PATCH] [YOLO]yolo op added in frontend and removed from topi (#1974)

---
 nnvm/python/nnvm/frontend/darknet.py     | 20 +++++---
 nnvm/python/nnvm/top/vision.py           | 15 ------
 nnvm/src/top/vision/yolo/yolo.cc         | 33 --------------
 topi/include/topi/vision/yolo/yolo.h     | 58 ------------------------
 topi/python/topi/testing/__init__.py     |  1 -
 topi/python/topi/testing/yolo_python.py  | 43 ------------------
 topi/python/topi/vision/yolo/__init__.py |  1 -
 topi/python/topi/vision/yolo/yolo.py     | 30 ------------
 topi/src/topi.cc                         |  6 ---
 topi/tests/python_cpp/test_topi_yolo.py  | 49 --------------------
 10 files changed, 14 insertions(+), 242 deletions(-)
 delete mode 100644 nnvm/src/top/vision/yolo/yolo.cc
 delete mode 100644 topi/include/topi/vision/yolo/yolo.h
 delete mode 100644 topi/python/topi/testing/yolo_python.py
 delete mode 100644 topi/python/topi/vision/yolo/yolo.py
 delete mode 100644 topi/tests/python_cpp/test_topi_yolo.py

diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py
index 4da2e90bc..18d07d07a 100644
--- a/nnvm/python/nnvm/frontend/darknet.py
+++ b/nnvm/python/nnvm/frontend/darknet.py
@@ -317,12 +317,19 @@ def _darknet_region(inputs, attrs):
 
 def _darknet_yolo(inputs, attrs):
     """Process the yolo operation."""
-    op_name, new_attrs = 'yolov3_yolo', {}
-    if 'n' in attrs:
-        new_attrs['n'] = attrs.get('n', 1)
-    if 'classes' in attrs:
-        new_attrs['classes'] = attrs.get('classes', 1)
-    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+    num = attrs.get('n', 1)
+    classes = attrs.get('classes', 1)
+    input_shape = attrs.get('shape')
+    split_size = classes + 5
+    intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3])
+    data_block = _sym.reshape(inputs[0], shape=intermediate_shape)
+    split_indices = (2, 4)
+    split_res = _sym.split(data_block, indices_or_sections=split_indices, axis=2)
+    split_res0 = _sym.sigmoid(split_res[0])
+    split_res2 = _sym.sigmoid(split_res[2])
+    concat_list = [split_res0, split_res[1], split_res2]
+    out = _sym.concatenate(*concat_list, axis=2)
+    return _sym.reshape(out, shape=input_shape), None
 
 def _darknet_activations(inputs, attrs):
     """Process the activation function."""
@@ -635,6 +642,7 @@ class GraphProto(object):
         elif LAYERTYPE.YOLO == layer.type:
             attr.update({'n' : layer.n})
             attr.update({'classes' : layer.classes})
+            attr.update({'shape' : (1, layer.c, layer.h, layer.w)})
 
         elif LAYERTYPE.UPSAMPLE == layer.type:
             attr.update({'scale' : layer.stride})
diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py
index e59b2bdfe..f2e12c0f3 100644
--- a/nnvm/python/nnvm/top/vision.py
+++ b/nnvm/python/nnvm/top/vision.py
@@ -38,21 +38,6 @@ def schedule_region(attrs, outs, target):
 
 reg.register_pattern("yolo_region", OpPattern.OPAQUE)
 
-@reg.register_compute("yolov3_yolo")
-def compute_yolo(attrs, inputs, _):
-    """Compute definition of yolo"""
-    n = attrs.get_int("n")
-    classes = attrs.get_int("classes")
-    return topi.vision.yolo.yolo(inputs[0], n, classes)
-
-@reg.register_schedule("yolov3_yolo")
-def schedule_yolo(attrs, outs, target):
-    """Schedule definition of yolo"""
-    with tvm.target.create(target):
-        return topi.generic.schedule_injective(outs)
-
-reg.register_pattern("yolov3_yolo", OpPattern.OPAQUE)
-
 # multibox_prior
 @reg.register_schedule("multibox_prior")
 def schedule_multibox_prior(_, outs, target):
diff --git a/nnvm/src/top/vision/yolo/yolo.cc b/nnvm/src/top/vision/yolo/yolo.cc
deleted file mode 100644
index 4800f4371..000000000
--- a/nnvm/src/top/vision/yolo/yolo.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file yolo.cc
- * \brief Property def of yolo operators.
- */
-#include <nnvm/op.h>
-#include <nnvm/node.h>
-#include <nnvm/op_attr_types.h>
-#include <nnvm/top/nn.h>
-#include "../../elemwise_op_common.h"
-
-namespace nnvm {
-namespace top {
-
-NNVM_REGISTER_OP(yolov3_yolo)
-.describe(R"code(Yolo layer
-)code" NNVM_ADD_FILELINE)
-.set_num_inputs(1)
-.set_num_outputs(1)
-.set_support_level(5)
-.add_argument("data", "Tensor", "Input data")
-.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
-.set_attr<FInplaceOption>(
-    "FInplaceOption",
-    [](const NodeAttrs &attrs) {
-      return std::vector<std::pair<int, int>>{{0, 0}, {1, 0}};
-    })
-.set_attr<FGradient>("FGradient", [](const NodePtr &n,
-                                     const std::vector<NodeEntry> &ograds) {
-  return std::vector<NodeEntry>{ograds[0], ograds[0]};
-});
-}  // namespace top
-}  // namespace nnvm
diff --git a/topi/include/topi/vision/yolo/yolo.h b/topi/include/topi/vision/yolo/yolo.h
deleted file mode 100644
index d2e24c01b..000000000
--- a/topi/include/topi/vision/yolo/yolo.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/*!
- *  Copyright (c) 2018 by Contributors
- * \brief YOLO op constructions
- * \file vision/yolo/yolo.h
- */
-#ifndef TOPI_VISION_YOLO_YOLO_H_
-#define TOPI_VISION_YOLO_YOLO_H_
-
-#include <algorithm>
-#include <string>
-
-#include "topi/detail/constant_utils.h"
-#include "topi/tags.h"
-#include "topi/transform.h"
-#include "tvm/tvm.h"
-
-
-namespace topi {
-namespace vision {
-namespace yolo {
-using namespace tvm;
-using namespace nn;
-
-/*!
-* \brief yolo operation
-*
-* \param data The input tensor.
-* \param num Darknet layer parameter n
-* \param classes number of classes in the yolo model
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the yolo operation
-*/
-inline Tensor yolo(const Tensor &data,
-                   int num,
-                   int classes,
-                   std::string name = "tensor",
-                   std::string tag = "yolo_output") {
-  auto input_shape = data->shape;
-  int split_size = classes + 5;
-  Array <Expr> intermediate_shape = {input_shape[0],
-                                     num,
-                                     split_size,
-                                     input_shape[2],
-                                     input_shape[3]};
-  auto data_block = reshape(data, intermediate_shape);
-  Array <Expr> split_indices = {2, 4};
-  Array <Tensor> split_res = split(data_block, split_indices, 2);
-  split_res.Set(0, sigmoid(split_res[0]));
-  split_res.Set(2, sigmoid(split_res[2]));
-  Tensor out = concatenate(split_res, 2);
-  return reshape(out, input_shape);
-}
-}  // namespace yolo
-}  // namespace vision
-}  // namespace topi
-#endif  // TOPI_VISION_YOLO_YOLO_H_
diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py
index c9d995a38..c91eea795 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -15,7 +15,6 @@ from .upsampling_python import upsampling_python
 from .bilinear_resize_python import bilinear_resize_python
 from .reorg_python import reorg_python
 from .region_python import region_python
-from .yolo_python import yolo_python
 from .shortcut_python import shortcut_python
 from .lrn_python import lrn_python
 from .l2_normalize_python import l2_normalize_python
diff --git a/topi/python/topi/testing/yolo_python.py b/topi/python/topi/testing/yolo_python.py
deleted file mode 100644
index a6b3a4120..000000000
--- a/topi/python/topi/testing/yolo_python.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
-"""Yolo operator in python"""
-import numpy as np
-
-def entry_index(batch, w, h, outputs, classes, coords, location, entry):
-    n = int(location/(w*h))
-    loc = location%(w*h)
-    return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
-
-def yolo_python(a_np, N, classes):
-    """Yolo operator
-    Parameters
-    ----------
-    a_np : numpy.ndarray
-        4-D with shape [batch, in_channel, in_height, in_width]
-
-    N : int
-        Darknet layer parameter n
-
-    classes : int
-        Darknet layer parameter classes
-
-    Returns
-    -------
-    b_np : np.ndarray
-        4-D with shape [batch, out_channel, out_height, out_width]
-    """
-
-    batch, in_channel, in_height, in_width = a_np.shape
-    a_np_temp = np.reshape(a_np, batch*in_channel*in_height*in_width)
-    outputs = batch*in_channel*in_height*in_width
-    b_np = np.zeros(batch*in_channel*in_height*in_width)
-    for i in range(batch*in_channel*in_height*in_width):
-        b_np[i] = a_np_temp[i]
-    for b in range(batch):
-        for n in range(N):
-            index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 0)
-            b_np[index: index+2*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+2*in_width*in_height]))
-            index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 4)
-            b_np[index: index+(1+classes)*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+(1+classes)*in_width*in_height]))
-
-    b_np = np.reshape(b_np, (batch, in_channel, in_height, in_width))
-    return b_np
diff --git a/topi/python/topi/vision/yolo/__init__.py b/topi/python/topi/vision/yolo/__init__.py
index 2c0a165f8..c0e9899a4 100644
--- a/topi/python/topi/vision/yolo/__init__.py
+++ b/topi/python/topi/vision/yolo/__init__.py
@@ -3,4 +3,3 @@
 from __future__ import absolute_import as _abs
 
 from .region import *
-from .yolo import *
diff --git a/topi/python/topi/vision/yolo/yolo.py b/topi/python/topi/vision/yolo/yolo.py
deleted file mode 100644
index 6ae630a86..000000000
--- a/topi/python/topi/vision/yolo/yolo.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# pylint: disable=invalid-name, unused-variable
-"""
-YOLO Operator
-=============
-YOLO operator, used in darknet.
-"""
-from __future__ import absolute_import as _abs
-import tvm
-from ... import cpp
-
-@tvm.target.generic_func
-def yolo(data, num, classes):
-    """YOLO forward operators.
-    Parameters
-    ----------
-    data : tvm.Tensor
-        4-D with shape [batch, c_in, h_in, w_in]
-
-    num : int
-        Darknet layer parameter n
-
-    classes : int
-        Darknet layer parameter classes
-
-    Returns
-    -------
-    out : tvm.Tensor
-        4-D with shape [batch, c_in, h_in, w_in]
-    """
-    return cpp.yolo.yolo(data, num, classes)
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index ae1ad5755..2d9f2fd6c 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -32,7 +32,6 @@
 #include <topi/vision/reorg.h>
 #include <topi/image/resize.h>
 #include <topi/vision/yolo/region.h>
-#include <topi/vision/yolo/yolo.h>
 #include <topi/generic/default.h>
 #include <topi/generic/extern.h>
 #include <topi/generic/injective.h>
@@ -413,11 +412,6 @@ TVM_REGISTER_GLOBAL("topi.vision.yolo.region")
   *rv = vision::yolo::region(args[0], args[1], args[2], args[3], args[4], args[5]);
   });
 
-TVM_REGISTER_GLOBAL("topi.vision.yolo.yolo")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = vision::yolo::yolo(args[0], args[1], args[2]);
-  });
-
 /* Ops from image/resize.h */
 TVM_REGISTER_GLOBAL("topi.image.resize")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
diff --git a/topi/tests/python_cpp/test_topi_yolo.py b/topi/tests/python_cpp/test_topi_yolo.py
deleted file mode 100644
index 293de4fca..000000000
--- a/topi/tests/python_cpp/test_topi_yolo.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""Test code for yolo op"""
-import logging
-import numpy as np
-import tvm
-import topi
-import topi.testing
-from topi.util import get_const_tuple
-
-def verify_yolo(ishape, n, classes):
-    '''Verify yolo operator by comparing outputs from tvm and numpy implementation'''
-    
-    A = tvm.placeholder(ishape, name='A')
-    B = topi.cpp.yolo.yolo(A, n, classes)
-    dtype = A.dtype
-
-    def get_ref_data_yolo():
-        '''Randomly initialize the data variables and get refernce output for the yolo operation'''
-        a_np = np.random.uniform(size=ishape).astype(dtype)
-        b_np = topi.testing.yolo_python(a_np, n, classes)
-        return a_np, b_np
-
-    a_np, b_np = get_ref_data_yolo()
-    def check_device(device):
-        '''Check the device is available and if so, build and run the program'''
-        if not tvm.module.enabled(device):
-            print("Skip because %s is not enabled" % device)
-            return
-        print("Running on target: %s" % device)
-        target = topi.cpp.TEST_create_target(device)
-        if device == "llvm":
-            s = topi.cpp.generic.default_schedule(target, [B], False)
-        else:
-            s = topi.cpp.cuda.schedule_injective(target, [B])
-        ctx = tvm.context(device, 0)
-        a = tvm.nd.array(a_np, ctx)
-        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
-        func = tvm.build(s, [A, B], device, name="yolo")
-        func(a, b)
-        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-
-    for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']:
-        check_device(device)
-
-def test_yolo():
-    verify_yolo((1, 425, 19, 19), 5, 80)
-
-if __name__ == "__main__":
-    logging.basicConfig(level=logging.DEBUG)
-    test_yolo()
-- 
GitLab