From b30ae8ac6df3a8ba6a7479722a987465bf4e8650 Mon Sep 17 00:00:00 2001 From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com> Date: Mon, 9 Jul 2018 10:20:52 +0530 Subject: [PATCH] [TOPI][DARKNET]Yolo op added (#1372) --- nnvm/python/nnvm/frontend/darknet.py | 4 +- nnvm/python/nnvm/top/vision.py | 14 ++--- nnvm/src/top/vision/{yolo2 => yolo}/region.cc | 2 +- nnvm/src/top/vision/{yolo2 => yolo}/region.h | 6 +- nnvm/src/top/vision/{yolo2 => yolo}/reorg.cc | 2 +- nnvm/src/top/vision/{yolo2 => yolo}/reorg.h | 6 +- .../topi/vision/{yolo2 => yolo}/region.h | 12 ++-- topi/include/topi/vision/yolo/yolo.h | 58 +++++++++++++++++++ topi/python/topi/cpp.py | 4 +- topi/python/topi/testing/__init__.py | 1 + topi/python/topi/testing/yolo_python.py | 43 ++++++++++++++ topi/python/topi/vision/__init__.py | 2 +- .../topi/vision/{yolo2 => yolo}/__init__.py | 1 + .../topi/vision/{yolo2 => yolo}/region.py | 2 +- topi/python/topi/vision/yolo/yolo.py | 30 ++++++++++ topi/src/topi.cc | 12 +++- topi/tests/python/test_topi_region.py | 2 +- topi/tests/python_cpp/test_topi_region.py | 2 +- topi/tests/python_cpp/test_topi_yolo.py | 49 ++++++++++++++++ 19 files changed, 220 insertions(+), 32 deletions(-) rename nnvm/src/top/vision/{yolo2 => yolo}/region.cc (96%) rename nnvm/src/top/vision/{yolo2 => yolo}/region.h (96%) rename nnvm/src/top/vision/{yolo2 => yolo}/reorg.cc (98%) rename nnvm/src/top/vision/{yolo2 => yolo}/reorg.h (96%) rename topi/include/topi/vision/{yolo2 => yolo}/region.h (92%) create mode 100644 topi/include/topi/vision/yolo/yolo.h create mode 100644 topi/python/topi/testing/yolo_python.py rename topi/python/topi/vision/{yolo2 => yolo}/__init__.py (87%) rename topi/python/topi/vision/{yolo2 => yolo}/region.py (91%) create mode 100644 topi/python/topi/vision/yolo/yolo.py create mode 100644 topi/tests/python_cpp/test_topi_yolo.py diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py index f2c744bea..c18585a27 100644 --- a/nnvm/python/nnvm/frontend/darknet.py +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -276,14 +276,14 @@ def _darknet_route(inputs, attrs): def _darknet_reorg(inputs, attrs): """Process the reorg operation.""" - op_name, new_attrs = 'yolo2_reorg', {} + op_name, new_attrs = 'yolo_reorg', {} if 'stride' in attrs: new_attrs = {'stride': attrs.get('stride', 1)} return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_region(inputs, attrs): """Process the region operation.""" - op_name, new_attrs = 'yolo2_region', {} + op_name, new_attrs = 'yolo_region', {} if 'n' in attrs: new_attrs['n'] = attrs.get('n', 1) if 'classes' in attrs: diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index edbf72320..f2e12c0f3 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -7,20 +7,20 @@ import topi from . import registry as reg from .registry import OpPattern -@reg.register_compute("yolo2_reorg") +@reg.register_compute("yolo_reorg") def compute_reorg(attrs, inputs, _): """Compute definition of reorg""" return topi.vision.reorg(inputs[0], attrs.get_int("stride")) -@reg.register_schedule("yolo2_reorg") +@reg.register_schedule("yolo_reorg") def schedule_reorg(attrs, outs, target): """Schedule definition of reorg""" with tvm.target.create(target): return topi.generic.schedule_injective(outs) -reg.register_pattern("yolo2_reorg", OpPattern.INJECTIVE) +reg.register_pattern("yolo_reorg", OpPattern.INJECTIVE) -@reg.register_compute("yolo2_region") +@reg.register_compute("yolo_region") def compute_region(attrs, inputs, _): """Compute definition of region""" n = attrs.get_int("n") @@ -28,15 +28,15 @@ def compute_region(attrs, inputs, _): coords = attrs.get_int("coords") background = attrs.get_int("background") softmax = attrs.get_int("softmax") - return topi.vision.yolo2.region(inputs[0], n, classes, coords, background, softmax) + return topi.vision.yolo.region(inputs[0], n, classes, coords, background, softmax) -@reg.register_schedule("yolo2_region") +@reg.register_schedule("yolo_region") def schedule_region(attrs, outs, target): """Schedule definition of region""" with tvm.target.create(target): return topi.generic.vision.schedule_region(outs) -reg.register_pattern("yolo2_region", OpPattern.OPAQUE) +reg.register_pattern("yolo_region", OpPattern.OPAQUE) # multibox_prior @reg.register_schedule("multibox_prior") diff --git a/nnvm/src/top/vision/yolo2/region.cc b/nnvm/src/top/vision/yolo/region.cc similarity index 96% rename from nnvm/src/top/vision/yolo2/region.cc rename to nnvm/src/top/vision/yolo/region.cc index 87860be3d..182c9b2ab 100644 --- a/nnvm/src/top/vision/yolo2/region.cc +++ b/nnvm/src/top/vision/yolo/region.cc @@ -13,7 +13,7 @@ namespace nnvm { namespace top { -NNVM_REGISTER_OP(yolo2_region) +NNVM_REGISTER_OP(yolo_region) .describe(R"code(Region layer )code" NNVM_ADD_FILELINE) .set_num_inputs(1) diff --git a/nnvm/src/top/vision/yolo2/region.h b/nnvm/src/top/vision/yolo/region.h similarity index 96% rename from nnvm/src/top/vision/yolo2/region.h rename to nnvm/src/top/vision/yolo/region.h index cc816eab6..f9dc87c59 100644 --- a/nnvm/src/top/vision/yolo2/region.h +++ b/nnvm/src/top/vision/yolo/region.h @@ -2,8 +2,8 @@ * Copyright (c) 2018 by Contributors * \file region.h */ -#ifndef NNVM_TOP_VISION_YOLO2_REGION_H_ -#define NNVM_TOP_VISION_YOLO2_REGION_H_ +#ifndef NNVM_TOP_VISION_YOLO_REGION_H_ +#define NNVM_TOP_VISION_YOLO_REGION_H_ #include <string> #include <vector> @@ -98,4 +98,4 @@ inline bool RegionType(const NodeAttrs &attrs, } } // namespace top } // namespace nnvm -#endif // NNVM_TOP_VISION_YOLO2_REGION_H_ +#endif // NNVM_TOP_VISION_YOLO_REGION_H_ diff --git a/nnvm/src/top/vision/yolo2/reorg.cc b/nnvm/src/top/vision/yolo/reorg.cc similarity index 98% rename from nnvm/src/top/vision/yolo2/reorg.cc rename to nnvm/src/top/vision/yolo/reorg.cc index e58940eb2..e44d77c07 100644 --- a/nnvm/src/top/vision/yolo2/reorg.cc +++ b/nnvm/src/top/vision/yolo/reorg.cc @@ -34,7 +34,7 @@ inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs, return true; } -NNVM_REGISTER_OP(yolo2_reorg) +NNVM_REGISTER_OP(yolo_reorg) .describe(R"(Perform reorg operation on input array based on the stride value. - **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width). - **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride). diff --git a/nnvm/src/top/vision/yolo2/reorg.h b/nnvm/src/top/vision/yolo/reorg.h similarity index 96% rename from nnvm/src/top/vision/yolo2/reorg.h rename to nnvm/src/top/vision/yolo/reorg.h index 87e0510e2..a16edecea 100644 --- a/nnvm/src/top/vision/yolo2/reorg.h +++ b/nnvm/src/top/vision/yolo/reorg.h @@ -2,8 +2,8 @@ * Copyright (c) 2018 by Contributors * \file reorg.h */ -#ifndef NNVM_TOP_VISION_YOLO2_REORG_H_ -#define NNVM_TOP_VISION_YOLO2_REORG_H_ +#ifndef NNVM_TOP_VISION_YOLO_REORG_H_ +#define NNVM_TOP_VISION_YOLO_REORG_H_ #include <string> #include <vector> @@ -107,4 +107,4 @@ struct ReorgParam : public dmlc::Parameter<ReorgParam> { }; } // namespace top } // namespace nnvm -#endif // NNVM_TOP_VISION_YOLO2_REORG_H_ +#endif // NNVM_TOP_VISION_YOLO_REORG_H_ diff --git a/topi/include/topi/vision/yolo2/region.h b/topi/include/topi/vision/yolo/region.h similarity index 92% rename from topi/include/topi/vision/yolo2/region.h rename to topi/include/topi/vision/yolo/region.h index a77ff49b5..88553fc29 100644 --- a/topi/include/topi/vision/yolo2/region.h +++ b/topi/include/topi/vision/yolo/region.h @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors * \brief Region op constructions - * \file vision/yolo2/region.h + * \file vision/yolo/region.h */ -#ifndef TOPI_VISION_YOLO2_REGION_H_ -#define TOPI_VISION_YOLO2_REGION_H_ +#ifndef TOPI_VISION_YOLO_REGION_H_ +#define TOPI_VISION_YOLO_REGION_H_ #include <algorithm> #include <string> @@ -19,7 +19,7 @@ namespace topi { namespace vision { -namespace yolo2 { +namespace yolo { using namespace tvm; using namespace nn; @@ -75,7 +75,7 @@ inline Tensor region(const Tensor &data, Tensor out = concatenate(split_res, 2); return reshape(out, input_shape); } -} // namespace yolo2 +} // namespace yolo } // namespace vision } // namespace topi -#endif // TOPI_VISION_YOLO2_REGION_H_ +#endif // TOPI_VISION_YOLO_REGION_H_ diff --git a/topi/include/topi/vision/yolo/yolo.h b/topi/include/topi/vision/yolo/yolo.h new file mode 100644 index 000000000..d2e24c01b --- /dev/null +++ b/topi/include/topi/vision/yolo/yolo.h @@ -0,0 +1,58 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief YOLO op constructions + * \file vision/yolo/yolo.h + */ +#ifndef TOPI_VISION_YOLO_YOLO_H_ +#define TOPI_VISION_YOLO_YOLO_H_ + +#include <algorithm> +#include <string> + +#include "topi/detail/constant_utils.h" +#include "topi/tags.h" +#include "topi/transform.h" +#include "tvm/tvm.h" + + +namespace topi { +namespace vision { +namespace yolo { +using namespace tvm; +using namespace nn; + +/*! +* \brief yolo operation +* +* \param data The input tensor. +* \param num Darknet layer parameter n +* \param classes number of classes in the yolo model +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the yolo operation +*/ +inline Tensor yolo(const Tensor &data, + int num, + int classes, + std::string name = "tensor", + std::string tag = "yolo_output") { + auto input_shape = data->shape; + int split_size = classes + 5; + Array <Expr> intermediate_shape = {input_shape[0], + num, + split_size, + input_shape[2], + input_shape[3]}; + auto data_block = reshape(data, intermediate_shape); + Array <Expr> split_indices = {2, 4}; + Array <Tensor> split_res = split(data_block, split_indices, 2); + split_res.Set(0, sigmoid(split_res[0])); + split_res.Set(2, sigmoid(split_res[2])); + Tensor out = concatenate(split_res, 2); + return reshape(out, input_shape); +} +} // namespace yolo +} // namespace vision +} // namespace topi +#endif // TOPI_VISION_YOLO_YOLO_H_ diff --git a/topi/python/topi/cpp.py b/topi/python/topi/cpp.py index cf7314be4..85f203387 100644 --- a/topi/python/topi/cpp.py +++ b/topi/python/topi/cpp.py @@ -46,7 +46,7 @@ x86 = _create_module("x86") _init_api_prefix("topi.cpp.x86", "topi.x86") vision = _create_module("vision") _init_api_prefix("topi.cpp.vision", "topi.vision") -yolo2 = _create_module("vision.yolo2") -_init_api_prefix("topi.cpp.vision.yolo2", "topi.vision.yolo2") +yolo = _create_module("vision.yolo") +_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo") image = _create_module("image") _init_api_prefix("topi.cpp.image", "topi.image") diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index c91eea795..c9d995a38 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -15,6 +15,7 @@ from .upsampling_python import upsampling_python from .bilinear_resize_python import bilinear_resize_python from .reorg_python import reorg_python from .region_python import region_python +from .yolo_python import yolo_python from .shortcut_python import shortcut_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python diff --git a/topi/python/topi/testing/yolo_python.py b/topi/python/topi/testing/yolo_python.py new file mode 100644 index 000000000..a6b3a4120 --- /dev/null +++ b/topi/python/topi/testing/yolo_python.py @@ -0,0 +1,43 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Yolo operator in python""" +import numpy as np + +def entry_index(batch, w, h, outputs, classes, coords, location, entry): + n = int(location/(w*h)) + loc = location%(w*h) + return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc + +def yolo_python(a_np, N, classes): + """Yolo operator + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + N : int + Darknet layer parameter n + + classes : int + Darknet layer parameter classes + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + + batch, in_channel, in_height, in_width = a_np.shape + a_np_temp = np.reshape(a_np, batch*in_channel*in_height*in_width) + outputs = batch*in_channel*in_height*in_width + b_np = np.zeros(batch*in_channel*in_height*in_width) + for i in range(batch*in_channel*in_height*in_width): + b_np[i] = a_np_temp[i] + for b in range(batch): + for n in range(N): + index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 0) + b_np[index: index+2*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+2*in_width*in_height])) + index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 4) + b_np[index: index+(1+classes)*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+(1+classes)*in_width*in_height])) + + b_np = np.reshape(b_np, (batch, in_channel, in_height, in_width)) + return b_np diff --git a/topi/python/topi/vision/__init__.py b/topi/python/topi/vision/__init__.py index 9b6d9bf5e..a94cd9c7c 100644 --- a/topi/python/topi/vision/__init__.py +++ b/topi/python/topi/vision/__init__.py @@ -2,7 +2,7 @@ """VISION network operators""" from __future__ import absolute_import as _abs -from . import yolo2, ssd +from . import yolo, ssd from .shortcut import * from .reorg import * from .nms import * diff --git a/topi/python/topi/vision/yolo2/__init__.py b/topi/python/topi/vision/yolo/__init__.py similarity index 87% rename from topi/python/topi/vision/yolo2/__init__.py rename to topi/python/topi/vision/yolo/__init__.py index c0e9899a4..2c0a165f8 100644 --- a/topi/python/topi/vision/yolo2/__init__.py +++ b/topi/python/topi/vision/yolo/__init__.py @@ -3,3 +3,4 @@ from __future__ import absolute_import as _abs from .region import * +from .yolo import * diff --git a/topi/python/topi/vision/yolo2/region.py b/topi/python/topi/vision/yolo/region.py similarity index 91% rename from topi/python/topi/vision/yolo2/region.py rename to topi/python/topi/vision/yolo/region.py index 79dfd6961..77c1c86a8 100644 --- a/topi/python/topi/vision/yolo2/region.py +++ b/topi/python/topi/vision/yolo/region.py @@ -36,4 +36,4 @@ def region(data, num, classes, coords, background, softmax=True): out : tvm.Tensor 4-D with shape [batch, c_in, h_in, w_in] """ - return cpp.yolo2.region(data, num, classes, coords, background, softmax) + return cpp.yolo.region(data, num, classes, coords, background, softmax) diff --git a/topi/python/topi/vision/yolo/yolo.py b/topi/python/topi/vision/yolo/yolo.py new file mode 100644 index 000000000..6ae630a86 --- /dev/null +++ b/topi/python/topi/vision/yolo/yolo.py @@ -0,0 +1,30 @@ +# pylint: disable=invalid-name, unused-variable +""" +YOLO Operator +============= +YOLO operator, used in darknet. +""" +from __future__ import absolute_import as _abs +import tvm +from ... import cpp + +@tvm.target.generic_func +def yolo(data, num, classes): + """YOLO forward operators. + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, c_in, h_in, w_in] + + num : int + Darknet layer parameter n + + classes : int + Darknet layer parameter classes + + Returns + ------- + out : tvm.Tensor + 4-D with shape [batch, c_in, h_in, w_in] + """ + return cpp.yolo.yolo(data, num, classes) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index a4a87b309..c08bd5f56 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -29,7 +29,8 @@ #include <topi/vision/reorg.h> #include <topi/image/resize.h> -#include <topi/vision/yolo2/region.h> +#include <topi/vision/yolo/region.h> +#include <topi/vision/yolo/yolo.h> #include <topi/generic/default.h> #include <topi/generic/extern.h> #include <topi/generic/injective.h> @@ -386,9 +387,14 @@ TVM_REGISTER_GLOBAL("topi.vision.reorg") *rv = vision::reorg(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.vision.yolo2.region") +TVM_REGISTER_GLOBAL("topi.vision.yolo.region") .set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = vision::yolo2::region(args[0], args[1], args[2], args[3], args[4], args[5]); + *rv = vision::yolo::region(args[0], args[1], args[2], args[3], args[4], args[5]); + }); + +TVM_REGISTER_GLOBAL("topi.vision.yolo.yolo") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = vision::yolo::yolo(args[0], args[1], args[2]); }); /* Ops from image/resize.h */ diff --git a/topi/tests/python/test_topi_region.py b/topi/tests/python/test_topi_region.py index f83001535..a2835339e 100644 --- a/topi/tests/python/test_topi_region.py +++ b/topi/tests/python/test_topi_region.py @@ -10,7 +10,7 @@ def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_ in_height = in_width = in_size A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - B = topi.vision.yolo2.region(A, n, classes, coords, background, l_softmax) + B = topi.vision.yolo.region(A, n, classes, coords, background, l_softmax) a_shape = get_const_tuple(A.shape) dtype = A.dtype diff --git a/topi/tests/python_cpp/test_topi_region.py b/topi/tests/python_cpp/test_topi_region.py index 8d40fa1e6..a37cf6610 100644 --- a/topi/tests/python_cpp/test_topi_region.py +++ b/topi/tests/python_cpp/test_topi_region.py @@ -11,7 +11,7 @@ def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_ in_height = in_width = in_size A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - B = topi.cpp.yolo2.region(A, n, classes, coords, background, l_softmax) + B = topi.cpp.yolo.region(A, n, classes, coords, background, l_softmax) a_shape = get_const_tuple(A.shape) dtype = A.dtype diff --git a/topi/tests/python_cpp/test_topi_yolo.py b/topi/tests/python_cpp/test_topi_yolo.py new file mode 100644 index 000000000..ed234b7bd --- /dev/null +++ b/topi/tests/python_cpp/test_topi_yolo.py @@ -0,0 +1,49 @@ +"""Test code for yolo op""" +import logging +import numpy as np +import tvm +import topi +import topi.testing +from topi.util import get_const_tuple + +def verify_yolo(ishape, n, classes): + '''Verify yolo operator by comparing outputs from tvm and numpy implementation''' + + A = tvm.placeholder(ishape, name='A') + B = topi.cpp.yolo.yolo(A, n, classes) + dtype = A.dtype + + def get_ref_data_yolo(): + '''Randomly initialize the data variables and get refernce output for the yolo operation''' + a_np = np.random.uniform(size=ishape).astype(dtype) + b_np = topi.testing.yolo_python(a_np, n, classes) + return a_np, b_np + + a_np, b_np = get_ref_data_yolo() + def check_device(device): + '''Check the device is available and if so, build and run the program''' + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + target = topi.cpp.TEST_create_target(device) + if device == "llvm": + s = topi.cpp.generic.default_schedule(target, [B], False) + else: + s = topi.cpp.cuda.schedule_injective(target, [B]) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, B], device, name="yolo") + func(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']: + check_device(device) + +def test_yolo(): + verify_yolo((1, 425, 19, 19), 5, 80) + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + test_yolo() -- GitLab