From edac6a8dac9c4222ff9a719a4b2ec2fc06031655 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sat, 14 Apr 2018 19:25:26 -0700
Subject: [PATCH] Refactor, refactor code structure, fix pynq rpc (#29)

---
 .../resnet18/pynq/imagenet_predict.py         |  16 +-
 vta/python/vta/__init__.py                    |   5 +-
 vta/python/vta/exec/rpc_server.py             |  10 +-
 vta/python/vta/top/__init__.py                |   5 +
 vta/python/vta/{ => top}/arm_conv2d.py        |  10 +-
 vta/python/vta/{ => top}/vta_conv2d.py        |   7 +-
 .../integration/test_benchmark_topi_conv2d.py | 155 ++++++++++++++++++
 vta/tests/python/pynq/test_benchmark_topi.py  | 146 -----------------
 8 files changed, 185 insertions(+), 169 deletions(-)
 create mode 100644 vta/python/vta/top/__init__.py
 rename vta/python/vta/{ => top}/arm_conv2d.py (97%)
 rename vta/python/vta/{ => top}/vta_conv2d.py (98%)
 create mode 100644 vta/tests/python/integration/test_benchmark_topi_conv2d.py
 delete mode 100644 vta/tests/python/pynq/test_benchmark_topi.py

diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py
index eb660ea12..ae8d49017 100644
--- a/vta/examples/resnet18/pynq/imagenet_predict.py
+++ b/vta/examples/resnet18/pynq/imagenet_predict.py
@@ -37,10 +37,10 @@ remote = rpc.connect(host, port)
 vta.program_fpga(remote, BITSTREAM_FILE)
 
 if verbose:
-    logging.basicConfig(level=logging.INFO)
+    logging.basicConfig(level=logging.DEBUG)
 
-# Change to -device=vta-cpu to run cpu only inference.
-target = "llvm -device=vta"
+# Change to -device=vtacpu to run cpu only inference.
+target = tvm.target.create("llvm -device=vta")
 target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
 
 synset = eval(open(os.path.join(CATEG_FILE)).read())
@@ -109,7 +109,7 @@ dtype = "float32"
 sym = vta.graph.remove_stochastic(sym)
 sym = vta.graph.clean_cast(sym)
 sym = vta.graph.clean_conv_fuse(sym)
-if "vta" in target:
+if target.device_name == "vta":
     sym = vta.graph.pack(sym, shape_dict, factor)
 
 graph_attr.set_shape_inputs(sym, shape_dict)
@@ -118,10 +118,10 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
 sym = sym.apply("InferType")
 
 with nnvm.compiler.build_config(opt_level=3):
-    if "vta" not in target:
+    if target.device_name != "vta":
         graph, lib, params = nnvm.compiler.build(
-            sym, target, shape_dict, dtype_dict,
-            params=params, target_host=target_host)
+            sym, target_host, shape_dict, dtype_dict,
+            params=params)
     else:
         with vta.build_config():
             graph, lib, params = nnvm.compiler.build(
@@ -133,7 +133,7 @@ temp = util.tempdir()
 lib.save(temp.relpath("graphlib.o"))
 remote.upload(temp.relpath("graphlib.o"))
 lib = remote.load_module("graphlib.o")
-ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0)
+ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
 
 print("Build complete...")
 
diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py
index d0a3279f1..4be16ccfb 100644
--- a/vta/python/vta/__init__.py
+++ b/vta/python/vta/__init__.py
@@ -3,11 +3,12 @@ from __future__ import absolute_import as _abs
 
 
 from .environment import get_env, Environment
-from . import arm_conv2d, vta_conv2d
-from .build_module import build_config, lower, build
 from .rpc_client import reconfig_runtime, program_fpga
 
+
 try:
+    from . import top
+    from .build_module import build_config, lower, build
     from . import graph
 except (ImportError, RuntimeError):
     pass
diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py
index 014b40564..f3db38c92 100644
--- a/vta/python/vta/exec/rpc_server.py
+++ b/vta/python/vta/exec/rpc_server.py
@@ -75,20 +75,20 @@ def server_start():
         pkg = PkgConfig(cfg, proj_root)
         # check if the configuration is already the same
         if os.path.isfile(cfg_path):
-            old_cfg = json.load(open(cfg_path))
+            old_cfg = json.loads(open(cfg_path, "r").read())
             if pkg.same_config(old_cfg):
-                logging.info("Skip reconfiguration because runtime config is the same")
+                logging.info("Skip reconfig_runtime due to same config.")
                 return
-        cflags += ["-O2", "-std=c++11"]
+        cflags = ["-O2", "-std=c++11"]
         cflags += pkg.cflags
         ldflags = pkg.ldflags
         lib_name = dll_path
-        source = env.pkg_config.lib_source
+        source = pkg.lib_source
         logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
                      dll_path, str(cflags), str(source), str(ldflags))
         cc.create_shared(lib_name, source, cflags + ldflags)
         with open(cfg_path, "w") as outputfile:
-            json.dump(pkg.cfg_json, outputfile)
+            outputfile.write(pkg.cfg_json)
 
 
 def main():
diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py
new file mode 100644
index 000000000..614ed2347
--- /dev/null
+++ b/vta/python/vta/top/__init__.py
@@ -0,0 +1,5 @@
+"""TVM TOPI connector, eventually most of these should go to TVM repo"""
+
+from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
+from . import vta_conv2d
+from . import arm_conv2d
diff --git a/vta/python/vta/arm_conv2d.py b/vta/python/vta/top/arm_conv2d.py
similarity index 97%
rename from vta/python/vta/arm_conv2d.py
rename to vta/python/vta/top/arm_conv2d.py
index 9e46ee7f8..c959f1ee9 100644
--- a/vta/python/vta/arm_conv2d.py
+++ b/vta/python/vta/top/arm_conv2d.py
@@ -44,7 +44,7 @@ _SCHEDULES = [
     Im2ColPack(7, 4, 1, 16, False),
 ]
 
-@_get_schedule.register(["tcpu", "vta"])
+@_get_schedule.register(["vtacpu", "vta"])
 def _schedule_conv2d(wkl):
     if wkl not in _WORKLOADS:
         raise ValueError("no schedule for such workload: {}".format(wkl))
@@ -53,10 +53,10 @@ def _schedule_conv2d(wkl):
     return sch
 
 
-@conv2d.register(["tcpu", "vta"])
+@conv2d.register(["vtacpu", "vta"])
 def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
-    assert layout == 'NCHW', "only support NCHW convolution on tcpu"
-    assert data.shape[0].value == 1, "only support batch size=1 convolution on tcpu"
+    assert layout == 'NCHW', "only support NCHW convolution on vtacpu"
+    assert data.shape[0].value == 1, "only support batch size=1 convolution on vtacpu"
     wkl = _get_workload(data, kernel, stride, padding, out_dtype)
     sch = _get_schedule(wkl)
     return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
@@ -284,7 +284,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
 
     return s
 
-@generic.schedule_conv2d_nchw.register(["tcpu", "vta"])
+@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
 def schedule_conv2d(outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
diff --git a/vta/python/vta/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py
similarity index 98%
rename from vta/python/vta/vta_conv2d.py
rename to vta/python/vta/top/vta_conv2d.py
index 0baca7ba5..577eac8e1 100644
--- a/vta/python/vta/vta_conv2d.py
+++ b/vta/python/vta/top/vta_conv2d.py
@@ -1,4 +1,5 @@
 """Namespace for supporting packed_conv2d + ewise variant of nnvm."""
+from __future__ import absolute_import as _abs
 
 from collections import namedtuple
 
@@ -7,7 +8,7 @@ import tvm
 import topi
 
 from nnvm.top import registry as reg, OpPattern
-from . import environment as vta
+from ..environment import get_env
 
 
 Workload = namedtuple("Conv2DWorkload",
@@ -219,7 +220,7 @@ def schedule_packed_conv2d(outs):
     wrkld = _get_workload(data, pad_data, kernel, output)
 
     plan = _WL2PLAN[wrkld]
-    env = vta.get_env()
+    env = get_env()
 
     load_inp = load_wgt = load_out = store_out = env.dma_copy
     alu = env.alu
@@ -251,7 +252,7 @@ def schedule_packed_conv2d(outs):
 
     # tile
     oc_factor = (plan.oc_factor if plan.oc_factor
-                 else wrkld.out_filter // vta.BLOCK_OUT)
+                 else plan.out_filter // env.BLOCK_OUT)
     h_factor = (plan.h_factor if plan.h_factor else oshape[2])
     w_factor = (plan.w_factor if plan.w_factor else oshape[3])
 
diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py
new file mode 100644
index 000000000..0a5edfdc7
--- /dev/null
+++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py
@@ -0,0 +1,155 @@
+"""Testing if we can generate code in topi style"""
+
+import tvm
+from tvm.contrib import util
+from tvm.contrib.pickle_memoize import memoize
+import topi
+import topi.testing
+import vta
+import vta.testing
+import numpy as np
+
+Workload = vta.top.vta_conv2d.Workload
+
+@tvm.tag_scope(tag=topi.tag.ELEMWISE)
+def my_clip(x, a_min, a_max):
+    """Unlike topi's current clip, put min and max into two stages."""
+    const_min = tvm.const(a_min, x.dtype)
+    const_max = tvm.const(a_max, x.dtype)
+    x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
+    x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
+    return x
+
+
+def test_vta_conv2d():
+    def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
+        data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
+                      wl.height, wl.width, env.BLOCK_IN)
+        kernel_shape = (wl.out_filter // env.BLOCK_OUT,
+                        wl.in_filter // env.BLOCK_IN,
+                        wl.hkernel, wl.wkernel,
+                        env.BLOCK_OUT, env.BLOCK_IN)
+        bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
+
+
+        fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
+        fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
+        data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
+        kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
+        bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
+
+        res_conv = vta.top.packed_conv2d(
+            data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
+        res = topi.right_shift(res_conv, 8)
+        res = topi.broadcast_add(res, bias)
+        res = my_clip(res, 0, 127)
+        res = topi.cast(res, "int8")
+
+        num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
+
+        a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
+        w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
+        stride = (wl.hstride, wl.wstride)
+        data_dtype = data.dtype
+        acc_dtype = env.acc_dtype
+        assert wl.hpad == wl.wpad
+        padding = wl.hpad
+
+        @memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc")
+        def get_ref_data():
+            a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
+            w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype)
+            a_np = np.abs(a_np)
+            w_np = np.abs(w_np)
+            b_np = topi.testing.conv2d_nchw_python(
+                a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype)
+            return a_np, w_np, b_np
+
+
+        def verify(s, check_correctness):
+            mod = vta.build(s, [data, kernel, bias, res], "ext_dev",
+                            env.target_host, name="conv2d")
+            temp = util.tempdir()
+
+            mod.save(temp.relpath("conv2d.o"))
+            remote.upload(temp.relpath("conv2d.o"))
+            f = remote.load_module("conv2d.o")
+            # verify
+            ctx = remote.ext_dev(0)
+            # Data in original format
+            data_orig, kernel_orig, res_ref = get_ref_data()
+            bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32")
+            bias_orig = np.abs(bias_orig)
+
+            data_packed = data_orig.reshape(
+                batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
+                wl.height, wl.width).transpose((0, 1, 3, 4, 2))
+            kernel_packed = kernel_orig.reshape(
+                wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
+                wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
+                wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
+            bias_packed = bias_orig.reshape(
+                wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
+            res_shape = topi.util.get_const_tuple(res.shape)
+
+            res_np = np.zeros(res_shape).astype(res.dtype)
+            data_arr = tvm.nd.array(data_packed, ctx)
+            kernel_arr = tvm.nd.array(kernel_packed, ctx)
+            bias_arr = tvm.nd.array(bias_packed, ctx)
+            res_arr = tvm.nd.array(res_np, ctx)
+            time_f = f.time_evaluator("conv2d", ctx, number=5)
+            cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
+            res_unpack = res_arr.asnumpy().transpose(
+                (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
+            if check_correctness:
+                assert wl.hpad == wl.wpad
+                stride = (wl.hstride, wl.wstride)
+                padding = wl.hpad
+                res_ref = res_ref >> 8
+                res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
+                res_ref = np.clip(res_ref, 0, 127).astype("int8")
+                np.testing.assert_allclose(res_unpack, res_ref)
+            return cost
+
+        def conv_normal(print_ir):
+            print("----- CONV2D End-to-End Test-------")
+            with vta.build_config():
+                s = vta.top.schedule_packed_conv2d([res])
+                if print_ir:
+                    print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
+            cost = verify(s, True)
+            gops = (num_ops / cost.mean) / float(10 ** 9)
+            print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
+
+        conv_normal(False)
+
+    def _run(env, remote):
+        # ResNet18 workloads
+        resnet = {
+            # Workloads of resnet18 on imagenet
+            0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
+            1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
+            2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
+            3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
+            4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
+            5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
+            6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
+            7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
+            8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
+            9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
+            10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
+            11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
+        }
+
+        batch_size = 1
+        for i in range(0, len(resnet)):
+            wl = resnet[i]
+            key = "resnet-cfg[%d]" % i
+            print("key=%s" % key)
+            print(wl)
+            run_vta_conv2d(env, remote, key, batch_size, wl)
+    vta.testing.run(_run)
+
+
+if __name__ == "__main__":
+    test_vta_conv2d()
diff --git a/vta/tests/python/pynq/test_benchmark_topi.py b/vta/tests/python/pynq/test_benchmark_topi.py
deleted file mode 100644
index 3e2d19a67..000000000
--- a/vta/tests/python/pynq/test_benchmark_topi.py
+++ /dev/null
@@ -1,146 +0,0 @@
-"""Testing if we can generate code in topi style"""
-
-import topi
-import tvm
-from tvm.contrib import util, rpc
-import vta
-from vta import vta_conv2d
-import numpy as np
-import mxnet as mx
-
-Workload = vta_conv2d.Workload
-
-@tvm.tag_scope(tag=topi.tag.ELEMWISE)
-def my_clip(x, a_min, a_max):
-    """Unlike topi's current clip, put min and max into two stages."""
-    const_min = tvm.const(a_min, x.dtype)
-    const_max = tvm.const(a_max, x.dtype)
-    x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
-    x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
-    return x
-
-host = "pynq"
-port = 9091
-target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
-print_ir = False
-
-
-def test_vta_conv2d(key, batch_size, wl, profile=True):
-    env = vta.get_env()
-    data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
-                  wl.height, wl.width, env.BLOCK_IN)
-    kernel_shape = (wl.out_filter // env.BLOCK_OUT,
-                    wl.in_filter // env.BLOCK_IN,
-                    wl.hkernel, wl.wkernel,
-                    env.BLOCK_OUT, env.BLOCK_IN)
-    bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
-
-
-    fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
-    fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
-    data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
-    kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
-    bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
-
-    res_conv = vta_conv2d.packed_conv2d(
-        data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
-    res = topi.right_shift(res_conv, 8)
-    res = topi.broadcast_add(res, bias)
-    res = my_clip(res, 0, 127)
-    res = topi.cast(res, "int8")
-
-    num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
-
-    def verify(s, check_correctness):
-        mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d")
-        temp = util.tempdir()
-        remote = rpc.connect(host, port)
-
-        mod.save(temp.relpath("conv2d.o"))
-        remote.upload(temp.relpath("conv2d.o"))
-        f = remote.load_module("conv2d.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        # Data in original format
-        data_orig = (np.random.uniform(
-            size=(batch_size, wl.in_filter, wl.height, wl.width)) * 4).astype(data.dtype)
-        kernel_orig = (np.random.uniform(
-            size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)) * 4).astype(kernel.dtype)
-        bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32")
-
-        data_orig = np.abs(data_orig)
-        kernel_orig = np.abs(kernel_orig)
-        bias_orig = np.abs(bias_orig)
-
-        data_packed = data_orig.reshape(
-            batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
-            wl.height, wl.width).transpose((0, 1, 3, 4, 2))
-        kernel_packed = kernel_orig.reshape(
-            wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
-            wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
-            wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
-        bias_packed = bias_orig.reshape(
-            wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
-        res_shape = topi.util.get_const_tuple(res.shape)
-
-        res_np = np.zeros(res_shape).astype(res.dtype)
-        data_arr = tvm.nd.array(data_packed, ctx)
-        kernel_arr = tvm.nd.array(kernel_packed, ctx)
-        bias_arr = tvm.nd.array(bias_packed, ctx)
-        res_arr = tvm.nd.array(res_np, ctx)
-        time_f = f.time_evaluator("conv2d", ctx, number=10)
-        cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
-        res_unpack = res_arr.asnumpy().transpose(
-            (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
-        if check_correctness:
-            res_ref = mx.nd.Convolution(
-                mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)),
-                mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)),
-                stride=(wl.hstride, wl.wstride),
-                kernel=(wl.hkernel, wl.wkernel),
-                num_filter=wl.out_filter,
-                no_bias=True,
-                pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype)
-            res_ref = res_ref >> 8
-            res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
-            res_ref = np.clip(res_ref, 0, 127).astype("int8")
-            np.testing.assert_allclose(res_unpack, res_ref)
-            print("Correctness check pass...")
-        return cost
-
-    def conv_normal(print_ir):
-        print("----- CONV2D End-to-End Test-------")
-        with vta.build_config():
-            s = vta_conv2d.schedule_packed_conv2d([res])
-            if print_ir:
-                print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
-            cost = verify(s, True)
-        gops = (num_ops / cost.mean) / float(10 ** 9)
-        print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
-
-    conv_normal(print_ir)
-
-# ResNet18 workloads
-resnet = {
-    # Workloads of resnet18 on imagenet
-    0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
-    1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
-    2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
-    3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
-    4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
-    5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
-    6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
-    7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
-    8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
-    9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
-    10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
-    11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
-}
-
-batch_size = 1
-for i in range(0, len(resnet)):
-    wl = resnet[i]
-    key = "resnet-cfg[%d]" % i
-    print "key=%s" % key
-    print wl
-    test_vta_conv2d(key, batch_size, wl)
-- 
GitLab