From 7ec898d55ed628cb06cef6448e715b818eeb75d5 Mon Sep 17 00:00:00 2001
From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com>
Date: Fri, 6 Apr 2018 10:23:29 +0530
Subject: [PATCH] [FRONTEND] DarkNet Yolo2 Frontend Support  (#377)

---
 nnvm/Makefile                                 |   2 +-
 nnvm/python/nnvm/frontend/__init__.py         |   1 +
 nnvm/python/nnvm/frontend/darknet.py          | 637 ++++++++++++++++++
 nnvm/python/nnvm/testing/__init__.py          |   2 +
 nnvm/python/nnvm/testing/darknet.py           | 494 ++++++++++++++
 nnvm/python/nnvm/testing/yolo2_detection.py   | 246 +++++++
 nnvm/python/nnvm/top/__init__.py              |   1 +
 nnvm/python/nnvm/top/vision.py                |  40 ++
 nnvm/src/top/vision/yolo2/region.cc           |  35 +
 nnvm/src/top/vision/yolo2/region.h            | 101 +++
 nnvm/src/top/vision/yolo2/reorg.cc            |  52 ++
 nnvm/src/top/vision/yolo2/reorg.h             | 110 +++
 nnvm/tests/ci_build/Dockerfile.gpu            |   3 +
 .../install/ubuntu_install_darknet.sh         |   4 +
 .../python/frontend/darknet/test_forward.py   | 257 +++++++
 nnvm/tutorials/from_darknet.py                | 227 +++++++
 16 files changed, 2211 insertions(+), 1 deletion(-)
 create mode 100644 nnvm/python/nnvm/frontend/darknet.py
 create mode 100644 nnvm/python/nnvm/testing/darknet.py
 create mode 100644 nnvm/python/nnvm/testing/yolo2_detection.py
 create mode 100644 nnvm/python/nnvm/top/vision.py
 create mode 100644 nnvm/src/top/vision/yolo2/region.cc
 create mode 100644 nnvm/src/top/vision/yolo2/region.h
 create mode 100644 nnvm/src/top/vision/yolo2/reorg.cc
 create mode 100644 nnvm/src/top/vision/yolo2/reorg.h
 create mode 100644 nnvm/tests/ci_build/install/ubuntu_install_darknet.sh
 create mode 100644 nnvm/tests/python/frontend/darknet/test_forward.py
 create mode 100644 nnvm/tutorials/from_darknet.py

diff --git a/nnvm/Makefile b/nnvm/Makefile
index 4779e95b3..62a4fadad 100644
--- a/nnvm/Makefile
+++ b/nnvm/Makefile
@@ -56,7 +56,7 @@ endif
 all: lib/libnnvm.a lib/libnnvm_compiler.$(SHARED_LIBRARY_SUFFIX)
 
 SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc)
-SRC_COMPILER = $(wildcard src/top/*/*.cc src/compiler/*.cc src/compiler/*/*.cc)
+SRC_COMPILER = $(wildcard src/top/*/*.cc wildcard src/top/vision/*/*.cc src/compiler/*.cc src/compiler/*/*.cc)
 ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC))
 TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_COMPILER))
 ALL_DEP = $(ALL_OBJ)
diff --git a/nnvm/python/nnvm/frontend/__init__.py b/nnvm/python/nnvm/frontend/__init__.py
index 100d4115b..00ed9e51f 100644
--- a/nnvm/python/nnvm/frontend/__init__.py
+++ b/nnvm/python/nnvm/frontend/__init__.py
@@ -4,3 +4,4 @@ from .mxnet import from_mxnet
 from .onnx import from_onnx
 from .coreml import from_coreml
 from .keras import from_keras
+from .darknet import from_darknet
diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py
new file mode 100644
index 000000000..413b07d64
--- /dev/null
+++ b/nnvm/python/nnvm/frontend/darknet.py
@@ -0,0 +1,637 @@
+"""
+DarkNet symbol frontend.
+"""
+
+from __future__ import absolute_import as _abs
+from enum import IntEnum
+import numpy as np
+import tvm
+from .. import symbol as _sym
+
+class LAYERTYPE(IntEnum):
+    """Darknet LAYERTYPE Class constant."""
+    CONVOLUTIONAL = 0
+    DECONVOLUTIONAL = 1
+    CONNECTED = 2
+    MAXPOOL = 3
+    SOFTMAX = 4
+    DETECTION = 5
+    DROPOUT = 6
+    CROP = 7
+    ROUTE = 8
+    COST = 9
+    NORMALIZATION = 10
+    AVGPOOL = 11
+    LOCAL = 12
+    SHORTCUT = 13
+    ACTIVE = 14
+    RNN = 15
+    GRU = 16
+    LSTM = 17
+    CRNN = 18
+    BATCHNORM = 19
+    NETWORK = 20
+    XNOR = 21
+    REGION = 22
+    REORG = 23
+    BLANK = 24
+
+class ACTIVATION(IntEnum):
+    """Darknet ACTIVATION Class constant."""
+    LOGISTIC = 0
+    RELU = 1
+    RELIE = 2
+    LINEAR = 3
+    RAMP = 4
+    TANH = 5
+    PLSE = 6
+    LEAKY = 7
+    ELU = 8
+    LOGGY = 9
+    STAIR = 10
+    HARDTAN = 11
+    LHTAN = 12
+
+__all__ = ['from_darknet']
+
+def _darknet_get_nnvm_op(op_name):
+    """Get the nnvm operation from opname, raise error if not supported."""
+    op = getattr(_sym, op_name)
+    if not op:
+        raise RuntimeError("Not to map op_name {} to nnvm.sym".format(op_name))
+    return op
+
+def _darknet_required_attr(attr, key):
+    """Check the attribute exists and return if exists, if not return error."""
+    assert isinstance(attr, dict)
+    if key not in attr:
+        raise AttributeError("Required attribute {} not found.".format(key))
+    return attr[key]
+
+def _darknet_raise_not_supported(attr, op='nnvm'):
+    """Raise error if any operation is not supported."""
+    err = "{} is not supported in {}.".format(attr, op)
+    raise NotImplementedError(err)
+
+def _darknet_warn_not_used(attr, op='nnvm'):
+    """Raise warning if any operation not supported."""
+    import warnings
+    err = "{} is ignored in {}.".format(attr, op)
+    warnings.warn(err)
+
+def _darknet_parse_tshape(tshape):
+    """Parse tshape in string."""
+    return [int(x.strip()) for x in tshape.strip('()').split(',')]
+
+def _darknet_parse_bool_str(attr, key, default='False'):
+    """Parse bool string to boolean."""
+    return attr.get(key, default).strip().lower() in \
+                                    ['true', '1', 't', 'y', 'yes']
+
+def _darknet_maxpooling(inputs, attrs):
+    """Process the max pool 2d operation."""
+    kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
+    if len(kernel) != 1:
+        _darknet_raise_not_supported('non-2d kernel', 'pool_2d')
+
+    op_name, new_attrs = 'max_pool2d', {}
+    strides = int(attrs.get('stride', (1, 1)))
+    pads = int(attrs.get('pad', (0, 0)))
+    new_attrs['pool_size'] = [kernel[0], kernel[0]]
+    new_attrs['strides'] = str((strides, strides))
+    new_attrs['padding'] = str((pads, pads))
+
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_avgpooling(inputs, attrs):
+    """Process the average pool 2d operation."""
+    kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
+    if len(kernel) != 1:
+        _darknet_raise_not_supported('non-2d kernel', 'pool_2d')
+
+    op_name, new_attrs = 'avg_pool2d', {}
+    strides = int(attrs.get('stride', (1, 1)))
+    pads = int(attrs.get('pad', (0, 0)))
+    new_attrs['pool_size'] = [kernel[0], kernel[0]]
+    new_attrs['strides'] = str((strides, strides))
+    new_attrs['padding'] = str((pads, pads))
+
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_batch_norm(inputs, attrs):
+    """Process the batchnormalization operation."""
+    op_name, new_attrs = 'darknet_batch_norm', {}
+    new_attrs['axis'] = attrs.get('axis', 1)
+    new_attrs['epsilon'] = attrs.get('eps', 0.000001)
+    new_attrs['center'] = True
+    new_attrs['scale'] = True
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_conv2d(inputs, attrs):
+    """Process the convolution 2d operation."""
+    kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
+    if len(kernel) != 1:
+        _darknet_raise_not_supported('non 2d kernel', 'conv2d')
+    layout = attrs.get('layout', 'NCHW')
+    if layout not in ['NCHW', 'NHWC']:
+        _darknet_raise_not_supported('layout: ' + layout, 'conv2d')
+    strides = int(attrs.get('stride', (1, 1)))
+    pads = int(attrs.get('pad', (0, 0)))
+
+    op_name, new_attrs = 'conv2d', {}
+    new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter')
+    new_attrs['kernel_size'] = [kernel[0], kernel[0]]
+    new_attrs['strides'] = (strides, strides)
+    new_attrs['padding'] = (pads, pads)
+    new_attrs['dilation'] = attrs.get('dilate', (1, 1))
+    new_attrs['groups'] = attrs.get('num_group', 1)
+    new_attrs['layout'] = layout
+    if attrs.get('use_batchNorm', False) is True:
+        new_attrs['use_bias'] = False
+    else:
+        new_attrs['use_bias'] = True
+    out_name = {}
+    sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
+    out_name[0] = sym.list_output_names()[0].replace('_output', '')
+
+    if attrs.get('use_batchNorm', False) is True:
+        op_name, new_attrs = 'batch_norm', {}
+        new_attrs['epsilon'] = 0.000001
+        sym = _darknet_get_nnvm_op(op_name)(*sym, **new_attrs)
+        out_name[1] = sym.list_output_names()[0].replace('_output', '')
+    if 'activation' in attrs:
+        new_attrs = {}
+        new_attrs['activation'] = attrs['activation']
+        new_attrs['slope'] = 0.1
+        sym, _ = _darknet_activations(sym, new_attrs)
+    return sym, out_name
+
+
+def _darknet_conv2d_transpose(inputs, attrs):
+    """Process the convolution 2d transpose operation."""
+    if 'target_shape' in attrs:
+        _darknet_raise_not_supported('target_shape', 'conv2d_transpose')
+    kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
+    if len(kernel) != 2:
+        _darknet_raise_not_supported('non-2d kernel', 'conv2d_transpose')
+    layout = attrs.get('layout', 'NCHW')
+    if layout not in ['NCHW', 'NHWC']:
+        _darknet_raise_not_supported('layout: ' + layout, 'conv2d_transpose')
+    op_name, new_attrs = 'conv2d_transpose', {}
+    new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter')
+    new_attrs['kernel_size'] = kernel
+    new_attrs['strides'] = attrs.get('stride', (1, 1))
+    new_attrs['output_padding'] = attrs.get('adj', (0, 0))
+    new_attrs['padding'] = attrs.get('pad', (0, 0))
+    new_attrs['dilation'] = attrs.get('dilate', (1, 1))
+    new_attrs['groups'] = attrs.get('num_group', 1)
+    new_attrs['layout'] = layout
+    new_attrs['use_bias'] = not _darknet_parse_bool_str(attrs, 'no_bias')
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_shortcut(inputs, attrs):
+    """Process the shortcut operation."""
+    op_name, new_attrs = 'elemwise_add', {}
+    input_0 = inputs[0]
+    input_1 = inputs[1]
+    input_0_channel = int(attrs['out_channel'])
+    input_1_channel = int(attrs['add_out_channel'])
+    input_0_size = int(attrs['out_size'])
+    input_1_size = int(attrs['add_out_size'])
+
+    if input_0_size > input_1_size:
+        scale = int(input_0_size/input_1_size)
+        input_1 = _sym.upsampling(input_1, scale=scale, name="_upsampling")
+    elif input_0_size < input_1_size:
+        stride = int(input_1_size/input_0_size)
+        input_1 = _sym.avg_pool2d(input_1, pool_size=(1, 1),
+                                  strides=(stride, stride), padding=(0, 0), name="_downsampling")
+
+    if input_0_channel != input_1_channel:
+        pad_channel = input_0_channel - input_1_channel
+        input_1 = _sym.pad(input_1, pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)),
+                           pad_value=0.)
+
+    new_inputs = _as_list([input_0, input_1])
+    sym = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs)
+    out_name = sym.list_output_names()[0].replace('_output', '')
+    if 'activation' in attrs:
+        new_attrs['activation'] = attrs['activation']
+        sym, _ = _darknet_activations(sym, new_attrs)
+    return sym, out_name
+
+def _darknet_dense(inputs, attrs):
+    """Process the dense operation."""
+    op_name, new_attrs = 'dense', {}
+    new_attrs['units'] = _darknet_required_attr(attrs, 'num_hidden')
+
+    if attrs.get('use_bias', False) is True:
+        new_attrs['use_bias'] = True
+    if attrs.get('use_flatten', False) is True:
+        inputs[0] = _sym.flatten(inputs[0])
+    sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
+    out_name = sym.list_output_names()[0].replace('_output', '')
+    if 'activation' in attrs:
+        new_attrs = {}
+        new_attrs['activation'] = attrs['activation']
+        sym, _ = _darknet_activations(sym, new_attrs)
+    return sym, out_name
+
+def _darknet_dropout(inputs, attrs):
+    """Process the dropout operation, its a blank operation."""
+    op_name, new_attrs = 'dropout', {}
+    new_attrs['rate'] = attrs.get('p', 0.5)
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_reshape(inputs, attrs):
+    """Process the reshape operation."""
+    if _darknet_parse_bool_str(attrs, 'reverse'):
+        _darknet_raise_not_supported('reverse', 'reshape')
+    op_name, new_attrs = 'reshape', {}
+    new_attrs['shape'] = _darknet_required_attr(attrs, 'shape')
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_softmax_output(inputs, attrs):
+    """Process the softmax operation."""
+    op_name, new_attrs = 'softmax', {}
+    if _darknet_parse_bool_str(attrs, 'multi_output'):
+        new_attrs['axis'] = 1
+
+    if attrs.get('use_flatten', False) is True:
+        inputs[0] = _sym.flatten(inputs[0])
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_route(inputs, attrs):
+    """Process the route operation, which is equivalent to concat."""
+    op_name = 'concatenate'
+    new_attrs = {'axis': attrs.get('dim', 1)}
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_reorg(inputs, attrs):
+    """Process the reorg operation."""
+    op_name, new_attrs = 'yolo2_reorg', {}
+    if 'stride' in attrs:
+        new_attrs = {'stride': attrs.get('stride', 1)}
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_region(inputs, attrs):
+    """Process the region operation."""
+    op_name, new_attrs = 'yolo2_region', {}
+    if 'n' in attrs:
+        new_attrs['n'] = attrs.get('n', 1)
+    if 'classes' in attrs:
+        new_attrs['classes'] = attrs.get('classes', 1)
+    if 'coords' in attrs:
+        new_attrs['coords'] = attrs.get('coords', 0)
+    if 'background' in attrs:
+        new_attrs['background'] = attrs.get('background', 0)
+    if 'softmax' in attrs:
+        new_attrs['softmax'] = attrs.get('softmax', 0)
+    return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
+
+def _darknet_activations(inputs, attrs):
+    """Process the activation function."""
+    act = _darknet_required_attr(attrs, 'activation')
+    if ACTIVATION.RELU == act:
+        act_type = 'relu'
+    elif ACTIVATION.TANH == act:
+        act_type = 'tanh'
+    elif ACTIVATION.LINEAR == act:
+        return inputs, None
+    elif ACTIVATION.LEAKY == act:
+        act_type = 'leaky_relu'
+    else:
+        _darknet_raise_not_supported('act: ' + act)
+
+    if act_type in ['relu', 'tanh']:
+        op_name, new_attrs = act_type, {}
+        sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
+    elif act_type in ['leaky_relu']:
+        op_name, new_attrs = act_type, {}
+        new_attrs['alpha'] = attrs.get('slope', 0.1)
+        sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
+    else:
+        _darknet_raise_not_supported('act_type: ' + act_type)
+    return sym, None
+
+def _darknet_op_not_support(inputs, attrs):
+    """Raise exception if the operation is not supported."""
+    err = "{} is not supported in {}.".format(attrs, inputs)
+    raise NotImplementedError(err)
+
+_DARKNET_CONVERT_MAP = {
+    'CONVOLUTIONAL'   : _darknet_conv2d,
+    'DECONVOLUTIONAL' : _darknet_conv2d_transpose,
+    'CONNECTED'       : _darknet_dense,
+    'MAXPOOL'         : _darknet_maxpooling,
+    'SOFTMAX'         : _darknet_softmax_output,
+    'DROPOUT'         : _darknet_dropout,
+    'AVGPOOL'         : _darknet_avgpooling,
+    'BATCHNORM'       : _darknet_batch_norm,
+    'RESHAPE'         : _darknet_reshape,
+    'ROUTE'           : _darknet_route,
+    'REORG'           : _darknet_reorg,
+    'REGION'          : _darknet_region,
+    'ACTIVATION'      : _darknet_activations,
+    'SHORTCUT'        : _darknet_shortcut,
+    'DETECTION'       : _darknet_op_not_support,
+    'CROP'            : _darknet_op_not_support,
+    'COST'            : _darknet_op_not_support,
+    'NORMALIZATION'   : _darknet_op_not_support,
+    'LOCAL'           : _darknet_op_not_support,
+    'ACTIVE'          : _darknet_op_not_support,
+    'RNN'             : _darknet_op_not_support,
+    'GRU'             : _darknet_op_not_support,
+    'LSTM'            : _darknet_op_not_support,
+    'CRNN'            : _darknet_op_not_support,
+    'NETWORK'         : _darknet_op_not_support,
+    'XNOR'            : _darknet_op_not_support,
+    'BLANK'           : _darknet_op_not_support,
+}
+
+def _darknet_convert_symbol(op_name, inputs, attrs):
+    """Convert from darknet op to nnvm op.
+    The converter must specify some conversions explicitly to
+    support gluon format ops such as conv2d...
+
+    Parameters
+    ----------
+    op_name : str
+        Operator name, such as Convolution, Connected, etc
+    inputs : list of nnvm.Symbol
+        List of input symbols.
+    attrs : dict
+        Dict of operator attributes
+
+    Returns
+    -------
+    out_name : converted out name of operation
+    sym : nnvm.Symbol
+        Converted nnvm Symbol
+    """
+
+    if op_name in _DARKNET_CONVERT_MAP:
+        sym, out_name = _DARKNET_CONVERT_MAP[op_name](inputs, attrs)
+    else:
+        _darknet_raise_not_supported('Operator: ' + op_name)
+    if out_name is  None:
+        out_name = sym.list_output_names()[0].replace('_output', '')
+    return out_name, sym
+
+
+def _as_list(arr):
+    """Force being a list, ignore if already is."""
+    if isinstance(arr, list):
+        return arr
+    return [arr]
+
+def _read_memory_buffer(shape, data, dtype):
+    length = 1
+    for x in shape:
+        length *= x
+    data_np = np.zeros(length, dtype=dtype)
+    for i in range(length):
+        data_np[i] = data[i]
+    return data_np.reshape(shape)
+
+def _get_darknet_layername(layer_type):
+    """Get the layer name from the darknet enums."""
+    return str((LAYERTYPE(layer_type))).replace('LAYERTYPE.', '')
+
+def _get_convolution_weights(layer, opname, params, dtype):
+    """Get the convolution layer weights and biases."""
+    if layer.nweights == 0:
+        return
+
+    if (layer.n * layer.c * layer.size * layer.size) != layer.nweights:
+        raise RuntimeError("layer weights size not matching with n c h w")
+
+    weights = _read_memory_buffer((layer.n, layer.c, layer.size, layer.size), layer.weights, dtype)
+
+    biases = _read_memory_buffer((layer.n, ), layer.biases, dtype)
+
+    k = _get_tvm_params_name(opname[0], 'weight')
+    params[k] = tvm.nd.array(weights)
+
+    if layer.batch_normalize == 1 and layer.dontloadscales != 1:
+        _get_batchnorm_weights(layer, opname[1], params, layer.n, dtype)
+        k = _get_tvm_params_name(opname[1], 'beta')
+        params[k] = tvm.nd.array(biases)
+    else:
+        k = _get_tvm_params_name(opname[0], 'bias')
+        params[k] = tvm.nd.array(biases)
+
+def _get_connected_weights(layer, opname, params, dtype):
+    """Parse the weights and biases for fully connected or dense layer."""
+    size = layer.outputs * layer.inputs
+    if size == 0:
+        return
+
+    weights = _read_memory_buffer((layer.outputs, layer.inputs), layer.weights, dtype)
+    biases = _read_memory_buffer((layer.outputs, ), layer.biases, dtype)
+
+    k = _get_tvm_params_name(opname, 'weight')
+    params[k] = tvm.nd.array(weights)
+    k = _get_tvm_params_name(opname, 'bias')
+    params[k] = tvm.nd.array(biases)
+
+    if layer.batch_normalize == 1 and layer.dontloadscales != 1:
+        _get_batchnorm_weights(layer, opname, params, layer.outputs, dtype)
+
+def _get_batchnorm_weights(layer, opname, params, size, dtype):
+    """Parse the weights for batchnorm, which includes, scales, moving mean
+    and moving variances."""
+    scales = _read_memory_buffer((size, ), layer.scales, dtype)
+    rolling_mean = _read_memory_buffer((size, ), layer.rolling_mean, dtype)
+    rolling_variance = _read_memory_buffer((size, ), layer.rolling_variance, dtype)
+
+    k = _get_tvm_params_name(opname, 'moving_mean')
+    params[k] = tvm.nd.array(rolling_mean)
+    k = _get_tvm_params_name(opname, 'moving_var')
+    params[k] = tvm.nd.array(rolling_variance)
+    k = _get_tvm_params_name(opname, 'gamma')
+    params[k] = tvm.nd.array(scales)
+
+def _get_darknet_attrs(net, layer_num):
+    """Parse attributes of each layer and return."""
+    attr = {}
+    use_flatten = True
+    layer = net.layers[layer_num]
+    op_name = _get_darknet_layername(layer.type)
+
+    if LAYERTYPE.CONVOLUTIONAL == layer.type:
+        attr.update({'layout' : 'NCHW'})
+        attr.update({'pad' : str(layer.pad)})
+        attr.update({'num_group' : str(layer.groups)})
+        attr.update({'num_filter' : str(layer.n)})
+        attr.update({'stride' : str(layer.stride)})
+        attr.update({'kernel' : str(layer.size)})
+        attr.update({'activation' : (layer.activation)})
+
+        if layer.nbiases == 0:
+            attr.update({'use_bias' : False})
+        else:
+            attr.update({'use_bias' : True})
+
+        if layer.batch_normalize == 1 and layer.dontloadscales != 1:
+            attr.update({'use_batchNorm' : True})
+            attr.update({'use_scales' : True})
+
+    #elif LAYERTYPE.BATCHNORM == layer.type:
+    #    attr.update({'flatten' : str('True')})
+
+    elif LAYERTYPE.CONNECTED == layer.type:
+        attr.update({'num_hidden' : str(layer.outputs)})
+        attr.update({'activation' : (layer.activation)})
+        if layer_num != 0:
+            layer_prev = net.layers[layer_num - 1]
+            if (layer_prev.out_h == layer.h and
+                    layer_prev.out_w == layer.w and
+                    layer_prev.out_c == layer.c):
+                use_flatten = False
+        attr.update({'use_flatten' : use_flatten})
+        if layer.nbiases == 0:
+            attr.update({'use_bias' : False})
+        else:
+            attr.update({'use_bias' : True})
+        if layer.batch_normalize == 1 and layer.dontloadscales != 1:
+            attr.update({'use_batchNorm' : True})
+            attr.update({'use_scales' : True})
+
+    elif LAYERTYPE.MAXPOOL == layer.type:
+        attr.update({'pad' : str(layer.pad)})
+        attr.update({'stride' : str(layer.stride)})
+        attr.update({'kernel' : str(layer.size)})
+
+    elif LAYERTYPE.AVGPOOL == layer.type:
+        attr.update({'pad' : str(layer.pad)})
+        if layer.stride == 0:
+            attr.update({'stride' : str(1)})
+        else:
+            attr.update({'stride' : str(layer.stride)})
+        if layer.size == 0 and layer.h == layer.w:
+            attr.update({'kernel' : str(layer.h)})
+        else:
+            attr.update({'kernel' : str(layer.size)})
+
+    elif LAYERTYPE.DROPOUT == layer.type:
+        attr.update({'p' : str(layer.probability)})
+
+    elif LAYERTYPE.SOFTMAX == layer.type:
+        attr.update({'axis' : 1})
+        attr.update({'use_flatten' : True})
+
+    elif LAYERTYPE.SHORTCUT == layer.type:
+        add_layer = net.layers[layer.index]
+        attr.update({'activation' : (layer.activation)})
+        attr.update({'out_channel' : (layer.out_c)})
+        attr.update({'out_size' : (layer.out_h)})
+        attr.update({'add_out_channel' : (add_layer.out_c)})
+        attr.update({'add_out_size' : (add_layer.out_h)})
+
+    elif LAYERTYPE.ROUTE == layer.type:
+        pass
+
+    elif LAYERTYPE.COST == layer.type:
+        pass
+
+    elif LAYERTYPE.REORG == layer.type:
+        attr.update({'stride' : layer.stride})
+
+    elif LAYERTYPE.REGION == layer.type:
+        attr.update({'n' : layer.n})
+        attr.update({'classes' : layer.classes})
+        attr.update({'coords' : layer.coords})
+        attr.update({'background' : layer.background})
+        attr.update({'softmax' : layer.softmax})
+    else:
+        err = "Darknet layer {} is not supported in nnvm.".format(op_name)
+        raise NotImplementedError(err)
+
+    return op_name, attr
+
+def _get_tvm_params_name(opname, arg_name):
+    """Makes the params name for the k,v pair."""
+    return opname + '_'+ arg_name
+
+def _get_darknet_params(layer, opname, tvmparams, dtype='float32'):
+    """To parse and get the darknet params."""
+    if LAYERTYPE.CONVOLUTIONAL == layer.type:
+        _get_convolution_weights(layer, opname, tvmparams, dtype)
+
+    #elif LAYERTYPE.BATCHNORM == layer.type:
+    #   size = layer.outputs
+    #   _get_batchnorm_weights(layer, opname, tvmparams, size, dtype)
+
+    elif LAYERTYPE.CONNECTED == layer.type:
+        _get_connected_weights(layer, opname, tvmparams, dtype)
+
+def _preproc_layer(net, i, sym_array):
+    """To preprocess each darknet layer, some layer doesnt need processing."""
+    layer = net.layers[i]
+    if i == 0:
+        name = 'data'
+        attribute = {}
+        sym = [_sym.Variable(name, **attribute)]
+    else:
+        sym = sym_array[i - 1]
+    skip_layer = False
+
+    if LAYERTYPE.ROUTE == layer.type:
+        sym = []
+        for j in range(layer.n):
+            sym.append(sym_array[layer.input_layers[j]])
+        if layer.n == 1:
+            skip_layer = True
+
+    elif LAYERTYPE.COST == layer.type:
+        skip_layer = True
+
+    elif LAYERTYPE.SHORTCUT == layer.type:
+        sym = [sym, sym_array[layer.index]]
+
+    elif LAYERTYPE.BLANK == layer.type:
+        skip_layer = True
+
+    if skip_layer is True:
+        sym_array[i] = sym
+
+    return skip_layer, sym
+
+def _from_darknet(net, dtype='float32'):
+    """To convert the darknet symbol to nnvm symbols."""
+    sym_array = {}
+    tvmparams = {}
+    for i in range(net.n):
+        need_skip, sym = _preproc_layer(net, i, sym_array)
+        if need_skip is True:
+            continue
+        op_name, attr = _get_darknet_attrs(net, i)
+        layer_name, sym = _darknet_convert_symbol(op_name, _as_list(sym), attr)
+        _get_darknet_params(net.layers[i], layer_name, tvmparams, dtype)
+        sym_array[i] = sym
+
+    return sym, tvmparams
+
+def from_darknet(net, dtype='float32'):
+    """Convert from darknet's model into compatible NNVM format.
+    Reconstruct a nnvm symbol by traversing the darknet input.
+
+    Parameters
+    ----------
+    net : ctype Pointer to network
+        Darknet parsed symbols
+
+    dtype : str
+        Datatype of the input net structure, default is float32
+
+    Returns
+    -------
+    sym : nnvm.Symbol
+        Compatible nnvm symbol
+
+    params : dict of str to tvm.NDArray
+        The parameter dict to be used by nnvm
+    """
+
+    return _from_darknet(net, dtype)
diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py
index 84c51a4df..19c37f7ac 100644
--- a/nnvm/python/nnvm/testing/__init__.py
+++ b/nnvm/python/nnvm/testing/__init__.py
@@ -7,3 +7,5 @@ from . import mobilenet
 from . import mlp
 from . import resnet
 from . import vgg
+from . import darknet
+from . import yolo2_detection
diff --git a/nnvm/python/nnvm/testing/darknet.py b/nnvm/python/nnvm/testing/darknet.py
new file mode 100644
index 000000000..30b790bb4
--- /dev/null
+++ b/nnvm/python/nnvm/testing/darknet.py
@@ -0,0 +1,494 @@
+# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
+"""
+Compile DarkNet Models
+====================
+DarkNet helper functions for darknet model parsing and image loading.
+This functions will not be loaded by default.
+These are utility functions used for testing and tutorial file.
+"""
+from __future__ import division
+from enum import IntEnum
+import math
+import numpy as np
+import cv2
+from cffi import FFI
+
+def _resize_image(img, w_in, h_in):
+    """Resize the image to the given height and width."""
+    imc, imh, imw = img.shape
+    h_in = int(h_in)
+    w_in = int(w_in)
+    part = np.zeros((imc, imh, w_in))
+    resized = np.zeros((imc, h_in, w_in))
+    w_scale = (imw - 1) / (w_in - 1)
+    h_scale = (imh - 1) / (h_in - 1)
+    for k in range(imc):
+        for j in range(imh):
+            for c in range(w_in):
+                if c == w_in - 1 or imw == 1:
+                    part[k][j][c] = img[k][j][imw - 1]
+                else:
+                    fdx, idx = math.modf(c * w_scale)
+                    part[k][j][c] = (1 - fdx) * img[k][j][int(idx)] + \
+                                            fdx * img[k][j][int(idx) + 1]
+    for k in range(imc):
+        for j in range(h_in):
+            fdy, idy = math.modf(j * h_scale)
+            for c in range(w_in):
+                resized[k][j][c] = (1 - fdy)*part[k][int(idy)][c]
+            if (j == h_in - 1) or (imh == 1):
+                continue
+            for c in range(w_in):
+                resized[k][j][c] += fdy * part[k][int(idy) + 1][c]
+    return resized
+
+def load_image_color(test_image):
+    """To load the image using opencv api and do preprocessing."""
+    imagex = cv2.imread(test_image)
+    imagex = np.array(imagex)
+    imagex = imagex.transpose((2, 0, 1))
+    imagex = np.divide(imagex, 255.0)
+    imagex = np.flip(imagex, 0)
+    return imagex
+
+def _letterbox_image(img, w_in, h_in):
+    """To get the image in boxed format."""
+    imc, imh, imw = img.shape
+    if (w_in / imw) < (h_in / imh):
+        new_w = w_in
+        new_h = imh * w_in / imw
+    else:
+        new_h = h_in
+        new_w = imw * h_in/imh
+    resized = _resize_image(img, new_w, new_h)
+    boxed = np.full((imc, h_in, w_in), 0.5, dtype=float)
+    _, resizedh, resizedw = resized.shape
+    boxed[:, int((h_in - new_h) / 2)
+          :int((h_in - new_h) / 2) + resizedh, int((w_in - new_w) / 2)
+          :int((w_in - new_w) / 2) + resizedw] = resized
+    return boxed
+
+def load_image(image, resize_width, resize_height):
+    """Load the image and convert to the darknet model format.
+    The image processing of darknet is different from normal.
+    Parameters
+    ----------
+    image : string
+        The image file name with path
+
+    resize_width : integer
+        The width to which the image needs to be resized
+
+    resize_height : integer
+        The height to which the image needs to be resized
+
+    Returns
+    -------
+    img : Float array
+        Array of processed image
+    """
+
+    img = load_image_color(image)
+    return _letterbox_image(img, resize_width, resize_height)
+
+class LAYERTYPE(IntEnum):
+    """Darknet LAYERTYPE Class constant."""
+    CONVOLUTIONAL = 0
+    DECONVOLUTIONAL = 1
+    CONNECTED = 2
+    MAXPOOL = 3
+    SOFTMAX = 4
+    DETECTION = 5
+    DROPOUT = 6
+    CROP = 7
+    ROUTE = 8
+    COST = 9
+    NORMALIZATION = 10
+    AVGPOOL = 11
+    LOCAL = 12
+    SHORTCUT = 13
+    ACTIVE = 14
+    RNN = 15
+    GRU = 16
+    LSTM = 17
+    CRNN = 18
+    BATCHNORM = 19
+    NETWORK = 20
+    XNOR = 21
+    REGION = 22
+    REORG = 23
+    BLANK = 24
+
+class ACTIVATION(IntEnum):
+    """Darknet ACTIVATION Class constant."""
+    LOGISTIC = 0
+    RELU = 1
+    RELIE = 2
+    LINEAR = 3
+    RAMP = 4
+    TANH = 5
+    PLSE = 6
+    LEAKY = 7
+    ELU = 8
+    LOGGY = 9
+    STAIR = 10
+    HARDTAN = 11
+    LHTAN = 12
+
+__darknetffi__ = FFI()
+
+__darknetffi__.cdef("""
+typedef struct network network;
+typedef struct layer layer;
+
+typedef struct{
+    int *leaf;
+    int n;
+    int *parent;
+    int *child;
+    int *group;
+    char **name;
+
+    int groups;
+    int *group_size;
+    int *group_offset;
+} tree;
+
+typedef enum{
+    LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN
+} ACTIVATION;
+
+
+typedef enum {
+    CONVOLUTIONAL,
+    DECONVOLUTIONAL,
+    CONNECTED,
+    MAXPOOL,
+    SOFTMAX,
+    DETECTION,
+    DROPOUT,
+    CROP,
+    ROUTE,
+    COST,
+    NORMALIZATION,
+    AVGPOOL,
+    LOCAL,
+    SHORTCUT,
+    ACTIVE,
+    RNN,
+    GRU,
+    LSTM,
+    CRNN,
+    BATCHNORM,
+    NETWORK,
+    XNOR,
+    REGION,
+    REORG,
+    BLANK
+} LAYERTYPE;
+
+typedef enum{
+    SSE, MASKED, LONE, SEG, SMOOTH
+} COSTTYPE;
+
+
+struct layer{
+    LAYERTYPE type;
+    ACTIVATION activation;
+    COSTTYPE cost_type;
+    void (*forward);
+    void (*backward);
+    void (*update);
+    void (*forward_gpu);
+    void (*backward_gpu);
+    void (*update_gpu);
+    int batch_normalize;
+    int shortcut;
+    int batch;
+    int forced;
+    int flipped;
+    int inputs;
+    int outputs;
+    int nweights;
+    int nbiases;
+    int extra;
+    int truths;
+    int h,w,c;
+    int out_h, out_w, out_c;
+    int n;
+    int max_boxes;
+    int groups;
+    int size;
+    int side;
+    int stride;
+    int reverse;
+    int flatten;
+    int spatial;
+    int pad;
+    int sqrt;
+    int flip;
+    int index;
+    int binary;
+    int xnor;
+    int steps;
+    int hidden;
+    int truth;
+    float smooth;
+    float dot;
+    float angle;
+    float jitter;
+    float saturation;
+    float exposure;
+    float shift;
+    float ratio;
+    float learning_rate_scale;
+    int softmax;
+    int classes;
+    int coords;
+    int background;
+    int rescore;
+    int objectness;
+    int does_cost;
+    int joint;
+    int noadjust;
+    int reorg;
+    int log;
+    int tanh;
+
+    float alpha;
+    float beta;
+    float kappa;
+
+    float coord_scale;
+    float object_scale;
+    float noobject_scale;
+    float mask_scale;
+    float class_scale;
+    int bias_match;
+    int random;
+    float thresh;
+    int classfix;
+    int absolute;
+
+    int onlyforward;
+    int stopbackward;
+    int dontload;
+    int dontloadscales;
+
+    float temperature;
+    float probability;
+    float scale;
+
+    char  * cweights;
+    int   * indexes;
+    int   * input_layers;
+    int   * input_sizes;
+    int   * map;
+    float * rand;
+    float * cost;
+    float * state;
+    float * prev_state;
+    float * forgot_state;
+    float * forgot_delta;
+    float * state_delta;
+    float * combine_cpu;
+    float * combine_delta_cpu;
+
+    float * concat;
+    float * concat_delta;
+
+    float * binary_weights;
+
+    float * biases;
+    float * bias_updates;
+
+    float * scales;
+    float * scale_updates;
+
+    float * weights;
+    float * weight_updates;
+
+    float * delta;
+    float * output;
+    float * squared;
+    float * norms;
+
+    float * spatial_mean;
+    float * mean;
+    float * variance;
+
+    float * mean_delta;
+    float * variance_delta;
+
+    float * rolling_mean;
+    float * rolling_variance;
+
+    float * x;
+    float * x_norm;
+
+    float * m;
+    float * v;
+
+    float * bias_m;
+    float * bias_v;
+    float * scale_m;
+    float * scale_v;
+
+
+    float *z_cpu;
+    float *r_cpu;
+    float *h_cpu;
+    float * prev_state_cpu;
+
+    float *temp_cpu;
+    float *temp2_cpu;
+    float *temp3_cpu;
+
+    float *dh_cpu;
+    float *hh_cpu;
+    float *prev_cell_cpu;
+    float *cell_cpu;
+    float *f_cpu;
+    float *i_cpu;
+    float *g_cpu;
+    float *o_cpu;
+    float *c_cpu;
+    float *dc_cpu;
+
+    float * binary_input;
+
+    struct layer *input_layer;
+    struct layer *self_layer;
+    struct layer *output_layer;
+
+    struct layer *reset_layer;
+    struct layer *update_layer;
+    struct layer *state_layer;
+
+    struct layer *input_gate_layer;
+    struct layer *state_gate_layer;
+    struct layer *input_save_layer;
+    struct layer *state_save_layer;
+    struct layer *input_state_layer;
+    struct layer *state_state_layer;
+
+    struct layer *input_z_layer;
+    struct layer *state_z_layer;
+
+    struct layer *input_r_layer;
+    struct layer *state_r_layer;
+
+    struct layer *input_h_layer;
+    struct layer *state_h_layer;
+
+    struct layer *wz;
+    struct layer *uz;
+    struct layer *wr;
+    struct layer *ur;
+    struct layer *wh;
+    struct layer *uh;
+    struct layer *uo;
+    struct layer *wo;
+    struct layer *uf;
+    struct layer *wf;
+    struct layer *ui;
+    struct layer *wi;
+    struct layer *ug;
+    struct layer *wg;
+
+    tree *softmax_tree;
+
+    size_t workspace_size;
+};
+
+
+typedef enum {
+    CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
+} LEARNINGRATEPOLICY;
+
+typedef struct network{
+    int n;
+    int batch;
+    size_t *seen;
+    int *t;
+    float epoch;
+    int subdivisions;
+    layer *layers;
+    float *output;
+    LEARNINGRATEPOLICY policy;
+
+    float learning_rate;
+    float momentum;
+    float decay;
+    float gamma;
+    float scale;
+    float power;
+    int time_steps;
+    int step;
+    int max_batches;
+    float *scales;
+    int   *steps;
+    int num_steps;
+    int burn_in;
+
+    int adam;
+    float B1;
+    float B2;
+    float eps;
+
+    int inputs;
+    int outputs;
+    int truths;
+    int notruth;
+    int h, w, c;
+    int max_crop;
+    int min_crop;
+    float max_ratio;
+    float min_ratio;
+    int center;
+    float angle;
+    float aspect;
+    float exposure;
+    float saturation;
+    float hue;
+    int random;
+
+    int gpu_index;
+    tree *hierarchy;
+
+    float *input;
+    float *truth;
+    float *delta;
+    float *workspace;
+    int train;
+    int index;
+    float *cost;
+} network;
+
+
+typedef struct {
+    int w;
+    int h;
+    int c;
+    float *data;
+} image;
+
+network *load_network(char *cfg, char *weights, int clear);
+image letterbox_image(image im, int w, int h);
+int resize_network(network *net, int w, int h);
+void top_predictions(network *net, int n, int *index);
+void free_image(image m);
+image load_image_color(char *filename, int w, int h);
+float *network_predict_image(network *net, image im);
+network *make_network(int n);
+layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
+layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, int batch_normalize, int adam);
+layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride, int padding);
+layer make_avgpool_layer(int batch, int w, int h, int c);
+layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2);
+layer make_batchnorm_layer(int batch, int w, int h, int c);
+layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, int flatten, int extra);
+layer make_region_layer(int batch, int w, int h, int n, int classes, int coords);
+void free_network(network *net);
+"""
+                   )
diff --git a/nnvm/python/nnvm/testing/yolo2_detection.py b/nnvm/python/nnvm/testing/yolo2_detection.py
new file mode 100644
index 000000000..b7744c45c
--- /dev/null
+++ b/nnvm/python/nnvm/testing/yolo2_detection.py
@@ -0,0 +1,246 @@
+# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
+"""
+Yolo detection boxes helper functions
+====================
+DarkNet helper functions for yolo and image loading.
+This functions will not be loaded by default.
+These are utility functions used for testing and tutorial file.
+"""
+from __future__ import division
+import math
+from collections import namedtuple
+import numpy as np
+from PIL import Image
+from PIL import ImageDraw
+from PIL import ImageFont
+
+def _entry_index(batch, w, h, outputs, classes, coords, location, entry):
+    n = int(location/(w*h))
+    loc = location%(w*h)
+    return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
+
+Box = namedtuple('Box', ['x', 'y', 'w', 'h'])
+def _get_region_box(x, biases, n, index, i, j, w, h, stride):
+    b = Box(0, 0, 0, 0)
+    b = b._replace(x=(i + x[index + 0*stride]) / w)
+    b = b._replace(y=(j + x[index + 1*stride]) / h)
+    b = b._replace(w=np.exp(x[index + 2*stride]) * biases[2*n] / w)
+    b = b._replace(h=np.exp(x[index + 3*stride]) * biases[2*n+1] / h)
+    return b
+
+def _correct_region_boxes(boxes, n, w, h, netw, neth, relative):
+    new_w, new_h = (netw, (h*netw)/w) if (netw/w < neth/h) else ((w*neth/h), neth)
+    for i in range(n):
+        b = boxes[i]
+        b = boxes[i]
+        b = b._replace(x=(b.x - (netw - new_w)/2/netw) / (new_w/netw))
+        b = b._replace(y=(b.y - (neth - new_h)/2/neth) / (new_h/neth))
+        b = b._replace(w=b.w * netw/new_w)
+        b = b._replace(h=b.h * neth/new_h)
+        if not relative:
+            b = b._replace(x=b.x * w)
+            b = b._replace(w=b.w * w)
+            b = b._replace(y=b.y * h)
+            b = b._replace(h=b.h * h)
+        boxes[i] = b
+
+def _overlap(x1, w1, x2, w2):
+    l1 = x1 - w1/2
+    l2 = x2 - w2/2
+    left = l1 if l1 > l2 else l2
+    r1 = x1 + w1/2
+    r2 = x2 + w2/2
+    right = r1 if r1 < r2 else r2
+    return right - left
+
+def _box_intersection(a, b):
+    w = _overlap(a.x, a.w, b.x, b.w)
+    h = _overlap(a.y, a.h, b.y, b.h)
+    if w < 0 or h < 0:
+        return 0
+    return w*h
+
+def _box_union(a, b):
+    i = _box_intersection(a, b)
+    u = a.w*a.h + b.w*b.h - i
+    return u
+
+def _box_iou(a, b):
+    return _box_intersection(a, b)/_box_union(a, b)
+
+def get_region_boxes(layer_in, imw, imh, netw, neth, thresh, probs,
+                     boxes, relative, tvm_out):
+    "To get the boxes for the image based on the prediction"
+    lw = layer_in.w
+    lh = layer_in.h
+    probs = [[0 for i in range(layer_in.classes + 1)] for y in range(lw*lh*layer_in.n)]
+    boxes = [Box(0, 0, 0, 0) for i in range(lw*lh*layer_in.n)]
+    for i in range(lw*lh):
+        row = int(i / lw)
+        col = int(i % lw)
+        for n in range(layer_in.n):
+            index = n*lw*lh + i
+            obj_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
+                                     layer_in.coords, n*lw*lh + i, layer_in.coords)
+            box_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
+                                     layer_in.coords, n*lw*lh + i, 0)
+            mask_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
+                                      layer_in.coords, n*lw*lh + i, 4)
+            scale = 1 if layer_in.background  else tvm_out[obj_index]
+            boxes[index] = _get_region_box(tvm_out, layer_in.biases, n, box_index, col,
+                                           row, lw, lh, lw*lh)
+            if not layer_in.softmax_tree:
+                max_element = 0
+                for j in range(layer_in.classes):
+                    class_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
+                                               layer_in.coords, n*lw*lh + i, layer_in.coords+1+j)
+                    prob = scale*tvm_out[class_index]
+                    probs[index][j] = prob if prob > thresh else 0
+                    max_element = max(max_element, prob)
+                probs[index][layer_in.classes] = max_element
+
+    _correct_region_boxes(boxes, lw*lh*layer_in.n, imw, imh, netw, neth, relative)
+    return boxes, probs
+
+
+def do_nms_sort(boxes, probs, total, classes, thresh):
+    "Does the sorting based on the threshold values"
+    SortableBbox = namedtuple('SortableBbox', ['index_var', 'class_var', 'probs'])
+
+    s = [SortableBbox(0, 0, []) for i in range(total)]
+    for i in range(total):
+        s[i] = s[i]._replace(index_var=i)
+        s[i] = s[i]._replace(class_var=0)
+        s[i] = s[i]._replace(probs=probs)
+
+    for k in range(classes):
+        for i in range(total):
+            s[i] = s[i]._replace(class_var=k)
+        s = sorted(s, key=lambda x: x.probs[x.index_var][x.class_var], reverse=True)
+        for i in range(total):
+            if probs[s[i].index_var][k] == 0:
+                continue
+            a = boxes[s[i].index_var]
+            for j in range(i+1, total):
+                b = boxes[s[j].index_var]
+                if _box_iou(a, b) > thresh:
+                    probs[s[j].index_var][k] = 0
+    return boxes, probs
+
+def draw_detections(im, num, thresh, boxes, probs, names, classes):
+    "Draw the markings around the detected region"
+    for i in range(num):
+        labelstr = []
+        category = -1
+        for j in range(classes):
+            if probs[i][j] > thresh:
+                if category == -1:
+                    category = j
+                labelstr.append(names[j])
+        if category > -1:
+            imc, imh, imw = im.shape
+            width = int(imh * 0.006)
+            offset = category*123457 % classes
+            red = _get_color(2, offset, classes)
+            green = _get_color(1, offset, classes)
+            blue = _get_color(0, offset, classes)
+            rgb = [red, green, blue]
+            b = boxes[i]
+            left = int((b.x-b.w/2.)*imw)
+            right = int((b.x+b.w/2.)*imw)
+            top = int((b.y-b.h/2.)*imh)
+            bot = int((b.y+b.h/2.)*imh)
+
+            if left < 0:
+                left = 0
+            if right > imw-1:
+                right = imw-1
+            if top < 0:
+                top = 0
+            if bot > imh-1:
+                bot = imh-1
+            _draw_box_width(im, left, top, right, bot, width, red, green, blue)
+            label = _get_label(''.join(labelstr), rgb)
+            _draw_label(im, top + width, left, label, rgb)
+
+def _get_pixel(im, x, y, c):
+    return im[c][y][x]
+
+def _set_pixel(im, x, y, c, val):
+    if x < 0 or y < 0 or c < 0 or x >= im.shape[2] or y >= im.shape[1] or c >= im.shape[0]:
+        return
+    im[c][y][x] = val
+
+def _draw_label(im, r, c, label, rgb):
+    w = label.shape[2]
+    h = label.shape[1]
+    if (r - h) >= 0:
+        r = r - h
+
+    for j in range(h):
+        if j < h and (j + r) < im.shape[1]:
+            for i in range(w):
+                if i < w and (i + c) < im.shape[2]:
+                    for k in range(label.shape[0]):
+                        val = _get_pixel(label, i, j, k)
+                        _set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)
+
+def _get_label(labelstr, rgb):
+    text = labelstr
+    colorText = "black"
+    testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1)))
+    font = ImageFont.truetype("arial.ttf", 25)
+    width, height = testDraw.textsize(labelstr, font=font)
+    img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255),
+                                                   int(rgb[2]*255)))
+    d = ImageDraw.Draw(img)
+    d.text((0, 0), text, fill=colorText, font=font)
+    opencvImage = np.divide(np.asarray(img), 255)
+    return opencvImage.transpose(2, 0, 1)
+
+def _get_color(c, x, max_value):
+    c = int(c)
+    colors = [[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]]
+    ratio = (float(x)/float(max_value)) * 5
+    i = int(math.floor(ratio))
+    j = int(math.ceil(ratio))
+    ratio -= i
+    r = (1-ratio) * colors[i][c] + ratio*colors[j][c]
+    return r
+
+def _draw_box(im, x1, y1, x2, y2, r, g, b):
+    y1 = int(y1)
+    y2 = int(y2)
+    x1 = int(x1)
+    x2 = int(x2)
+    ac, ah, aw = im.shape
+    if x1 < 0:
+        x1 = 0
+    if x1 >= aw:
+        y1 = 0
+    if y1 >= ah:
+        y1 = ah - 1
+    if y2 < 0:
+        y2 = 0
+    if y2 >= ah:
+        y2 = ah - 1
+
+    for i in range(x1, x2):
+        im[0][y1][i] = r
+        im[0][y2][i] = r
+        im[1][y1][i] = g
+        im[1][y2][i] = g
+        im[2][y1][i] = b
+        im[2][y2][i] = b
+
+    for i in range(y1, y2):
+        im[0][i][x1] = r
+        im[0][i][x2] = r
+        im[1][i][x1] = g
+        im[1][i][x2] = g
+        im[2][i][x1] = b
+        im[2][i][x2] = b
+
+def _draw_box_width(im, x1, y1, x2, y2, w, r, g, b):
+    for i in range(int(w)):
+        _draw_box(im, x1+i, y1+i, x2-i, y2-i, r, g, b)
diff --git a/nnvm/python/nnvm/top/__init__.py b/nnvm/python/nnvm/top/__init__.py
index 273324d1f..12294fa0d 100644
--- a/nnvm/python/nnvm/top/__init__.py
+++ b/nnvm/python/nnvm/top/__init__.py
@@ -7,6 +7,7 @@ from . import tensor
 from . import nn
 from . import transform
 from . import reduction
+from . import vision
 
 from .registry import OpPattern
 from .registry import register_compute, register_schedule, register_pattern
diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py
new file mode 100644
index 000000000..89409de62
--- /dev/null
+++ b/nnvm/python/nnvm/top/vision.py
@@ -0,0 +1,40 @@
+
+# pylint: disable=invalid-name, unused-argument
+"""Definition of nn ops"""
+from __future__ import absolute_import
+
+import topi
+import tvm
+from . import registry as reg
+from .registry import OpPattern
+
+@reg.register_compute("yolo2_reorg")
+def compute_reorg(attrs, inputs, _):
+    """Compute definition of reorg"""
+    return topi.vision.reorg(inputs[0], attrs.get_int("stride"))
+
+@reg.register_schedule("yolo2_reorg")
+def schedule_reorg(attrs, outs, target):
+    """Schedule definition of reorg"""
+    with tvm.target.create(target):
+        return topi.generic.schedule_injective(outs)
+
+reg.register_pattern("yolo2_reorg", OpPattern.INJECTIVE)
+
+@reg.register_compute("yolo2_region")
+def compute_region(attrs, inputs, _):
+    """Compute definition of region"""
+    n = attrs.get_int("n")
+    classes = attrs.get_int("classes")
+    coords = attrs.get_int("coords")
+    background = attrs.get_int("background")
+    softmax = attrs.get_int("softmax")
+    return topi.vision.yolo2.region(inputs[0], n, classes, coords, background, softmax)
+
+@reg.register_schedule("yolo2_region")
+def schedule_region(attrs, outs, target):
+    """Schedule definition of region"""
+    with tvm.target.create(target):
+        return topi.generic.vision.schedule_region(outs)
+
+reg.register_pattern("yolo2_region", OpPattern.OPAQUE)
diff --git a/nnvm/src/top/vision/yolo2/region.cc b/nnvm/src/top/vision/yolo2/region.cc
new file mode 100644
index 000000000..87860be3d
--- /dev/null
+++ b/nnvm/src/top/vision/yolo2/region.cc
@@ -0,0 +1,35 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file region.cc
+ * \brief Property def of pooling operators.
+ */
+#include <nnvm/op.h>
+#include <nnvm/node.h>
+#include <nnvm/op_attr_types.h>
+#include <nnvm/top/nn.h>
+#include "../../op_common.h"
+#include "region.h"
+
+namespace nnvm {
+namespace top {
+
+NNVM_REGISTER_OP(yolo2_region)
+.describe(R"code(Region layer
+)code" NNVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_support_level(5)
+.add_argument("data", "Tensor", "Input data")
+.set_attr<FInferType>("FInferType", RegionType<1, 1>)
+.set_attr<FInferShape>("FInferShape", RegionShape<1, 1>)
+.set_attr<FInplaceOption>(
+    "FInplaceOption",
+    [](const NodeAttrs &attrs) {
+      return std::vector<std::pair<int, int>>{{0, 0}, {1, 0}};
+    })
+.set_attr<FGradient>("FGradient", [](const NodePtr &n,
+                                     const std::vector<NodeEntry> &ograds) {
+  return std::vector<NodeEntry>{ograds[0], ograds[0]};
+});
+}  // namespace top
+}  // namespace nnvm
diff --git a/nnvm/src/top/vision/yolo2/region.h b/nnvm/src/top/vision/yolo2/region.h
new file mode 100644
index 000000000..cc816eab6
--- /dev/null
+++ b/nnvm/src/top/vision/yolo2/region.h
@@ -0,0 +1,101 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file region.h
+ */
+#ifndef NNVM_TOP_VISION_YOLO2_REGION_H_
+#define NNVM_TOP_VISION_YOLO2_REGION_H_
+
+#include <string>
+#include <vector>
+#include <utility>
+#include <iostream>
+#include <sstream>
+
+namespace nnvm {
+namespace top {
+
+template <typename AttrType,
+          bool (*is_none)(const AttrType &),
+          bool (*assign)(AttrType *,
+          const AttrType &),
+          bool reverse_infer,
+          std::string (*attr_string)(const AttrType &),
+          int n_in = -1,
+          int n_out = -1>
+inline bool RegionAttr(const nnvm::NodeAttrs &attrs,
+                       std::vector<AttrType> *in_attrs,
+                       std::vector<AttrType> *out_attrs,
+                       const AttrType &none) {
+  AttrType dattr = none;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  if (n_in != -1) {
+    in_size = static_cast<size_t>(n_in);
+  }
+  if (n_out != -1) {
+    out_size = static_cast<size_t>(n_out);
+  }
+
+  auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
+    for (size_t i = 0; i < size; ++i) {
+      if (i == 0)
+        CHECK(assign(&dattr, (*vec)[i]))
+            << "Incompatible attr in node " << attrs.name << " at " << i
+            << "-th " << name << ": "
+            << "expected " << attr_string(dattr) << ", got "
+            << attr_string((*vec)[i]);
+    }
+  };
+  deduce(in_attrs, in_size, "input");
+
+  auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
+    for (size_t i = 0; i < size; ++i) {
+      CHECK(assign(&(*vec)[i], dattr))
+          << "Incompatible attr in node " << attrs.name << " at " << i << "-th "
+          << name << ": "
+          << "expected " << attr_string(dattr) << ", got "
+          << attr_string((*vec)[i]);
+    }
+  };
+  write(out_attrs, out_size, "output");
+
+  if (is_none(dattr)) {
+    return false;
+  }
+  return true;
+}
+
+template <int n_in, int n_out>
+inline bool RegionShape(const NodeAttrs &attrs,
+                        std::vector<TShape> *in_attrs,
+                        std::vector<TShape> *out_attrs) {
+  if (n_in != -1) {
+    CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
+        << " in operator " << attrs.name;
+  }
+  if (n_out != -1) {
+    CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
+        << " in operator " << attrs.name;
+  }
+  return RegionAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
+      attrs, in_attrs, out_attrs, TShape());
+}
+
+template <int n_in, int n_out>
+inline bool RegionType(const NodeAttrs &attrs,
+                       std::vector<int> *in_attrs,
+                       std::vector<int> *out_attrs) {
+  if (n_in != -1) {
+    CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
+        << " in operator " << attrs.name;
+  }
+  if (n_out != -1) {
+    CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
+        << " in operator " << attrs.name;
+  }
+  return RegionAttr<int, type_is_none, type_assign, true, type_string>(
+      attrs, in_attrs, out_attrs, -1);
+}
+}  // namespace top
+}  // namespace nnvm
+#endif  // NNVM_TOP_VISION_YOLO2_REGION_H_
diff --git a/nnvm/src/top/vision/yolo2/reorg.cc b/nnvm/src/top/vision/yolo2/reorg.cc
new file mode 100644
index 000000000..e58940eb2
--- /dev/null
+++ b/nnvm/src/top/vision/yolo2/reorg.cc
@@ -0,0 +1,52 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file reorg.cc
+ */
+#include <nnvm/op.h>
+#include <nnvm/node.h>
+#include <nnvm/op_attr_types.h>
+#include <nnvm/top/nn.h>
+#include "../../op_common.h"
+#include "../../elemwise_op_common.h"
+#include "reorg.h"
+
+namespace nnvm {
+namespace top {
+
+// reorg
+DMLC_REGISTER_PARAMETER(ReorgParam);
+
+inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs,
+                            std::vector<TShape> *in_shape,
+                            std::vector<TShape> *out_shape) {
+  const ReorgParam &param = nnvm::get<ReorgParam>(attrs.parsed);
+  TShape dshape = in_shape->at(0);
+  if (dshape.ndim() == 0)
+    return false;
+  NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
+  CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D";
+  CHECK_GT(param.stride, 0U) << "Stride value cannot be 0";
+  TShape oshape({dshape[0], 0, 0, 0});
+  oshape[1] = dshape[1] * param.stride * param.stride;
+  oshape[2] = dshape[2] / param.stride;
+  oshape[3] = dshape[3] / param.stride;
+  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
+  return true;
+}
+
+NNVM_REGISTER_OP(yolo2_reorg)
+.describe(R"(Perform reorg operation on input array based on the stride value.
+- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
+- **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride).
+)" NNVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_support_level(5)
+.add_argument("data", "Tensor", "Data input to reorganize")
+.set_attr_parser(ParamParser<ReorgParam>)
+.add_arguments(ReorgParam::__FIELDS__())
+.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReorgParam>)
+.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
+.set_attr<FInferShape>("FInferShape", ReorgInferShape);
+}  // namespace top
+}  // namespace nnvm
diff --git a/nnvm/src/top/vision/yolo2/reorg.h b/nnvm/src/top/vision/yolo2/reorg.h
new file mode 100644
index 000000000..87e0510e2
--- /dev/null
+++ b/nnvm/src/top/vision/yolo2/reorg.h
@@ -0,0 +1,110 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file reorg.h
+ */
+#ifndef NNVM_TOP_VISION_YOLO2_REORG_H_
+#define NNVM_TOP_VISION_YOLO2_REORG_H_
+
+#include <string>
+#include <vector>
+#include <utility>
+#include <iostream>
+#include <sstream>
+
+namespace nnvm {
+namespace top {
+
+template <typename AttrType,
+          bool (*is_none)(const AttrType &),
+          bool (*assign)(AttrType *,
+          const AttrType &),
+          bool reverse_infer,
+          std::string (*attr_string)(const AttrType &),
+          int n_in = -1,
+          int n_out = -1>
+inline bool ReorgAttr(const nnvm::NodeAttrs &attrs,
+                      std::vector<AttrType> *in_attrs,
+                      std::vector<AttrType> *out_attrs,
+                      const AttrType &none) {
+  AttrType dattr = none;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  if (n_in != -1) {
+    in_size = static_cast<size_t>(n_in);
+  }
+  if (n_out != -1) {
+    out_size = static_cast<size_t>(n_out);
+  }
+
+  auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
+    for (size_t i = 0; i < size; ++i) {
+      if (i == 0) {
+        CHECK(assign(&dattr, (*vec)[i]))
+            << "Incompatible attr in node " << attrs.name << " at " << i
+            << "-th " << name << ": "
+            << "expected " << attr_string(dattr) << ", got "
+            << attr_string((*vec)[i]);
+      }
+    }
+  };
+  deduce(in_attrs, in_size, "input");
+
+  auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
+    for (size_t i = 0; i < size; ++i) {
+      CHECK(assign(&(*vec)[i], dattr))
+          << "Incompatible attr in node " << attrs.name << " at " << i << "-th "
+          << name << ": "
+          << "expected " << attr_string(dattr) << ", got "
+          << attr_string((*vec)[i]);
+    }
+  };
+  write(out_attrs, out_size, "output");
+
+  if (is_none(dattr)) {
+    return false;
+  }
+  return true;
+}
+
+template <int n_in, int n_out>
+inline bool ReorgShape(const NodeAttrs &attrs,
+                       std::vector<TShape> *in_attrs,
+                       std::vector<TShape> *out_attrs) {
+  if (n_in != -1) {
+    CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
+        << " in operator " << attrs.name;
+  }
+  if (n_out != -1) {
+    CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
+        << " in operator " << attrs.name;
+  }
+  return ReorgAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
+      attrs, in_attrs, out_attrs, TShape());
+}
+
+template <int n_in, int n_out>
+inline bool ReorgType(const NodeAttrs &attrs,
+                      std::vector<int> *in_attrs,
+                      std::vector<int> *out_attrs) {
+  if (n_in != -1) {
+    CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
+        << " in operator " << attrs.name;
+  }
+  if (n_out != -1) {
+    CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
+        << " in operator " << attrs.name;
+  }
+  return ReorgAttr<int, type_is_none, type_assign, true, type_string>(
+      attrs, in_attrs, out_attrs, -1);
+}
+
+struct ReorgParam : public dmlc::Parameter<ReorgParam> {
+  int stride;
+
+  DMLC_DECLARE_PARAMETER(ReorgParam) {
+    DMLC_DECLARE_FIELD(stride).set_default(1).describe("Stride value");
+  }
+};
+}  // namespace top
+}  // namespace nnvm
+#endif  // NNVM_TOP_VISION_YOLO2_REORG_H_
diff --git a/nnvm/tests/ci_build/Dockerfile.gpu b/nnvm/tests/ci_build/Dockerfile.gpu
index bde32322c..2ee5ed04e 100644
--- a/nnvm/tests/ci_build/Dockerfile.gpu
+++ b/nnvm/tests/ci_build/Dockerfile.gpu
@@ -41,6 +41,9 @@ RUN bash /install/ubuntu_install_coreml.sh
 COPY install/ubuntu_install_keras.sh /install/ubuntu_install_keras.sh
 RUN bash /install/ubuntu_install_keras.sh
 
+COPY install/ubuntu_install_darknet.sh /install/ubuntu_install_darknet.sh
+RUN bash /install/ubuntu_install_darknet.sh
+
 RUN pip install Pillow
 
 # Environment variables
diff --git a/nnvm/tests/ci_build/install/ubuntu_install_darknet.sh b/nnvm/tests/ci_build/install/ubuntu_install_darknet.sh
new file mode 100644
index 000000000..f5e0c2791
--- /dev/null
+++ b/nnvm/tests/ci_build/install/ubuntu_install_darknet.sh
@@ -0,0 +1,4 @@
+#install the necessary dependancies, cffi, opencv
+wget 'https://github.com/siju-samuel/darknet/blob/master/lib/libdarknet.so?raw=true' -O libdarknet.so
+pip2 install opencv-python cffi
+pip3 install opencv-python cffi
diff --git a/nnvm/tests/python/frontend/darknet/test_forward.py b/nnvm/tests/python/frontend/darknet/test_forward.py
new file mode 100644
index 000000000..ad28c49c0
--- /dev/null
+++ b/nnvm/tests/python/frontend/darknet/test_forward.py
@@ -0,0 +1,257 @@
+"""
+Compile Darknet Models
+=====================
+This article is a test script to test darknet models with NNVM.
+All the required models and libraries will be downloaded from the internet
+by the script.
+"""
+import os
+import requests
+import numpy as np
+from nnvm import frontend
+from nnvm.testing.darknet import __darknetffi__
+import nnvm.compiler
+import tvm
+import sys
+import urllib
+if sys.version_info >= (3,):
+    import urllib.request as urllib2
+else:
+    import urllib2
+
+def _download(url, path, overwrite=False, sizecompare=False):
+    ''' Download from internet'''
+    if os.path.isfile(path) and not overwrite:
+        if sizecompare:
+            file_size = os.path.getsize(path)
+            res_head = requests.head(url)
+            res_get = requests.get(url, stream=True)
+            if 'Content-Length' not in res_head.headers:
+                res_get = urllib2.urlopen(url)
+            urlfile_size = int(res_get.headers['Content-Length'])
+            if urlfile_size != file_size:
+                print("exist file got corrupted, downloading", path, " file freshly")
+                _download(url, path, True, False)
+                return
+        print('File {} exists, skip.'.format(path))
+        return
+    print('Downloading from url {} to {}'.format(url, path))
+    try:
+        urllib.request.urlretrieve(url, path)
+        print('')
+    except:
+        urllib.urlretrieve(url, path)
+
+DARKNET_LIB = 'libdarknet.so'
+DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+                                    + DARKNET_LIB + '?raw=true'
+_download(DARKNETLIB_URL, DARKNET_LIB)
+LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
+
+def test_forward(net):
+    '''Test network with given input image on both darknet and tvm'''
+    def get_darknet_output(net, img):
+        return LIB.network_predict_image(net, img)
+
+    def get_tvm_output(net, img):
+        '''Compute TVM output'''
+        dtype = 'float32'
+        batch_size = 1
+        sym, params = frontend.darknet.from_darknet(net, dtype)
+        data = np.empty([batch_size, img.c, img.h, img.w], dtype)
+        i = 0
+        for c in range(img.c):
+            for h in range(img.h):
+                for k in range(img.w):
+                    data[0][c][h][k] = img.data[i]
+                    i = i + 1
+
+        target = 'llvm'
+        shape_dict = {'data': data.shape}
+        #with nnvm.compiler.build_config(opt_level=2):
+        graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params)
+        ######################################################################
+        # Execute on TVM
+        # ---------------
+        # The process is no different from other examples.
+        from tvm.contrib import graph_runtime
+        ctx = tvm.cpu(0)
+        m = graph_runtime.create(graph, library, ctx)
+        # set inputs
+        m.set_input('data', tvm.nd.array(data.astype(dtype)))
+        m.set_input(**params)
+        m.run()
+        # get outputs
+        out_shape = (net.outputs,)
+        tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
+        return tvm_out
+
+    test_image = 'dog.jpg'
+    img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image   +'?raw=true'
+    _download(img_url, test_image)
+    img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h)
+    darknet_output = get_darknet_output(net, img)
+    darknet_out = np.zeros(net.outputs, dtype='float32')
+    for i in range(net.outputs):
+        darknet_out[i] = darknet_output[i]
+    tvm_out = get_tvm_output(net, img)
+    np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-3, atol=1e-3)
+
+def test_forward_extraction():
+    '''test extraction model'''
+    model_name = 'extraction'
+    cfg_name = model_name + '.cfg'
+    weights_name = model_name + '.weights'
+    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
+    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    _download(cfg_url, cfg_name)
+    _download(weights_url, weights_name)
+    net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_alexnet():
+    '''test alexnet model'''
+    model_name = 'alexnet'
+    cfg_name = model_name + '.cfg'
+    weights_name = model_name + '.weights'
+    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
+    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    _download(cfg_url, cfg_name)
+    _download(weights_url, weights_name)
+    net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_resnet50():
+    '''test resnet50 model'''
+    model_name = 'resnet50'
+    cfg_name = model_name + '.cfg'
+    weights_name = model_name + '.weights'
+    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
+    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    _download(cfg_url, cfg_name)
+    _download(weights_url, weights_name)
+    net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_yolo():
+    '''test yolo model'''
+    model_name = 'yolo'
+    cfg_name = model_name + '.cfg'
+    weights_name = model_name + '.weights'
+    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
+    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    _download(cfg_url, cfg_name)
+    _download(weights_url, weights_name)
+    net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_convolutional():
+    '''test convolutional layer'''
+    net = LIB.make_network(1)
+    layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
+    net.layers[0] = layer
+    net.w = net.h = 224
+    LIB.resize_network(net, 224, 224)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_dense():
+    '''test fully connected layer'''
+    net = LIB.make_network(1)
+    layer = LIB.make_connected_layer(1, 75, 20, 1, 0, 0)
+    net.layers[0] = layer
+    net.w = net.h = 5
+    LIB.resize_network(net, 5, 5)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_maxpooling():
+    '''test maxpooling layer'''
+    net = LIB.make_network(1)
+    layer = LIB.make_maxpool_layer(1, 224, 224, 3, 2, 2, 0)
+    net.layers[0] = layer
+    net.w = net.h = 224
+    LIB.resize_network(net, 224, 224)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_avgpooling():
+    '''test avgerage pooling layer'''
+    net = LIB.make_network(1)
+    layer = LIB.make_avgpool_layer(1, 224, 224, 3)
+    net.layers[0] = layer
+    net.w = net.h = 224
+    LIB.resize_network(net, 224, 224)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_batch_norm():
+    '''test batch normalization layer'''
+    net = LIB.make_network(1)
+    layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 1, 0, 0, 0)
+    for i in range(32):
+        layer.rolling_mean[i] = np.random.rand(1)
+        layer.rolling_variance[i] = np.random.rand(1)
+    net.layers[0] = layer
+    net.w = net.h = 224
+    LIB.resize_network(net, 224, 224)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_shortcut():
+    '''test shortcut layer'''
+    net = LIB.make_network(3)
+    layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
+    layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0)
+    layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32)
+    layer_3.activation = 1
+    net.layers[0] = layer_1
+    net.layers[1] = layer_2
+    net.layers[2] = layer_3
+    net.w = net.h = 224
+    LIB.resize_network(net, 224, 224)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_reorg():
+    '''test reorg layer'''
+    net = LIB.make_network(2)
+    layer_1 = LIB.make_convolutional_layer(1, 222, 222, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
+    layer_2 = LIB.make_reorg_layer(1, 110, 110, 32, 2, 0, 0, 0)
+    net.layers[0] = layer_1
+    net.layers[1] = layer_2
+    net.w = net.h = 222
+    LIB.resize_network(net, 222, 222)
+    test_forward(net)
+    LIB.free_network(net)
+
+def test_forward_region():
+    '''test region layer'''
+    net = LIB.make_network(2)
+    layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 8, 1, 3, 2, 0, 1, 0, 0, 0, 0)
+    layer_2 = LIB.make_region_layer(1, 111, 111, 2, 2, 1)
+    layer_2.softmax = 1
+    net.layers[0] = layer_1
+    net.layers[1] = layer_2
+    net.w = net.h = 224
+    LIB.resize_network(net, 224, 224)
+    test_forward(net)
+    LIB.free_network(net)
+
+if __name__ == '__main__':
+    test_forward_resnet50()
+    test_forward_alexnet()
+    test_forward_extraction()
+    test_forward_yolo()
+    test_forward_convolutional()
+    test_forward_maxpooling()
+    test_forward_avgpooling()
+    test_forward_batch_norm()
+    test_forward_shortcut()
+    test_forward_dense()
+    test_forward_reorg()
+    test_forward_region()
diff --git a/nnvm/tutorials/from_darknet.py b/nnvm/tutorials/from_darknet.py
new file mode 100644
index 000000000..b10327168
--- /dev/null
+++ b/nnvm/tutorials/from_darknet.py
@@ -0,0 +1,227 @@
+"""
+Tutorial for running Yolo-V2 in Darknet Models
+=====================
+**Author**: `Siju Samuel <https://siju-samuel.github.io/>`_
+
+This article is an introductory tutorial to deploy darknet models with NNVM.
+
+All the required models and libraries will be downloaded from the internet
+
+by the script.
+
+This script runs the YOLO-V2 Model with the bounding boxes
+
+Darknet parsing have dependancy with CFFI and CV2 library
+
+Please install CFFI and CV2 before executing this script
+
+pip install cffi
+
+pip install opencv-python
+"""
+from ctypes import *
+import math
+import random
+import nnvm
+import nnvm.frontend.darknet
+from nnvm.testing.darknet import __darknetffi__
+import matplotlib.pyplot as plt
+import numpy as np
+import tvm
+import os, sys, time, urllib, requests
+if sys.version_info >= (3,):
+    import urllib.request as urllib2
+    import urllib.parse as urlparse
+else:
+    import urllib2
+    import urlparse
+
+######################################################################
+# Set the parameters here.
+# Supported models alexnet, resnet50, resnet152, extraction, yolo
+######################################################################
+model_name = 'yolo'
+test_image = 'dog.jpg'
+target = 'llvm'
+ctx = tvm.cpu(0)
+######################################################################
+
+def dlProgress(count, block_size, total_size):
+    """Show the download progress."""
+    global start_time
+    if count == 0:
+        start_time = time.time()
+        return
+    duration = time.time() - start_time
+    progress_size = int(count * block_size)
+    speed = int(progress_size / (1024 * duration))
+    percent = int(count * block_size * 100 / total_size)
+    sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
+          (percent, progress_size / (1024 * 1024), speed, duration))
+    sys.stdout.flush()
+
+def download(url, path, overwrite=False, sizecompare=False):
+    """Downloads the file from the internet.
+    Set the input options correctly to overwrite or do the size comparison
+
+    Parameters
+    ----------
+    url : str
+        Operator name, such as Convolution, Connected, etc
+    path : str
+        List of input symbols.
+    overwrite : dict
+        Dict of operator attributes
+    sizecompare : dict
+        Dict of operator attributes
+
+    Returns
+    -------
+    out_name : converted out name of operation
+    sym : nnvm.Symbol
+        Converted nnvm Symbol
+    """
+    if os.path.isfile(path) and not overwrite:
+        if (sizecompare):
+            fileSize = os.path.getsize(path)
+            resHead = requests.head(url)
+            resGet = requests.get(url,stream=True)
+            if 'Content-Length' not in resHead.headers :
+                resGet = urllib2.urlopen(url)
+            urlFileSize = int(resGet.headers['Content-Length'])
+            if urlFileSize != fileSize:
+                print ("exist file got corrupted, downloading", path , " file freshly")
+                download(url, path, True, False)
+                return
+        print('File {} exists, skip.'.format(path))
+        return
+    print('Downloading from url {} to {}'.format(url, path))
+    try:
+        urllib.request.urlretrieve(url, path, reporthook=dlProgress)
+        print('')
+    except:
+        urllib.urlretrieve(url, path, reporthook=dlProgress)
+
+######################################################################
+# Prepare cfg and weights file
+# Pretrained model available https://pjreddie.com/darknet/imagenet/
+# --------------------------------------------------------------------
+# Download cfg and weights file first time.
+
+cfg_name = model_name + '.cfg'
+weights_name = model_name + '.weights'
+cfg_url = 'https://github.com/siju-samuel/darknet/blob/master/cfg/' + \
+            cfg_name + '?raw=true'
+weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+
+download(cfg_url, cfg_name)
+download(weights_url, weights_name)
+
+######################################################################
+# Download and Load darknet library
+# ---------------------------------
+
+darknet_lib = 'libdarknet.so'
+darknetlib_url = 'https://github.com/siju-samuel/darknet/blob/master/lib/' + \
+                        darknet_lib + '?raw=true'
+download(darknetlib_url, darknet_lib)
+
+#if the file doesnt exist, then exit normally.
+if os.path.isfile('./' + darknet_lib) is False:
+    exit(0)
+
+darknet_lib = __darknetffi__.dlopen('./' + darknet_lib)
+cfg = "./" + str(cfg_name)
+weights = "./" + str(weights_name)
+net = darknet_lib.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
+dtype = 'float32'
+batch_size = 1
+print("Converting darknet to nnvm symbols...")
+sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)
+
+######################################################################
+# Compile the model on NNVM
+# --------------------------------------------------------------------
+# compile the model
+data = np.empty([batch_size, net.c ,net.h, net.w], dtype);
+shape = {'data': data.shape}
+print("Compiling the model...")
+with nnvm.compiler.build_config(opt_level=2):
+    graph, lib, params = nnvm.compiler.build(sym, target, shape, dtype, params)
+
+#####################################################################
+# Save the json
+# --------------------------------------------------------------------
+def save_lib():
+    #Save the graph, params and .so to the current directory
+    print("Saving the compiled output...")
+    path_name = 'nnvm_darknet_' + model_name
+    path_lib = path_name + '_deploy_lib.so'
+    lib.export_library(path_lib)
+    with open(path_name
++ "deploy_graph.json", "w") as fo:
+        fo.write(graph.json())
+    with open(path_name
++ "deploy_param.params", "wb") as fo:
+        fo.write(nnvm.compiler.save_param_dict(params))
+#save_lib()
+
+######################################################################
+# Load a test image
+# --------------------------------------------------------------------
+print("Loading the test image...")
+img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \
+            test_image   +'?raw=true'
+download(img_url, test_image)
+
+data = nnvm.testing.darknet.load_image(test_image, net.w, net.h)
+
+######################################################################
+# Execute on TVM
+# --------------------------------------------------------------------
+# The process is no different from other examples.
+from tvm.contrib import graph_runtime
+
+m = graph_runtime.create(graph, lib, ctx)
+
+# set inputs
+m.set_input('data', tvm.nd.array(data.astype(dtype)))
+m.set_input(**params)
+# execute
+print("Running the test image...")
+
+m.run()
+# get outputs
+out_shape = (net.outputs,)
+tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
+
+#do the detection and bring up the bounding boxes
+thresh = 0.24
+hier_thresh = 0.5
+img = nnvm.testing.darknet.load_image_color(test_image)
+_, im_h, im_w = img.shape
+probs= []
+boxes = []
+region_layer = net.layers[net.n - 1]
+boxes, probs = nnvm.testing.yolo2_detection.get_region_boxes(region_layer, im_w, im_h, net.w, net.h,
+                       thresh, probs, boxes, 1, tvm_out)
+
+boxes, probs = nnvm.testing.yolo2_detection.do_nms_sort(boxes, probs,
+                       region_layer.w*region_layer.h*region_layer.n, region_layer.classes, 0.3)
+
+coco_name = 'coco.names'
+coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name   +'?raw=true'
+font_name = 'arial.ttf'
+font_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + font_name   +'?raw=true'
+download(coco_url, coco_name)
+download(font_url, font_name)
+
+with open(coco_name) as f:
+    content = f.readlines()
+
+names = [x.strip() for x in content]
+
+nnvm.testing.yolo2_detection.draw_detections(img, region_layer.w*region_layer.h*region_layer.n,
+                 thresh, boxes, probs, names, region_layer.classes)
+plt.imshow(img.transpose(1,2,0))
+plt.show()
-- 
GitLab