From 8c9758b6064bc2ebb2ea82a86b31701f03e1067f Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Mon, 21 May 2018 12:23:31 -0700
Subject: [PATCH] Update Graph Support for Batching, Fix Swapping (#37)

* fix graph transform for batch dimension

* fix

* fix
---
 .../resnet18/pynq/imagenet_predict.py         | 25 ++++--
 vta/python/vta/graph.py                       | 76 +++++++++++--------
 vta/src/runtime.cc                            |  4 +-
 vta/src/sim/sim_driver.cc                     |  1 +
 4 files changed, 67 insertions(+), 39 deletions(-)

diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py
index e4f82b175..554cceabd 100644
--- a/vta/examples/resnet18/pynq/imagenet_predict.py
+++ b/vta/examples/resnet18/pynq/imagenet_predict.py
@@ -3,6 +3,7 @@ import nnvm
 import tvm
 from nnvm.compiler import graph_attr
 import vta
+import vta.testing
 import os
 import numpy as np
 from PIL import Image
@@ -12,7 +13,8 @@ import logging
 import wget
 from tvm.contrib import graph_runtime, rpc, util
 
-factor = 16
+bfactor = 1
+cfactor = 16
 host = "pynq"
 port = 9091
 verbose = False
@@ -38,6 +40,10 @@ if verbose:
 target = tvm.target.create("llvm -device=vta")
 target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
 
+if vta.get_env().TARGET == "sim":
+    target_host = "llvm"
+
+
 synset = eval(open(os.path.join(CATEG_FILE)).read())
 image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
 
@@ -105,7 +111,7 @@ sym = vta.graph.remove_stochastic(sym)
 sym = vta.graph.clean_cast(sym)
 sym = vta.graph.clean_conv_fuse(sym)
 if target.device_name == "vta":
-    sym = vta.graph.pack(sym, shape_dict, factor)
+    sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor)
 
 graph_attr.set_shape_inputs(sym, shape_dict)
 sym = sym.apply("InferShape")
@@ -127,7 +133,13 @@ with nnvm.compiler.build_config(opt_level=3):
 assert tvm.module.enabled("rpc")
 temp = util.tempdir()
 lib.save(temp.relpath("graphlib.o"))
-remote = rpc.connect(host, port)
+
+if vta.get_env().TARGET == "sim":
+    remote = rpc.LocalSession()
+    print("local session")
+else:
+    remote = rpc.connect(host, port)
+
 remote.upload(temp.relpath("graphlib.o"))
 lib = remote.load_module("graphlib.o")
 ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
@@ -154,16 +166,17 @@ def run_e2e(graph):
     print("t-cost=%g" % tcost.mean)
 
 
-def run_layer(old_graph):
+def run_layer(old_graph, layer_begin, layer_end):
     """Run a certain layer."""
-    for layer_id in range(1, 2):
+    for layer_id in range(layer_begin, layer_end):
+        print("run resnet[%d]..."% (layer_id))
         graph = mark_nop(old_graph, layer_id)
         m = graph_runtime.create(graph, lib, ctx)
         # set inputs
         m.set_input('data', tvm.nd.array(x.astype("float32")))
         m.set_input(**params)
         # execute
-        timer = m.module.time_evaluator("run", ctx, number=10)
+        timer = m.module.time_evaluator("run", ctx, number=1)
         tcost = timer()
         print("resnet[%d]: %g\n"% (layer_id, tcost.mean))
 
diff --git a/vta/python/vta/graph.py b/vta/python/vta/graph.py
index b8237980d..41a38c2b6 100644
--- a/vta/python/vta/graph.py
+++ b/vta/python/vta/graph.py
@@ -10,51 +10,58 @@ import nnvm
 from nnvm.compiler import graph_attr, graph_util
 
 
-def _pack_channel(data, dshape, factor):
+def _pack_batch_channel(data, dshape, bfactor, cfactor):
     """Pack the data channel dimension.
     """
-    assert dshape[1] % factor == 0
+    assert dshape[0] % bfactor == 0
+    assert dshape[1] % cfactor == 0
     data = nnvm.sym.reshape(data,
-                            shape=(dshape[0], dshape[1] // factor,
-                                   factor, dshape[2], dshape[3]))
+                            shape=(dshape[0] // bfactor, bfactor,
+                                   dshape[1] // cfactor, cfactor,
+                                   dshape[2], dshape[3]))
     data = nnvm.sym.transpose(
-        data, axes=(0, 1, 3, 4, 2))
+        data, axes=(0, 2, 4, 5, 1, 3))
     return data
 
 
-def _unpack_channel(data, old_shape):
+def _unpack_batch_channel(data, old_shape):
     """Unpack the data channel dimension.
     """
-    data = nnvm.sym.transpose(data, axes=(0, 1, 4, 2, 3))
+    data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
     data = nnvm.sym.reshape(data, shape=old_shape)
     return data
 
 
-def _pack_weight(data, dshape, factor):
+def _pack_weight(data, dshape, cfactor):
     """Pack the weight into packed format.
     """
     assert len(dshape) == 4
-    assert dshape[0] % factor == 0
-    assert dshape[1] % factor == 0
+    assert dshape[0] % cfactor == 0
+    assert dshape[1] % cfactor == 0
     data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // factor, factor,
-                                   dshape[1] // factor, factor,
+                            shape=(dshape[0] // cfactor, cfactor,
+                                   dshape[1] // cfactor, cfactor,
                                    dshape[2], dshape[3]))
     data = nnvm.sym.transpose(
         data, axes=(0, 2, 4, 5, 1, 3))
     return data
 
 
-def _pack_bias(data, dshape, factor):
+def _pack_bias(data, dshape, bfactor, cfactor):
     """Pack the bias parameter.
     """
     assert len(dshape) == 3
-    assert dshape[0] % factor == 0
+    assert dshape[0] % cfactor == 0
     data = nnvm.sym.reshape(data,
-                            shape=(dshape[0] // factor,
-                                   factor, dshape[1], dshape[2]))
+                            shape=(dshape[0] // cfactor,
+                                   cfactor, dshape[1],
+                                   dshape[2], 1))
     data = nnvm.sym.transpose(
-        data, axes=(0, 2, 3, 1))
+        data, axes=(0, 2, 3, 4, 1))
+    # broadcast batch dimension to bfactor
+    data = nnvm.sym.broadcast_to(
+        data,
+        shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
     return data
 
 
@@ -245,8 +252,8 @@ def clean_cast(graph):
     return ret
 
 
-def pack(graph, shape_dict, factor, start_name=None):
-    """Pack the graph into channel packed format.
+def pack(graph, shape_dict, bfactor, cfactor, start_name=None):
+    """Pack the graph into batch&channel packed format.
 
     Parameters
     ----------
@@ -256,8 +263,11 @@ def pack(graph, shape_dict, factor, start_name=None):
     shape_dict : dict of str to shapex
        The input shape.
 
-    factor : int
-       The packing factor
+    bfactor : int
+       The packing factor in batch
+
+    cfactor : int
+       The packing factor in channel
 
     start_name: str, optional
        Start name start packing from certain known node.
@@ -290,42 +300,44 @@ def pack(graph, shape_dict, factor, start_name=None):
             new_node = nnvm.symbol.Variable(node_name)
             if start_name and node_name == start_name:
                 start_pack = True
-                new_node = _pack_channel(new_node, oshape, factor)
+                new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
         elif op_name == "max_pool2d":
             assert not start_pack
             start_pack = True
             new_node = get_clone(children, op_name, node_name, attrs)
-            new_node = _pack_channel(new_node, oshape, factor)
+            new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
         elif op_name == "global_avg_pool2d":
             if start_pack:
                 start_pack = False
-                children[0] = _unpack_channel(children[0], ishape[0])
+                children[0] = _unpack_batch_channel(children[0], ishape[0])
                 new_node = getattr(nnvm.symbol, op_name)(
                     *children, name=node_name, **attrs)
             else:
                 new_node = get_clone(children, op_name, node_name, attrs)
         elif op_name == "quantized_conv2d":
             if start_pack:
-                attrs["pack_channel"] = str(factor)
+                attrs["pack_batch"] = str(bfactor)
+                attrs["pack_channel"] = str(cfactor)
                 data, weight = children
-                weight = _pack_weight(weight, ishape[1], factor)
+                weight = _pack_weight(weight, ishape[1], cfactor)
                 new_node = nnvm.sym.quantized_conv2d(
                     data, weight, name=node_name, **attrs)
             elif counter == 1:
-                attrs["pack_channel"] = str(factor)
+                attrs["pack_batch"] = str(bfactor)
+                attrs["pack_channel"] = str(cfactor)
                 data, weight = children
-                data = _pack_channel(data, ishape[0], factor)
-                weight = _pack_weight(weight, ishape[1], factor)
+                data = _pack_batch_channel(data, ishape[0], bfactor, cfactor)
+                weight = _pack_weight(weight, ishape[1], cfactor)
                 new_node = nnvm.sym.quantized_conv2d(
                     data, weight, name=node_name, **attrs)
-                new_node = _unpack_channel(new_node, oshape)
+                new_node = _unpack_batch_channel(new_node, oshape)
                 counter = counter + 1
             else:
                 new_node = get_clone(children, op_name, node_name, attrs)
         elif op_name.startswith("broadcast"):
             if start_pack:
                 assert len(ishape[1]) == 3
-                children[1] = _pack_bias(children[1], ishape[1], factor)
+                children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
                 new_node = getattr(nnvm.symbol, op_name)(
                     *children, name=node_name, **attrs)
             else:
@@ -341,7 +353,7 @@ def pack(graph, shape_dict, factor, start_name=None):
     ret = node_map[graph.index.output_entries[0][0]]
     if start_pack:
         oshape = shape[graph.index.output_entries[0][0]]
-        ret = _unpack_channel(ret, oshape)
+        ret = _unpack_batch_channel(ret, oshape)
     graph = nnvm.graph.create(ret)
     graph = graph_attr.set_shape_inputs(graph, shape_dict)
     graph = graph.apply("InferShape")
diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc
index c0de87fa3..9e84acfc0 100644
--- a/vta/src/runtime.cc
+++ b/vta/src/runtime.cc
@@ -367,9 +367,10 @@ class UopQueue : public BaseQueue {
     }
     assert(num_op <= kMaxNumUop);
     uint32_t uop_begin = 0;
-    if (sram_end_ + num_op > kMaxElems) {
+    if (sram_end_ + num_op > kMaxNumUop) {
       // Need to evict
       cache_ptr_ = 0;
+      sram_begin_ = 0;
       sram_end_ = num_op;
     } else {
       uop_begin = sram_end_;
@@ -388,6 +389,7 @@ class UopQueue : public BaseQueue {
     dram_end_ += num_op;
     kernel->sram_begin_ = uop_begin;
     kernel->sram_end_ = sram_end_;
+    CHECK(kernel->cached());
     assert(uop_begin != sram_end_);
     cache_.insert(cache_.begin() + cache_ptr_, kernel);
     cache_.erase(cache_.begin() + evict_begin, cache_.begin() + cache_ptr_);
diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc
index 57bc21c98..9a953e7ae 100644
--- a/vta/src/sim/sim_driver.cc
+++ b/vta/src/sim/sim_driver.cc
@@ -162,6 +162,7 @@ class DRAM {
    */
   void Free(void* data) {
     std::lock_guard<std::mutex> lock(mutex_);
+    if (pmap_.size() == 0) return;
     auto it = pmap_.find(data);
     CHECK(it != pmap_.end());
     Page* p = it->second.get();
-- 
GitLab