From 9c44e4b43dca7f02959ab839a892ec97e2dd40d7 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Thu, 12 Apr 2018 22:12:43 -0700
Subject: [PATCH] [DRIVER] Add simulator, unify testcase to unittest (#25)

---
 vta/Makefile                               |  20 +-
 vta/include/vta/driver.h                   |   2 +-
 vta/include/vta/hw_spec.h                  |   4 +-
 vta/include/vta/runtime.h                  |   2 +-
 vta/make/config.mk                         |   3 +-
 vta/python/vta/__init__.py                 |   1 +
 vta/python/vta/environment.py              |  22 +-
 vta/python/vta/ir_pass.py                  |  16 +-
 vta/python/vta/rpc_client.py               |   3 +-
 vta/python/vta/testing/__init__.py         |   3 +
 vta/python/vta/testing/simulator.py        |  51 ++
 vta/python/vta/testing/util.py             |  30 ++
 vta/src/data_buffer.h                      |   2 +-
 vta/src/pynq/pynq_driver.cc                |   4 +-
 vta/src/runtime.cc                         |  25 +-
 vta/src/sim/sim_driver.cc                  | 581 +++++++++++++++++++++
 vta/src/tvm/vta_device_api.cc              |   3 -
 vta/tests/python/pynq/test_vta_insn.py     | 504 ------------------
 vta/tests/python/unittest/test_vta_insn.py | 482 +++++++++++++++++
 19 files changed, 1217 insertions(+), 541 deletions(-)
 create mode 100644 vta/python/vta/testing/__init__.py
 create mode 100644 vta/python/vta/testing/simulator.py
 create mode 100644 vta/python/vta/testing/util.py
 create mode 100644 vta/src/sim/sim_driver.cc
 delete mode 100644 vta/tests/python/pynq/test_vta_insn.py
 create mode 100644 vta/tests/python/unittest/test_vta_insn.py

diff --git a/vta/Makefile b/vta/Makefile
index 069f6e01c..6bfa82dc2 100644
--- a/vta/Makefile
+++ b/vta/Makefile
@@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE)
 	LDFLAGS += $(ADD_LDFLAGS)
 endif
 
+UNAME_S := $(shell uname -s)
+
+ifeq ($(UNAME_S), Darwin)
+	SHARED_LIBRARY_SUFFIX := dylib
+	WHOLE_ARCH= -all_load
+	NO_WHOLE_ARCH= -noall_load
+	LDFLAGS += -undefined dynamic_lookup
+else
+	SHARED_LIBRARY_SUFFIX := so
+	WHOLE_ARCH= --whole-archive
+	NO_WHOLE_ARCH= --no-whole-archive
+endif
+
 
 all: lib/libvta.so lib/libvta_runtime.so
 
@@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET)
 	LDFLAGS += -l:libdma.so
 endif
 
+ifeq ($(TARGET), sim)
+	VTA_LIB_SRC += $(wildcard src/sim/*.cc)
+endif
+
 VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
 
 build/%.o: src/%.cc
@@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o
 lint: pylint cpplint
 
 cpplint:
-	python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests
+	python nnvm/dmlc-core/scripts/lint.py vta cpp include src
 
 pylint:
 	pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
@@ -86,3 +103,4 @@ clean:
 -include build/*.d
 -include build/*/*.d
 -include build/*/*/*.d
+-include build/*/*/*/*.d
diff --git a/vta/include/vta/driver.h b/vta/include/vta/driver.h
index 8a29fc47a..58f778806 100644
--- a/vta/include/vta/driver.h
+++ b/vta/include/vta/driver.h
@@ -77,7 +77,7 @@ void VTAMemFree(void* buf);
  * \param buf Pointer to memory region allocated with VTAMemAlloc.
  * \return The physical address of the memory region.
  */
-vta_phy_addr_t VTAGetMemPhysAddr(void* buf);
+vta_phy_addr_t VTAMemGetPhyAddr(void* buf);
 
 /*!
  * \brief Flushes the region of memory out of the CPU cache to DRAM.
diff --git a/vta/include/vta/hw_spec.h b/vta/include/vta/hw_spec.h
index 7eae322a0..9d62d4e7d 100644
--- a/vta/include/vta/hw_spec.h
+++ b/vta/include/vta/hw_spec.h
@@ -519,8 +519,8 @@ typedef struct {
   uint64_t alu_opcode     : VTA_ALU_OPCODE_BIT_WIDTH;
   /*! \brief Use immediate is true */
   uint64_t use_imm        : 1;
-  /*! \brief Immediate value */
-  uint64_t imm            : VTA_ALUOP_IMM_BIT_WIDTH;
+  /*! \brief Immediate value: allow negative value */
+  int64_t imm            : VTA_ALUOP_IMM_BIT_WIDTH;
 } VTAAluInsn;
 
 /*! \brief VTA ALU instruction converter */
diff --git a/vta/include/vta/runtime.h b/vta/include/vta/runtime.h
index c9373846d..479540129 100644
--- a/vta/include/vta/runtime.h
+++ b/vta/include/vta/runtime.h
@@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode,
                 uint32_t wgt_index,
                 uint32_t opcode,
                 uint32_t use_imm,
-                uint32_t imm_val);
+                int32_t imm_val);
 
 /*!
  * \brief Mark start of a micro op loop.
diff --git a/vta/make/config.mk b/vta/make/config.mk
index 9f611896a..e329dcf98 100644
--- a/vta/make/config.mk
+++ b/vta/make/config.mk
@@ -27,7 +27,7 @@ ADD_LDFLAGS=
 ADD_CFLAGS=
 
 # the hardware target
-TARGET = VTA_PYNQ_TARGET
+TARGET = pynq
 
 #---------------------
 # VTA hardware parameters
@@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
 
 # Update ADD_CFLAGS
 ADD_CFLAGS += \
-	-D$(TARGET) \
 	-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
 	-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
 	-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py
index 275c15b22..693a4124f 100644
--- a/vta/python/vta/__init__.py
+++ b/vta/python/vta/__init__.py
@@ -8,6 +8,7 @@ try:
     from . import arm_conv2d, vta_conv2d
     from .build_module import build_config, lower, build
     from .rpc_client import reconfig_runtime, program_fpga
+
     from . import graph
 except ImportError:
     pass
diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py
index b0d7d170f..8ff2bbce2 100644
--- a/vta/python/vta/environment.py
+++ b/vta/python/vta/environment.py
@@ -89,7 +89,7 @@ class Environment(object):
     """
     current = None
     cfg_keys = [
-        "target",
+        "TARGET",
         "LOG_INP_WIDTH",
         "LOG_WGT_WIDTH",
         "LOG_ACC_WIDTH",
@@ -204,9 +204,19 @@ class Environment(object):
 
     @property
     def gevm(self):
-        """GEMM intrinsic"""
+        """GEVM intrinsic"""
         return self.dev.gevm
 
+    @property
+    def target_host(self):
+        """The target host"""
+        if self.TARGET == "pynq":
+            return "llvm -target=armv7-none-linux-gnueabihf"
+        elif self.TARGET == "sim":
+            return "llvm"
+        else:
+            raise ValueError("Unknown target %s" % self.TARGET)
+
 
 def get_env():
     """Get the current VTA Environment.
@@ -278,6 +288,7 @@ def _init_env():
 
     for k in Environment.cfg_keys:
         keys.add("VTA_" + k)
+    keys.add("TARGET")
 
     if not os.path.isfile(filename):
         raise RuntimeError(
@@ -290,8 +301,11 @@ def _init_env():
             for k in keys:
                 if k  +" =" in line:
                     val = line.split("=")[1].strip()
-                    cfg[k[4:]] = int(val)
-    cfg["target"] = "pynq"
+                    if k.startswith("VTA_"):
+                        k = k[4:]
+                        cfg[k] = int(val)
+                    else:
+                        cfg[k] = val
     return Environment(cfg)
 
 Environment.current = _init_env()
diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py
index f85f1760e..9310d46dc 100644
--- a/vta/python/vta/ir_pass.py
+++ b/vta/python/vta/ir_pass.py
@@ -78,8 +78,7 @@ def fold_uop_loop(stmt_in):
             if not fail[0]:
                 begin = tvm.call_extern(
                     "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
-                end = tvm.call_extern(
-                    "int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets)
+                end = tvm.call_extern("int32", "VTAUopLoopEnd")
                 return [begin, ret, end]
         raise ValueError("Failed to fold the GEMM instructions..")
 
@@ -683,8 +682,14 @@ def inject_alu_intrin(stmt_in):
                 else:
                     raise RuntimeError(
                         "Function call not recognized %s" % (loop_body.value.name))
+            elif isinstance(loop_body.value, tvm.expr.Load):
+                alu_opcode = env.dev.ALU_OPCODE_SHR
+                lhs = loop_body.value
+                rhs = tvm.const(0)
             else:
-                raise RuntimeError("Expression not recognized %s" % (type(loop_body.value)))
+                raise RuntimeError(
+                    "Expression not recognized %s, %s, %s" % (
+                        type(loop_body.value), str(loop_body.value), str(stmt)))
 
             # Derive array index coefficients
             dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
@@ -772,7 +777,9 @@ def inject_alu_intrin(stmt_in):
             irb = tvm.ir_builder.create()
             for idx, extent in enumerate(extents):
                 irb.emit(tvm.call_extern(
-                    "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx]))
+                    "int32", "VTAUopLoopBegin",
+                    extent, dst_coeff[idx], src_coeff[idx], 0))
+            use_imm = int(use_imm)
             irb.emit(tvm.call_extern(
                 "int32", "VTAUopPush",
                 1, 0,
@@ -804,5 +811,6 @@ def debug_print(stmt):
     stmt : Stmt
         The
     """
+    # pylint: disable=superfluous-parens
     print(stmt)
     return stmt
diff --git a/vta/python/vta/rpc_client.py b/vta/python/vta/rpc_client.py
index a6355a592..f0a02d7cc 100644
--- a/vta/python/vta/rpc_client.py
+++ b/vta/python/vta/rpc_client.py
@@ -24,8 +24,7 @@ def reconfig_runtime(remote):
             "VTA_LOG_WGT_BUFF_SIZE",
             "VTA_LOG_ACC_BUFF_SIZE",
             "VTA_LOG_OUT_BUFF_SIZE"]
-
-    cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
+    cflags = []
     for k in keys:
         cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
     freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
diff --git a/vta/python/vta/testing/__init__.py b/vta/python/vta/testing/__init__.py
new file mode 100644
index 000000000..513fa1e99
--- /dev/null
+++ b/vta/python/vta/testing/__init__.py
@@ -0,0 +1,3 @@
+"""Testing utilities, this namespace is not imported by default."""
+
+from . util import run
diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py
new file mode 100644
index 000000000..bb436a185
--- /dev/null
+++ b/vta/python/vta/testing/simulator.py
@@ -0,0 +1,51 @@
+"""Utilities to start simulator."""
+import os
+import ctypes
+import json
+import tvm
+
+def _load_lib():
+    """Load local library, assuming they are simulator."""
+    # pylint: disable=unused-variable
+    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+    dll_path = [
+        os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
+        os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
+    ]
+    runtime_dll = []
+    if not all(os.path.exists(f) for f in dll_path):
+        return []
+    try:
+        for fname in dll_path:
+            runtime_dll.append(ctypes.CDLL(fname, ctypes.RTLD_GLOBAL))
+        return runtime_dll
+    except OSError:
+        return []
+
+
+def enabled():
+    """Check if simulator is enabled."""
+    f = tvm.get_global_func("vta.simulator.profiler_clear", True)
+    return f is not None
+
+
+def clear_stats():
+    """Clear profiler statistics"""
+    f = tvm.get_global_func("vta.simulator.profiler_clear", True)
+    if f:
+        f()
+
+
+def stats():
+    """Clear profiler statistics
+
+    Returns
+    -------
+    stats : dict
+        Current profiler statistics
+    """
+    x = tvm.get_global_func("vta.simulator.profiler_status")()
+    return json.loads(x)
+
+
+LIBS = _load_lib()
diff --git a/vta/python/vta/testing/util.py b/vta/python/vta/testing/util.py
new file mode 100644
index 000000000..bbf6417a1
--- /dev/null
+++ b/vta/python/vta/testing/util.py
@@ -0,0 +1,30 @@
+"""Test Utilities"""
+from __future__ import absolute_import as _abs
+
+import os
+from tvm.contrib import rpc
+from ..environment import get_env
+from . import simulator
+
+
+def run(run_func):
+    """Run test function on all available env.
+
+    Parameters
+    ----------
+    run_func : function(env, remote)
+    """
+    env = get_env()
+    # run on simulator
+    if simulator.enabled():
+        env.TARGET = "sim"
+        run_func(env, rpc.LocalSession())
+
+    # Run on PYNQ if env variable exists
+    pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
+    if pynq_host:
+        env.TARGET = "pynq"
+        port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
+        port = int(port)
+        remote = rpc.connect(pynq_host, port)
+        run_func(env, remote)
diff --git a/vta/src/data_buffer.h b/vta/src/data_buffer.h
index aed92c49e..fba46dc07 100644
--- a/vta/src/data_buffer.h
+++ b/vta/src/data_buffer.h
@@ -57,7 +57,7 @@ struct DataBuffer {
     assert(data != nullptr);
     DataBuffer* buffer = new DataBuffer();
     buffer->data_ = data;
-    buffer->phy_addr_ = VTAGetMemPhysAddr(data);
+    buffer->phy_addr_ = VTAMemGetPhyAddr(data);
     return buffer;
   }
   /*!
diff --git a/vta/src/pynq/pynq_driver.cc b/vta/src/pynq/pynq_driver.cc
index 0330450db..e2630b14a 100644
--- a/vta/src/pynq/pynq_driver.cc
+++ b/vta/src/pynq/pynq_driver.cc
@@ -1,6 +1,6 @@
 /*!
  *  Copyright (c) 2018 by Contributors
- * \file vta_pynq_driver.c
+ * \file pynq_driver.c
  * \brief VTA driver for Pynq board.
  */
 
@@ -17,7 +17,7 @@ void VTAMemFree(void* buf) {
   cma_free(buf);
 }
 
-vta_phy_addr_t VTAGetMemPhysAddr(void* buf) {
+vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
   return cma_get_phy_addr(buf);
 }
 
diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc
index a8819323f..da5109c14 100644
--- a/vta/src/runtime.cc
+++ b/vta/src/runtime.cc
@@ -11,6 +11,7 @@
 #include <vta/driver.h>
 #include <vta/hw_spec.h>
 #include <vta/runtime.h>
+#include <dmlc/logging.h>
 
 #include <cassert>
 #include <vector>
@@ -113,7 +114,7 @@ class UopKernel {
             uint32_t wgt_index,
             uint32_t opcode,
             uint32_t use_imm,
-            uint32_t imm_val) {
+            int32_t imm_val) {
     // The loop nest structure
     VerifyDep(dst_index);
     VTAUop op;
@@ -166,7 +167,7 @@ class UopKernel {
   uint32_t opcode_{0xFFFFFFFF};
   uint32_t reset_out_{0xFFFFFFFF};
   bool use_imm_{false};
-  uint16_t imm_val_{0};
+  int16_t imm_val_{0};
 
  private:
   // Verify that we don't write to the same acc_mem index two cycles in a row
@@ -195,10 +196,6 @@ class UopKernel {
 
 /*!
  * \brief Base class of all queues to send and recv serial data.
- * \param kElemBytes Element unit bytes.
- * \param kMaxBytes Maximum number of bytes.
- * \param kCoherent Whether we have coherent access to the buffer.
- * \param kAlwaysCache Wether we should use cached memory.
  */
 class BaseQueue {
  public:
@@ -227,7 +224,7 @@ class BaseQueue {
     dram_buffer_ = static_cast<char*>(VTAMemAlloc(
         max_bytes, coherent || always_cache_));
     assert(dram_buffer_ != nullptr);
-    dram_phy_addr_ = VTAGetMemPhysAddr(dram_buffer_);
+    dram_phy_addr_ = VTAMemGetPhyAddr(dram_buffer_);
   }
   /*!
    * \brief Reset the pointer of the buffer.
@@ -597,14 +594,14 @@ class InsnQueue : public BaseQueue {
         }
         // Print instruction field information
         if (c.mem.opcode == VTA_OPCODE_LOAD) {
-            printf("LOAD ");
-            if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
-            if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
-            if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
-            if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
+          printf("LOAD ");
+          if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
+          if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
+          if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
+          if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
         }
         if (c.mem.opcode == VTA_OPCODE_STORE) {
-            printf("STORE\n");
+          printf("STORE:\n");
         }
         printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
                static_cast<int>(c.mem.pop_prev_dep),
@@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode,
                 uint32_t wgt_index,
                 uint32_t opcode,
                 uint32_t use_imm,
-                uint32_t imm_val) {
+                int32_t imm_val) {
   vta::CommandQueue::ThreadLocal()->record_kernel()
       ->Push(mode, reset_out, dst_index, src_index,
              wgt_index, opcode, use_imm, imm_val);
diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc
new file mode 100644
index 000000000..a88ab2466
--- /dev/null
+++ b/vta/src/sim/sim_driver.cc
@@ -0,0 +1,581 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file sim_driver.cc
+ * \brief VTA driver for simulated backend.
+ */
+#include <vta/driver.h>
+#include <vta/hw_spec.h>
+#include <tvm/runtime/registry.h>
+#include <type_traits>
+#include <mutex>
+#include <map>
+#include <unordered_map>
+#include <cstring>
+#include <sstream>
+
+namespace vta {
+namespace sim {
+
+/*!
+ * \brief Helper class to pack and unpack bits
+ *  Applies truncation when pack to low level bits.
+ *
+ * \tparam bits The number of bits in integer.
+ * \note This implementation relies on little endian.
+ */
+template<uint32_t bits>
+class BitPacker {
+ public:
+  explicit BitPacker(void* data) {
+    data_ = static_cast<uint32_t*>(data);
+  }
+
+  uint32_t GetUnsigned(uint32_t index) const {
+    if (bits == 32) {
+      return data_[index];
+    } else if (bits == 16) {
+      return reinterpret_cast<uint16_t*>(data_)[index];
+    } else if (bits == 8) {
+      return reinterpret_cast<uint8_t*>(data_)[index];
+    } else {
+      uint32_t offset = index / kNumPackElem;
+      uint32_t shift = index % kNumPackElem;
+      return (data_[offset] >> shift) & kMask;
+    }
+  }
+
+  int32_t GetSigned(uint32_t index) const {
+    if (bits == 32) {
+      return reinterpret_cast<int32_t*>(data_)[index];
+    } else if (bits == 16) {
+      return reinterpret_cast<int16_t*>(data_)[index];
+    } else if (bits == 8) {
+      return reinterpret_cast<int8_t*>(data_)[index];
+    } else {
+      uint32_t offset = index / kNumPackElem;
+      uint32_t shift = (index % kNumPackElem) * bits;
+      int32_t uvalue = static_cast<int32_t>(
+          (data_[offset] >> shift) & kMask);
+      int kleft = 32 - bits;
+      return (uvalue << kleft) >> kleft;
+    }
+  }
+
+  void SetUnsigned(uint32_t index, uint32_t value) {
+    if (bits == 32) {
+      data_[index] = value;
+    } else if (bits == 16) {
+      reinterpret_cast<uint16_t*>(data_)[index] = value;
+    } else if (bits == 8) {
+      reinterpret_cast<uint8_t*>(data_)[index] = value;
+    } else {
+      uint32_t offset = index / kNumPackElem;
+      uint32_t shift = (index % kNumPackElem) * bits;
+      data_[offset] &= (~(kMask << shift));
+      data_[offset] |= (value & kMask) << shift;
+    }
+  }
+
+  void SetSigned(uint32_t index, int32_t value) {
+    if (bits == 32) {
+      reinterpret_cast<int32_t*>(data_)[index] = value;
+    } else if (bits == 16) {
+      reinterpret_cast<int16_t*>(data_)[index] = value;
+    } else if (bits == 8) {
+      reinterpret_cast<int8_t*>(data_)[index] = value;
+    } else {
+      uint32_t offset = index / kNumPackElem;
+      uint32_t shift = (index % kNumPackElem) * bits;
+      data_[offset] &= (~(kMask << shift));
+      data_[offset] |= static_cast<uint32_t>(value & kMask) << shift;
+    }
+  }
+
+ private:
+  uint32_t* data_;
+  static constexpr uint32_t kNumPackElem = 32 / bits;
+  static constexpr uint32_t kMask = (1U << (bits >= 32U ? 31U : bits)) - 1U;
+};
+
+/*!
+ * \brief DRAM memory manager
+ *  Implements simple paging to allow physical address translation.
+ */
+class DRAM {
+ public:
+  /*!
+   * \brief Get virtual address given physical address.
+   * \param phy_addr The simulator phyiscal address.
+   * \return The true virtual address;
+   */
+  void* GetAddr(uint64_t phy_addr) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    uint64_t loc = (phy_addr >> kPageBits) - 1;
+    CHECK_LT(loc, ptable_.size());
+    Page* p = ptable_[loc];
+    CHECK(p != nullptr);
+    size_t offset = (loc - p->ptable_begin) << kPageBits;
+    offset += phy_addr & (kPageSize - 1);
+    return reinterpret_cast<char*>(p->data) + offset;
+  }
+  /*!
+   * \brief Get physical address
+   * \param buf The virtual address.
+   * \return The true physical address;
+   */
+  vta_phy_addr_t GetPhyAddr(void* buf) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = pmap_.find(buf);
+    CHECK(it != pmap_.end());
+    Page* p = it->second.get();
+    return (p->ptable_begin + 1) << kPageBits;
+  }
+  /*!
+   * \brief Allocate memory from manager
+   * \param size The size of memory
+   * \return The virtual address
+   */
+  void* Alloc(size_t size) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    size_t npage = (size + kPageSize - 1) / kPageSize;
+    auto it = free_map_.lower_bound(npage);
+    if (it != free_map_.end()) {
+      Page* p = it->second;
+      free_map_.erase(it);
+      return p->data;
+    }
+    size_t start = ptable_.size();
+    std::unique_ptr<Page> p(new Page(start, npage));
+    // insert page entry
+    ptable_.resize(start + npage, p.get());
+    void* data = p->data;
+    pmap_[data] = std::move(p);
+    return data;
+  }
+  /*!
+   * \brief Free the memory.
+   * \param size The size of memory
+   * \return The virtual address
+   */
+  void Free(void* data) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = pmap_.find(data);
+    CHECK(it != pmap_.end());
+    Page* p = it->second.get();
+    free_map_.insert(std::make_pair(p->num_pages, p));
+  }
+
+  static DRAM* Global() {
+    static DRAM inst;
+    return &inst;
+  }
+
+
+ private:
+  // The bits in page table
+  static constexpr vta_phy_addr_t kPageBits = 16;
+  // page size, also the maximum allocable size 16 K
+  static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits;
+  /*! \brief A page in the DRAM */
+  struct Page {
+    /*! \brief Data Type */
+    using DType = typename std::aligned_storage<kPageSize, 256>::type;
+    /*! \brief Start location in page table */
+    size_t ptable_begin;
+    /*! \brief The total number of pages */
+    size_t num_pages;
+    /*! \brief Data */
+    DType* data{nullptr};
+    // construct a new page
+    explicit Page(size_t ptable_begin, size_t num_pages)
+        : ptable_begin(ptable_begin), num_pages(num_pages) {
+      data = new DType[num_pages];
+    }
+    ~Page() {
+      delete [] data;
+    }
+  };
+  // Internal lock
+  std::mutex mutex_;
+  // Physical address -> page
+  std::vector<Page*> ptable_;
+  // virtual addres -> page
+  std::unordered_map<void*, std::unique_ptr<Page> > pmap_;
+  // Free map
+  std::multimap<size_t, Page*> free_map_;
+};
+
+/*!
+ * \brief Register file.
+ * \tparam kBits Number of bits of one value.
+ * \tparam kLane Number of lanes in one element.
+ * \tparam kMaxNumElem Maximum number of element.
+ */
+template<int kBits, int kLane, int kMaxNumElem>
+class SRAM {
+ public:
+  /*! \brief Bytes of single vector element */
+  static const int kElemBytes = (kBits * kLane + 7) / 8;
+  /*! \brief content data type */
+  using DType = typename std::aligned_storage<kElemBytes, kElemBytes>::type;
+  SRAM() {
+    data_ = new DType[kMaxNumElem];
+  }
+  ~SRAM() {
+    delete [] data_;
+  }
+  // Get the i-th index
+  void* BeginPtr(uint32_t index) {
+    CHECK_LT(index, kMaxNumElem);
+    return &(data_[index]);
+  }
+  // Execute the load instruction on this SRAM
+  void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
+    load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
+    DType* sram_ptr = data_ + op->sram_base;
+    uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
+        op->dram_base * kElemBytes));
+    uint64_t xtotal = op->x_size + op->x_pad_0 + op->x_pad_1;
+    uint32_t ytotal = op->y_size + op->y_pad_0 + op->y_pad_1;
+    uint64_t sram_end = op->sram_base + xtotal * ytotal;
+    CHECK_LE(sram_end, kMaxNumElem);
+    memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0);
+    sram_ptr += xtotal * op->y_pad_0;
+    for (uint32_t y = 0; y < op->y_size; ++y) {
+      memset(sram_ptr, 0, kElemBytes * op->x_pad_0);
+      sram_ptr += op->x_pad_0;
+      memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size);
+      sram_ptr += op->x_size;
+      BitPacker<kBits> src(sram_ptr);
+      memset(sram_ptr, 0, kElemBytes * op->x_pad_1);
+      sram_ptr += op->x_pad_1;
+      dram_ptr += kElemBytes * op->x_stride;
+    }
+    memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_1);
+  }
+  // Execute the store instruction on this SRAM apply trucation.
+  // This relies on the elements is 32 bits
+  template<int target_bits>
+  void TruncStore(const VTAMemInsn* op, DRAM* dram) {
+    CHECK_EQ(op->x_pad_0, 0);
+    CHECK_EQ(op->x_pad_1, 0);
+    CHECK_EQ(op->y_pad_0, 0);
+    CHECK_EQ(op->y_pad_1, 0);
+    int target_width = (target_bits * kLane + 7) / 8;
+    BitPacker<kBits> src(data_ + op->sram_base);
+    BitPacker<target_bits> dst(dram->GetAddr(op->dram_base * target_width));
+    for (uint32_t y = 0; y < op->y_size; ++y) {
+      for (uint32_t x = 0; x < op->x_size; ++x) {
+        uint32_t sram_base = y * op->x_size + x;
+        uint32_t dram_base = y * op->x_stride + x;
+        for (int i = 0; i < kLane; ++i) {
+          dst.SetSigned(dram_base * kLane + i,
+                        src.GetSigned(sram_base * kLane +i));
+        }
+      }
+    }
+  }
+
+ private:
+  /*! \brief internal data content */
+  DType* data_;
+};
+
+
+/*!
+ * \brief Memory information of special memory region.
+ *  Use MemoryInfo as its container type
+ */
+class Profiler {
+ public:
+  /*! \brief The memory load statistics */
+  uint64_t inp_load_nbytes{0};
+  /*! \brief The memory load statistics */
+  uint64_t wgt_load_nbytes{0};
+  /*! \brief The ACC memory load statistics */
+  uint64_t acc_load_nbytes{0};
+  /*! \brief The ACC memory load statistics */
+  uint64_t uop_load_nbytes{0};
+  /*! \brief The ACC memory load statistics */
+  uint64_t out_store_nbytes{0};
+  /*! \brief instr counter for gemm */
+  uint64_t gemm_counter{0};
+  /*! \brief instr counter for ALU ops */
+  uint64_t alu_counter{0};
+  /*! \brief clear the profiler */
+  void Clear() {
+    inp_load_nbytes = 0;
+    wgt_load_nbytes = 0;
+    acc_load_nbytes = 0;
+    uop_load_nbytes = 0;
+    out_store_nbytes = 0;
+    gemm_counter = 0;
+    alu_counter = 0;
+  }
+
+  std::string AsJSON() {
+    std::ostringstream os;
+    os << "{\n"
+       << " \"inp_load_nbytes\":" << inp_load_nbytes << ",\n"
+       << " \"wgt_load_nbytes\":" << wgt_load_nbytes << ",\n"
+       << " \"acc_load_nbytes\":" << acc_load_nbytes << ",\n"
+       << " \"uop_load_nbytes\":" << uop_load_nbytes << ",\n"
+       << " \"out_store_nbytes\":" << out_store_nbytes << ",\n"
+       << " \"gemm_counter\":" << gemm_counter << ",\n"
+       << " \"alu_counter\":" << alu_counter << "\n"
+       <<"}\n";
+    return os.str();
+  }
+
+  static Profiler* ThreadLocal() {
+    static thread_local Profiler inst;
+    return &inst;
+  }
+};
+
+
+// Simulate device
+// TODO(tqchen,thierry): queue based event driven simulation.
+class Device {
+ public:
+  Device() {
+    prof_ = Profiler::ThreadLocal();
+    dram_ = DRAM::Global();
+  }
+
+  int Run(vta_phy_addr_t insn_phy_addr,
+          uint32_t insn_count,
+          uint32_t wait_cycles) {
+    VTAGenericInsn* insn = static_cast<VTAGenericInsn*>(
+        dram_->GetAddr(insn_phy_addr));
+    finish_counter_ = 0;
+    for (uint32_t i = 0; i < insn_count; ++i) {
+      this->Run(insn + i);
+    }
+    return 0;
+  }
+
+ private:
+  void Run(const VTAGenericInsn* insn) {
+    const VTAMemInsn* mem = reinterpret_cast<const VTAMemInsn*>(insn);
+    const VTAGemInsn* gem = reinterpret_cast<const VTAGemInsn*>(insn);
+    const VTAAluInsn* alu = reinterpret_cast<const VTAAluInsn*>(insn);
+    switch (mem->opcode) {
+      case VTA_OPCODE_LOAD: RunLoad(mem); break;
+      case VTA_OPCODE_STORE: RunStore(mem); break;
+      case VTA_OPCODE_GEMM: RunGEMM(gem); break;
+      case VTA_OPCODE_ALU: RunALU(alu); break;
+      case VTA_OPCODE_FINISH: ++finish_counter_; break;
+      default: {
+        LOG(FATAL) << "Unknown op_code" << mem->opcode;
+      }
+    }
+  }
+
+  void RunLoad(const VTAMemInsn* op) {
+    if (op->x_size == 0) return;
+    if (op->memory_type == VTA_MEM_ID_INP) {
+      inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
+    } else if (op->memory_type == VTA_MEM_ID_WGT) {
+      wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
+    } else if (op->memory_type == VTA_MEM_ID_ACC) {
+      acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
+    } else if (op->memory_type == VTA_MEM_ID_UOP) {
+      uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
+    } else {
+      LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
+    }
+  }
+
+  void RunStore(const VTAMemInsn* op) {
+    if (op->memory_type == VTA_MEM_ID_ACC ||
+        op->memory_type == VTA_MEM_ID_UOP) {
+      prof_->out_store_nbytes += (
+          op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
+      acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
+    } else {
+      LOG(FATAL) << "Store do not support memory_type="
+                 << op->memory_type;
+    }
+  }
+
+  void RunGEMM(const VTAGemInsn* op) {
+    if (!op->reset_reg) {
+      prof_->gemm_counter += op->iter_out * op->iter_in;
+      for (uint32_t y = 0; y < op->iter_out; ++y) {
+        for (uint32_t x = 0; x < op->iter_in; ++x) {
+          for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
+            VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
+            // Read in memory indices
+            uint32_t acc_idx = uop_ptr->dst_idx;
+            uint32_t inp_idx = uop_ptr->src_idx;
+            uint32_t wgt_idx = uop_ptr->wgt_idx;
+            acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
+            inp_idx += y * op->src_factor_out + x * op->src_factor_in;
+            wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in;
+            BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
+            BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx));
+            BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx));
+            // gemm loop
+            for (uint32_t i = 0; i < VTA_BATCH; ++i) {
+              for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {
+                uint32_t acc_offset = i * VTA_BLOCK_OUT + j;
+                int32_t sum = acc.GetSigned(acc_offset);
+                for (uint32_t k = 0; k < VTA_BLOCK_IN; ++k) {
+                  sum +=
+                      inp.GetSigned(i * VTA_BLOCK_IN + k) *
+                      wgt.GetSigned(j * VTA_BLOCK_IN + k);
+                }
+                acc.SetSigned(acc_offset, sum);
+              }
+            }
+          }
+        }
+      }
+    } else {
+      // reset
+      for (uint32_t y = 0; y < op->iter_out; ++y) {
+        for (uint32_t x = 0; x < op->iter_in; ++x) {
+          for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
+            VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
+            uint32_t acc_idx = uop_ptr->dst_idx;
+            acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
+            BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
+            for (uint32_t i = 0; i < VTA_BATCH * VTA_BLOCK_OUT; ++i) {
+              acc.SetSigned(i, 0);
+            }
+          }
+        }
+      }
+    }
+  }
+
+  void RunALU(const VTAAluInsn* op) {
+    prof_->alu_counter += op->iter_out * op->iter_in;
+    if (op->use_imm) {
+      RunALU_<true>(op);
+    } else {
+      RunALU_<false>(op);
+    }
+  }
+
+  template<bool use_imm>
+  void RunALU_(const VTAAluInsn* op) {
+    switch (op->alu_opcode) {
+      case VTA_ALU_OPCODE_ADD: {
+        return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
+            return x + y;
+          });
+      }
+      case VTA_ALU_OPCODE_MAX: {
+        return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
+            return std::max(x, y);
+          });
+      }
+      case VTA_ALU_OPCODE_MIN: {
+        return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
+            return std::min(x, y);
+          });
+      }
+      case VTA_ALU_OPCODE_SHR: {
+        return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
+            if (y >= 0) {
+              return x >> y;
+            } else {
+              return x << (-y);
+            }
+          });
+      }
+      default: {
+        LOG(FATAL) << "Unknown ALU code " << op->alu_opcode;
+      }
+    }
+  }
+
+  template<bool use_imm, typename F>
+  void RunALULoop(const VTAAluInsn* op, F func) {
+    for (int y = 0; y < op->iter_out; ++y) {
+      for (int x = 0; x < op->iter_in; ++x) {
+        for (int k = op->uop_bgn; k < op->uop_end; ++k) {
+          // Read micro op
+          VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(k));
+          uint32_t dst_index = uop_ptr->dst_idx;
+          uint32_t src_index = uop_ptr->src_idx;
+          dst_index += y * op->dst_factor_out + x * op->dst_factor_in;
+          src_index += y * op->src_factor_out + x * op->src_factor_in;
+          BitPacker<VTA_ACC_WIDTH> dst(acc_.BeginPtr(dst_index));
+          BitPacker<VTA_ACC_WIDTH> src(acc_.BeginPtr(src_index));
+          for (int k = 0; k < VTA_BLOCK_OUT; ++k) {
+            if (use_imm) {
+              dst.SetSigned(k, func(dst.GetSigned(k), op->imm));
+            } else {
+              dst.SetSigned(k, func(dst.GetSigned(k), src.GetSigned(k)));
+            }
+          }
+        }
+      }
+    }
+  }
+  // the finish counter
+  int finish_counter_{0};
+  // Prof_
+  Profiler* prof_;
+  // The DRAM interface
+  DRAM* dram_;
+  // The SRAM
+  SRAM<VTA_INP_WIDTH, VTA_BATCH * VTA_BLOCK_IN, VTA_INP_BUFF_DEPTH> inp_;
+  SRAM<VTA_WGT_WIDTH, VTA_BLOCK_IN * VTA_BLOCK_OUT, VTA_WGT_BUFF_DEPTH> wgt_;
+  SRAM<VTA_ACC_WIDTH, VTA_BATCH * VTA_BLOCK_OUT, VTA_ACC_BUFF_DEPTH> acc_;
+  SRAM<VTA_UOP_WIDTH, 1, VTA_UOP_BUFF_DEPTH> uop_;
+};
+
+using tvm::runtime::TVMRetValue;
+using tvm::runtime::TVMArgs;
+
+TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    Profiler::ThreadLocal()->Clear();
+  });
+TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    *rv = Profiler::ThreadLocal()->AsJSON();
+  });
+}  // namespace sim
+}  // namespace vta
+
+void* VTAMemAlloc(size_t size, int cached) {
+  return vta::sim::DRAM::Global()->Alloc(size);
+}
+
+void VTAMemFree(void* buf) {
+  vta::sim::DRAM::Global()->Free(buf);
+}
+
+vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
+  return vta::sim::DRAM::Global()->GetPhyAddr(buf);
+}
+
+void VTAFlushCache(vta_phy_addr_t buf, int size) {
+}
+
+void VTAInvalidateCache(vta_phy_addr_t buf, int size) {
+}
+
+VTADeviceHandle VTADeviceAlloc() {
+  return new vta::sim::Device();
+}
+
+void VTADeviceFree(VTADeviceHandle handle) {
+  delete static_cast<vta::sim::Device*>(handle);
+}
+
+int VTADeviceRun(VTADeviceHandle handle,
+                 vta_phy_addr_t insn_phy_addr,
+                 uint32_t insn_count,
+                 uint32_t wait_cycles) {
+  return static_cast<vta::sim::Device*>(handle)->Run(
+      insn_phy_addr, insn_count, wait_cycles);
+}
+
+void VTAProgram(const char* bitstream) {
+}
diff --git a/vta/src/tvm/vta_device_api.cc b/vta/src/tvm/vta_device_api.cc
index 450b23b05..e4671d8a0 100644
--- a/vta/src/tvm/vta_device_api.cc
+++ b/vta/src/tvm/vta_device_api.cc
@@ -67,9 +67,6 @@ class VTADeviceAPI final : public DeviceAPI {
         std::make_shared<VTADeviceAPI>();
     return inst;
   }
-
- private:
-  void* runtime_dll_{nullptr};
 };
 
 struct VTAWorkspacePool : public WorkspacePool {
diff --git a/vta/tests/python/pynq/test_vta_insn.py b/vta/tests/python/pynq/test_vta_insn.py
deleted file mode 100644
index 14baede4e..000000000
--- a/vta/tests/python/pynq/test_vta_insn.py
+++ /dev/null
@@ -1,504 +0,0 @@
-"""Unit test VTA's instructions """
-import tvm
-import vta
-import mxnet as mx
-import numpy as np
-import topi
-from tvm.contrib import rpc, util
-
-host = "pynq"
-port = 9091
-target = "llvm -target=armv7-none-linux-gnueabihf"
-do_verify = True
-print_ir = False
-
-def test_save_load_out():
-    env = vta.get_env()
-    """Test save/store output command"""
-    n = 4
-    x = tvm.placeholder(
-        (n, n, env.BATCH, env.BLOCK_OUT),
-        name="x",
-        dtype=env.acc_dtype)
-    x_buf = tvm.compute(
-        (n, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: x(*i),
-        "x_buf")
-    # insert no-op that won't be optimized away
-    y_buf = tvm.compute(
-        (n, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: x_buf(*i)>>0,
-        "y_buf")
-    y = tvm.compute(
-        (n, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: y_buf(*i).astype(env.inp_dtype),
-        "y")
-    # schedule
-    s = tvm.create_schedule(y.op)
-    s[x_buf].set_scope(env.acc_scope)
-    s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
-    s[y_buf].set_scope(env.acc_scope)
-    s[y_buf].pragma(y_buf.op.axis[0], env.alu)
-    s[y].pragma(y.op.axis[0], env.dma_copy)
-
-    def verify():
-        # build
-        with vta.build_config(env.DEBUG_DUMP_INSN):
-            m = vta.build(s, [x, y], "ext_dev", target)
-        temp = util.tempdir()
-        remote = rpc.connect(host, port)
-        m.save(temp.relpath("load_act.o"))
-        remote.upload(temp.relpath("load_act.o"))
-        f = remote.load_module("load_act.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        x_np = np.random.randint(
-            1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
-        y_np = x_np.astype(y.dtype)
-        x_nd = tvm.nd.array(x_np, ctx)
-        y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
-        f(x_nd, y_nd)
-        np.testing.assert_equal(y_np, y_nd.asnumpy())
-        print("\tFinished verification...")
-    if do_verify:
-        verify()
-
-
-def test_padded_load():
-    """Test padded load."""
-    env = vta.get_env()
-    # declare
-    n = 21
-    m = 20
-    pad_before = [0, 1, 0, 0]
-    pad_after = [1, 3, 0, 0]
-    x = tvm.placeholder(
-        (n, m, env.BATCH, env.BLOCK_OUT),
-        name="x",
-        dtype=env.acc_dtype)
-    x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
-    # insert no-op that won't be optimized away
-    y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
-                         m + pad_before[1] + pad_after[1],
-                         env.BATCH,
-                         env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
-    y = tvm.compute((n + pad_before[0] + pad_after[0],
-                     m + pad_before[1] + pad_after[1],
-                     env.BATCH,
-                     env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
-    # schedule
-    s = tvm.create_schedule(y.op)
-    s[x_buf].set_scope(env.acc_scope)
-    s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
-    s[y_buf].set_scope(env.acc_scope)
-    s[y_buf].pragma(y_buf.op.axis[0], env.alu)
-    s[y].pragma(y.op.axis[0], env.dma_copy)
-
-    def verify():
-        # build
-        with vta.build_config(env.DEBUG_DUMP_INSN):
-            mod = vta.build(s, [x, y], "ext_dev", target)
-        temp = util.tempdir()
-        remote = rpc.connect(host, port)
-        mod.save(temp.relpath("padded_load.o"))
-        remote.upload(temp.relpath("padded_load.o"))
-        f = remote.load_module("padded_load.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        x_np = np.random.randint(1, 2, size=(
-            n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
-        y_np = np.zeros((n + pad_before[0] + pad_after[0],
-                         m + pad_before[1] + pad_after[1],
-                         env.BATCH,
-                         env.BLOCK_OUT)).astype(y.dtype)
-        y_np[pad_before[0]:pad_before[0] + n,
-             pad_before[1]:pad_before[1] + m,
-             :] = x_np
-        x_nd = tvm.nd.array(x_np, ctx)
-        y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
-        f(x_nd, y_nd)
-        np.testing.assert_equal(y_np, y_nd.asnumpy())
-        print("\tFinished verification...")
-
-    if print_ir:
-        print(vta.lower(s, [y, x], simple_mode=True))
-
-
-def test_gemm():
-    """Test GEMM."""
-    env = vta.get_env()
-    # declare
-    o = 4
-    n = 4
-    m = 4
-    x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
-    w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
-    x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
-    w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
-    ko = tvm.reduce_axis((0, n), name="ko")
-    ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
-    y_gem = tvm.compute(
-        (o, m, env.BATCH, env.BLOCK_OUT),
-        lambda bo, co, bi, ci:
-            tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
-                    w_buf[co, ko, ci, ki].astype(env.acc_dtype),
-                    axis=[ko, ki]),
-        name="y_gem")
-    y_shf = tvm.compute(
-        (o, m, env.BATCH, env.BLOCK_OUT),
-        lambda *i: y_gem(*i)>>8,
-        name="y_shf")
-    y_max = tvm.compute(
-        (o, m, env.BATCH, env.BLOCK_OUT),
-        lambda *i: tvm.max(y_shf(*i), 0),
-        "y_max") #relu
-    y_min = tvm.compute(
-        (o, m, env.BATCH, env.BLOCK_OUT),
-        lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
-        "y_min") #relu
-    y = tvm.compute(
-        (o, m, env.BATCH, env.BLOCK_OUT),
-        lambda *i: y_min(*i).astype(env.inp_dtype),
-        name="y")
-
-    def verify(s):
-        mod = vta.build(s, [x, w, y], "ext_dev", target)
-        temp = util.tempdir()
-        remote = rpc.connect(host, port)
-        mod.save(temp.relpath("gemm.o"))
-        remote.upload(temp.relpath("gemm.o"))
-        f = remote.load_module("gemm.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        x_np = np.random.randint(
-            -128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
-        w_np = np.random.randint(
-            -128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
-        y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
-        x_nd = tvm.nd.array(x_np, ctx)
-        w_nd = tvm.nd.array(w_np, ctx)
-        y_nd = tvm.nd.array(y_np, ctx)
-        y_np = y_np.astype(env.acc_dtype)
-        for b in range(o):
-            for i in range(m):
-                for j in range(n):
-                    y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
-                                          w_np[i,j].T.astype(env.acc_dtype))
-        y_np = np.right_shift(y_np, 8)
-        y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
-        f(x_nd, w_nd, y_nd)
-        np.testing.assert_equal(y_np, y_nd.asnumpy())
-        print("\tFinished verification...")
-
-    def test_schedule1():
-        # default schedule with no smt
-        s = tvm.create_schedule(y.op)
-        # set the scope of the SRAM buffers
-        s[x_buf].set_scope(env.SCOPE_INP)
-        s[w_buf].set_scope(env.SCOPE_WGT)
-        s[y_gem].set_scope(env.acc_scope)
-        s[y_shf].set_scope(env.acc_scope)
-        s[y_max].set_scope(env.acc_scope)
-        s[y_min].set_scope(env.acc_scope)
-        # set pragmas for DMA transfer and ALU ops
-        s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
-        s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
-        s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
-        s[y_max].pragma(s[y_max].op.axis[0], env.alu)
-        s[y_min].pragma(s[y_min].op.axis[0], env.alu)
-        s[y].pragma(s[y].op.axis[0], env.dma_copy)
-        # tensorization
-        s[y_gem].reorder(
-            ko,
-            s[y_gem].op.axis[0],
-            s[y_gem].op.axis[1],
-            s[y_gem].op.axis[2],
-            s[y_gem].op.axis[3],
-            ki)
-        s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
-        if print_ir:
-            print(vta.lower(s, [x, w, y], simple_mode=True))
-        if do_verify:
-            with vta.build_config(env.DEBUG_DUMP_INSN):
-                verify(s)
-
-    def test_smt():
-        # test smt schedule
-        s = tvm.create_schedule(y.op)
-        s[x_buf].set_scope(env.SCOPE_INP)
-        s[w_buf].set_scope(env.SCOPE_WGT)
-        s[y_gem].set_scope(env.acc_scope)
-        s[y_shf].set_scope(env.acc_scope)
-        s[y_max].set_scope(env.acc_scope)
-        s[y_min].set_scope(env.acc_scope)
-        abo, aco, abi, aci = s[y].op.axis
-        abo1, abo2 = s[y].split(abo, nparts=2)
-        s[y].bind(abo1, tvm.thread_axis("cthread"))
-        s[y_gem].compute_at(s[y], abo1)
-        s[y_shf].compute_at(s[y], abo1)
-        s[y_max].compute_at(s[y], abo1)
-        s[y_min].compute_at(s[y], abo1)
-        s[y_gem].reorder(
-            ko,
-            s[y_gem].op.axis[0],
-            s[y_gem].op.axis[1],
-            s[y_gem].op.axis[2],
-            s[y_gem].op.axis[3],
-            ki)
-        s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
-        s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
-        s[y_max].pragma(s[y_max].op.axis[0], env.alu)
-        s[y_min].pragma(s[y_min].op.axis[0], env.alu)
-        s[x_buf].compute_at(s[y_gem], ko)
-        s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
-        s[w_buf].compute_at(s[y_gem], ko)
-        s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
-        s[y].pragma(abo2, env.dma_copy)
-        if print_ir:
-            print(vta.lower(s, [x, y, w], simple_mode=True))
-        if do_verify:
-            with vta.build_config(env.DEBUG_DUMP_INSN):
-                verify(s)
-
-    test_schedule1()
-    test_smt()
-
-def test_alu(tvm_op, np_op=None, use_imm=False):
-    """Test ALU"""
-    env = vta.get_env()
-    m = 8
-    n = 8
-    imm = np.random.randint(1,5)
-    # compute
-    a = tvm.placeholder(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        name="a",
-        dtype=env.acc_dtype)
-    a_buf = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: a(*i),
-        "a_buf") #DRAM->SRAM
-    if use_imm:
-        res_buf = tvm.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: tvm_op(a_buf(*i), imm),
-            "res_buf") #compute
-    else:
-        b = tvm.placeholder(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            name="b",
-            dtype=env.acc_dtype)
-        b_buf = tvm.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: b(*i),
-            "b_buf") #DRAM->SRAM
-        res_buf = tvm.compute(
-            (m, n, env.BATCH, env.BLOCK_OUT),
-            lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
-            "res_buf") #compute
-    res = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: res_buf(*i).astype(env.inp_dtype),
-        "res") #SRAM->DRAM
-    # schedule
-    s = tvm.create_schedule(res.op)
-    s[a_buf].set_scope(env.acc_scope) # SRAM
-    s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
-    s[res_buf].set_scope(env.acc_scope) # SRAM
-    s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
-    s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
-    if use_imm:
-        if print_ir:
-            print(vta.lower(s, [a, res], simple_mode=True))
-    else:
-        s[b_buf].set_scope(env.acc_scope) # SRAM
-        s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
-        if print_ir:
-            print(vta.lower(s, [a, b, res], simple_mode=True))
-
-    def verify():
-        # build
-        with vta.build_config():
-            if use_imm:
-                mod = vta.build(s, [a, res], "ext_dev", target)
-            else:
-                mod = vta.build(s, [a, b, res], "ext_dev", target)
-        temp = util.tempdir()
-        remote = rpc.connect(host, port)
-        mod.save(temp.relpath("load_act.o"))
-        remote.upload(temp.relpath("load_act.o"))
-        f = remote.load_module("load_act.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        a_np = np.random.randint(
-            -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
-        if use_imm:
-            res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
-        else:
-            b_np = np.random.randint(
-                -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
-            res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
-        res_np = res_np.astype(res.dtype)
-        a_nd = tvm.nd.array(a_np, ctx)
-        res_nd = tvm.nd.array(
-            np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
-        if use_imm:
-            f(a_nd, res_nd)
-        else:
-            b_nd = tvm.nd.array(b_np, ctx)
-            f(a_nd, b_nd, res_nd)
-        np.testing.assert_equal(res_np, res_nd.asnumpy())
-        print("\tFinished verification...")
-
-    if do_verify:
-        verify()
-
-def test_relu():
-    """Test RELU on ALU"""
-    env = vta.get_env()
-    m = 8
-    n = 8
-    # compute
-    a = tvm.placeholder(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        name="a",
-        dtype=env.acc_dtype)
-    a_buf = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: a(*i),
-        "a_buf") # DRAM->SRAM
-    max_buf = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: tvm.max(a_buf(*i), 0),
-        "res_buf") # relu
-    min_buf = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
-        "max_buf") # relu
-    res = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: min_buf(*i).astype(env.inp_dtype),
-        "min_buf") # SRAM->DRAM
-    # schedule
-    s = tvm.create_schedule(res.op)
-    s[a_buf].set_scope(env.acc_scope) # SRAM
-    s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
-    s[max_buf].set_scope(env.acc_scope) # SRAM
-    s[min_buf].set_scope(env.acc_scope) # SRAM
-    s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
-    s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
-    s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
-    if print_ir:
-        print(vta.lower(s, [a, res], simple_mode=True))
-
-    def verify():
-        # build
-        with vta.build_config(env.DEBUG_DUMP_INSN):
-            mod = vta.build(s, [a, res], "ext_dev", target)
-        temp = util.tempdir()
-        remote = rpc.connect(host, port)
-        mod.save(temp.relpath("load_act.o"))
-        remote.upload(temp.relpath("load_act.o"))
-        f = remote.load_module("load_act.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        a_np = np.random.randint(
-            -256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
-        res_np = np.clip(a_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
-        a_nd = tvm.nd.array(a_np, ctx)
-        res_nd = tvm.nd.array(
-            np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
-        f(a_nd, res_nd)
-        np.testing.assert_equal(res_np, res_nd.asnumpy())
-        print("\tFinished verification...")
-
-    if do_verify:
-        verify()
-
-def test_shift_and_scale():
-    """Test shift and scale on ALU"""
-    env = vta.get_env()
-    m = 8
-    n = 8
-    imm_shift = np.random.randint(-10,10)
-    imm_scale = np.random.randint(1,5)
-    # compute
-    a = tvm.placeholder(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        name="a", dtype=env.acc_dtype)
-    a_buf = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: a(*i),
-        "a_buf") # DRAM->SRAM
-    res_shift = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: a_buf(*i)+imm_shift,
-        "res_shift") # compute
-    res_scale = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: res_shift(*i)>>imm_scale,
-        "res_scale") # compute
-    res = tvm.compute(
-        (m, n, env.BATCH, env.BLOCK_OUT),
-        lambda *i: res_scale(*i).astype(env.inp_dtype),
-        "res") # SRAM->DRAM
-    # schedule
-    s = tvm.create_schedule(res.op)
-    s[a_buf].set_scope(env.acc_scope) # SRAM
-    s[res_shift].set_scope(env.acc_scope) # SRAM
-    s[res_scale].set_scope(env.acc_scope) # SRAM
-    s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
-    s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
-    s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
-    s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
-
-    if print_ir:
-        print(vta.lower(s, [a, res], simple_mode=True))
-
-    def verify():
-        # build
-        mod = vta.build(s, [a, res], "ext_dev", target)
-        temp = util.tempdir()
-        remote = rpc.connect(host, port)
-        mod.save(temp.relpath("load_act.o"))
-        remote.upload(temp.relpath("load_act.o"))
-        f = remote.load_module("load_act.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        a_np = np.random.randint(
-            -10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
-        res_np = np.right_shift((a_np + imm_shift), imm_scale)
-        res_np = res_np.astype(res.dtype)
-        a_nd = tvm.nd.array(a_np, ctx)
-        res_nd = tvm.nd.array(
-            np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
-        f(a_nd, res_nd)
-        np.testing.assert_equal(res_np, res_nd.asnumpy())
-        print("\tFinished verification...")
-
-    if do_verify:
-        verify()
-
-if __name__ == "__main__":
-    print("Load/store test")
-    test_save_load_out()
-    print("Padded load test")
-    test_padded_load()
-    # print("GEMM test")
-    # test_gemm()
-    print("Max immediate")
-    test_alu(tvm.max, np.maximum, use_imm=True)
-    print("Max")
-    test_alu(tvm.max, np.maximum)
-    print("Add immediate")
-    test_alu(lambda x, y: x + y, use_imm=True)
-    print("Add")
-    test_alu(lambda x, y: x + y)
-    print("Shift right immediate")
-    test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
-    print("Shift left immediate")
-    test_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
-    print("Relu")
-    test_relu()
-    # print("Shift and scale")
-    # test_shift_and_scale()
diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py
new file mode 100644
index 000000000..339d8d31e
--- /dev/null
+++ b/vta/tests/python/unittest/test_vta_insn.py
@@ -0,0 +1,482 @@
+"""Unit test VTA's instructions """
+import tvm
+import numpy as np
+import topi
+from tvm.contrib import rpc, util
+
+import vta
+import vta.testing
+from vta.testing import simulator
+
+
+def test_save_load_out():
+    """Test save/store output command"""
+    def _run(env, remote):
+        n = 6
+        x = tvm.placeholder(
+            (n, n, env.BATCH, env.BLOCK_OUT),
+            name="x",
+            dtype=env.acc_dtype)
+        x_buf = tvm.compute(
+            (n, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: x(*i), "x_buf")
+        # insert no-op that won't be optimized away
+        y_buf = tvm.compute(
+            (n, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: x_buf(*i)>>0, "y_buf")
+        y = tvm.compute(
+            (n, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
+        # schedule
+        s = tvm.create_schedule(y.op)
+        s[x_buf].set_scope(env.acc_scope)
+        s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
+        s[y_buf].set_scope(env.acc_scope)
+        s[y_buf].pragma(y_buf.op.axis[0], env.alu)
+        s[y].pragma(y.op.axis[0], env.dma_copy)
+
+        # verification
+        with vta.build_config():
+            m = vta.build(s, [x, y], "ext_dev", env.target_host)
+
+        if not remote:
+            return
+        temp = util.tempdir()
+        m.save(temp.relpath("load_act.o"))
+        remote.upload(temp.relpath("load_act.o"))
+        f = remote.load_module("load_act.o")
+        # verify
+        ctx = remote.ext_dev(0)
+        x_np = np.random.randint(
+            1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
+        y_np = x_np.astype(y.dtype)
+        x_nd = tvm.nd.array(x_np, ctx)
+        y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+        f(x_nd, y_nd)
+        np.testing.assert_equal(y_np, y_nd.asnumpy())
+
+    vta.testing.run(_run)
+
+
+def test_padded_load():
+    """Test padded load."""
+    def _run(env, remote):
+        # declare
+        n = 21
+        m = 20
+        pad_before = [0, 1, 0, 0]
+        pad_after = [1, 3, 0, 0]
+        x = tvm.placeholder(
+            (n, m, env.BATCH, env.BLOCK_OUT),
+            name="x",
+            dtype=env.acc_dtype)
+        x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
+        # insert no-op that won't be optimized away
+        y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
+                             m + pad_before[1] + pad_after[1],
+                             env.BATCH,
+                             env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
+        y = tvm.compute((n + pad_before[0] + pad_after[0],
+                         m + pad_before[1] + pad_after[1],
+                         env.BATCH,
+                         env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
+        # schedule
+        s = tvm.create_schedule(y.op)
+        s[x_buf].set_scope(env.acc_scope)
+        s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
+        s[y_buf].set_scope(env.acc_scope)
+        s[y_buf].pragma(y_buf.op.axis[0], env.alu)
+        s[y].pragma(y.op.axis[0], env.dma_copy)
+        # build
+        with vta.build_config():
+            mod = vta.build(s, [x, y], "ext_dev", env.target_host)
+
+        if not remote:
+            return
+        temp = util.tempdir()
+        mod.save(temp.relpath("padded_load.o"))
+        remote.upload(temp.relpath("padded_load.o"))
+        f = remote.load_module("padded_load.o")
+        # verify
+        ctx = remote.ext_dev(0)
+        x_np = np.random.randint(1, 2, size=(
+            n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
+        y_np = np.zeros((n + pad_before[0] + pad_after[0],
+                         m + pad_before[1] + pad_after[1],
+                         env.BATCH,
+                         env.BLOCK_OUT)).astype(y.dtype)
+        y_np[pad_before[0]:pad_before[0] + n,
+             pad_before[1]:pad_before[1] + m,
+             :] = x_np
+        x_nd = tvm.nd.array(x_np, ctx)
+        y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+        f(x_nd, y_nd)
+        np.testing.assert_equal(y_np, y_nd.asnumpy())
+
+    vta.testing.run(_run)
+
+
+def test_gemm():
+    """Test GEMM."""
+    def _run(env, remote):
+        # declare
+        o = 4
+        n = 1
+        m = 4
+        x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
+        w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
+        x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
+        w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
+        ko = tvm.reduce_axis((0, n), name="ko")
+        ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
+        y_gem = tvm.compute(
+            (o, m, env.BATCH, env.BLOCK_OUT),
+            lambda bo, co, bi, ci:
+            tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
+                    w_buf[co, ko, ci, ki].astype(env.acc_dtype),
+                    axis=[ko, ki]),
+            name="y_gem")
+        y_shf = tvm.compute(
+            (o, m, env.BATCH, env.BLOCK_OUT),
+            lambda *i: y_gem(*i)>>8,
+            name="y_shf")
+        y_max = tvm.compute(
+            (o, m, env.BATCH, env.BLOCK_OUT),
+            lambda *i: tvm.max(y_shf(*i), 0),
+            "y_max") #relu
+        y_min = tvm.compute(
+            (o, m, env.BATCH, env.BLOCK_OUT),
+            lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
+            "y_min") #relu
+        y = tvm.compute(
+            (o, m, env.BATCH, env.BLOCK_OUT),
+            lambda *i: y_min(*i).astype(env.inp_dtype),
+            name="y")
+
+        if not remote:
+            return
+
+        def verify(s):
+            mod = vta.build(s, [x, w, y], "ext_dev", env.target_host)
+            temp = util.tempdir()
+            mod.save(temp.relpath("gemm.o"))
+            remote.upload(temp.relpath("gemm.o"))
+            f = remote.load_module("gemm.o")
+            # verify
+            ctx = remote.ext_dev(0)
+            x_np = np.random.randint(
+                -128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
+            w_np = np.random.randint(
+                -128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
+            y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
+            x_nd = tvm.nd.array(x_np, ctx)
+            w_nd = tvm.nd.array(w_np, ctx)
+            y_nd = tvm.nd.array(y_np, ctx)
+            y_np = y_np.astype(env.acc_dtype)
+            for b in range(o):
+                for i in range(m):
+                    for j in range(n):
+                        y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
+                                              w_np[i,j].T.astype(env.acc_dtype))
+            y_np = np.right_shift(y_np, 8)
+            y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
+
+            if env.TARGET == "sim":
+                simulator.clear_stats()
+                f(x_nd, w_nd, y_nd)
+                print(simulator.stats())
+            else:
+                f(x_nd, w_nd, y_nd)
+
+            np.testing.assert_equal(y_np, y_nd.asnumpy())
+
+        def test_schedule1():
+            # default schedule with no smt
+            s = tvm.create_schedule(y.op)
+            # set the scope of the SRAM buffers
+            s[x_buf].set_scope(env.inp_scope)
+            s[w_buf].set_scope(env.wgt_scope)
+            s[y_gem].set_scope(env.acc_scope)
+            s[y_shf].set_scope(env.acc_scope)
+            s[y_max].set_scope(env.acc_scope)
+            s[y_min].set_scope(env.acc_scope)
+            # set pragmas for DMA transfer and ALU ops
+            s[x_buf].compute_at(s[y_gem], ko)
+            s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
+            s[w_buf].compute_at(s[y_gem], ko)
+            s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
+            s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
+            s[y_max].pragma(s[y_max].op.axis[0], env.alu)
+            s[y_min].pragma(s[y_min].op.axis[0], env.alu)
+            s[y].pragma(s[y].op.axis[0], env.dma_copy)
+            # tensorization
+            s[y_gem].reorder(
+                ko,
+                s[y_gem].op.axis[0],
+                s[y_gem].op.axis[1],
+                s[y_gem].op.axis[2],
+                s[y_gem].op.axis[3],
+                ki)
+            s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
+            verify(s)
+
+        def test_smt():
+            # test smt schedule
+            s = tvm.create_schedule(y.op)
+            s[x_buf].set_scope(env.inp_scope)
+            s[w_buf].set_scope(env.wgt_scope)
+            s[y_gem].set_scope(env.acc_scope)
+            s[y_shf].set_scope(env.acc_scope)
+            s[y_max].set_scope(env.acc_scope)
+            s[y_min].set_scope(env.acc_scope)
+            abo, aco, abi, aci = s[y].op.axis
+            abo1, abo2 = s[y].split(abo, nparts=2)
+            s[y].bind(abo1, tvm.thread_axis("cthread"))
+            s[y_gem].compute_at(s[y], abo1)
+            s[y_shf].compute_at(s[y], abo1)
+            s[y_max].compute_at(s[y], abo1)
+            s[y_min].compute_at(s[y], abo1)
+            s[y_gem].reorder(
+                ko,
+                s[y_gem].op.axis[0],
+                s[y_gem].op.axis[1],
+                s[y_gem].op.axis[2],
+                s[y_gem].op.axis[3],
+                ki)
+            s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
+            s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
+            s[y_max].pragma(s[y_max].op.axis[0], env.alu)
+            s[y_min].pragma(s[y_min].op.axis[0], env.alu)
+            s[x_buf].compute_at(s[y_gem], ko)
+            s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
+            s[w_buf].compute_at(s[y_gem], ko)
+            s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
+            s[y].pragma(abo2, env.dma_copy)
+            verify(s)
+
+        test_schedule1()
+        test_smt()
+    vta.testing.run(_run)
+
+
+def test_alu():
+    def _run(env, remote):
+        def check_alu(tvm_op, np_op=None, use_imm=False):
+            """Test ALU"""
+            m = 8
+            n = 8
+            imm = np.random.randint(1,5)
+            # compute
+            a = tvm.placeholder(
+                (m, n, env.BATCH, env.BLOCK_OUT),
+                name="a",
+                dtype=env.acc_dtype)
+            a_buf = tvm.compute(
+                (m, n, env.BATCH, env.BLOCK_OUT),
+                lambda *i: a(*i),
+                "a_buf") #DRAM->SRAM
+            if use_imm:
+                res_buf = tvm.compute(
+                    (m, n, env.BATCH, env.BLOCK_OUT),
+                    lambda *i: tvm_op(a_buf(*i), imm),
+                    "res_buf") #compute
+            else:
+                b = tvm.placeholder(
+                    (m, n, env.BATCH, env.BLOCK_OUT),
+                    name="b",
+                    dtype=env.acc_dtype)
+                b_buf = tvm.compute(
+                    (m, n, env.BATCH, env.BLOCK_OUT),
+                    lambda *i: b(*i),
+                    "b_buf") #DRAM->SRAM
+                res_buf = tvm.compute(
+                    (m, n, env.BATCH, env.BLOCK_OUT),
+                    lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
+                    "res_buf") #compute5B
+            res = tvm.compute(
+                (m, n, env.BATCH, env.BLOCK_OUT),
+                lambda *i: res_buf(*i).astype(env.inp_dtype),
+                "res") #SRAM->DRAM
+            # schedule
+            s = tvm.create_schedule(res.op)
+            s[a_buf].set_scope(env.acc_scope) # SRAM
+            s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
+            s[res_buf].set_scope(env.acc_scope) # SRAM
+            s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
+            s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
+            if not use_imm:
+                s[b_buf].set_scope(env.acc_scope) # SRAM
+                s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
+
+            if not remote:
+                return
+
+            # build
+            with vta.build_config():
+                if use_imm:
+                    mod = vta.build(s, [a, res], "ext_dev", env.target_host)
+                else:
+                    mod = vta.build(s, [a, b, res], "ext_dev", env.target_host)
+            temp = util.tempdir()
+            mod.save(temp.relpath("load_act.o"))
+            remote.upload(temp.relpath("load_act.o"))
+            f = remote.load_module("load_act.o")
+            # verify
+            ctx = remote.ext_dev(0)
+            a_np = np.random.randint(
+                -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
+            if use_imm:
+                res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
+            else:
+                b_np = np.random.randint(
+                    -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
+                res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
+            res_np = res_np.astype(res.dtype)
+            a_nd = tvm.nd.array(a_np, ctx)
+            res_nd = tvm.nd.array(
+                np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+            if use_imm:
+                f(a_nd, res_nd)
+            else:
+                b_nd = tvm.nd.array(b_np, ctx)
+                f(a_nd, b_nd, res_nd)
+            np.testing.assert_equal(res_np, res_nd.asnumpy())
+
+        check_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
+        check_alu(tvm.max, np.maximum, use_imm=True)
+        check_alu(tvm.max, np.maximum)
+        check_alu(lambda x, y: x + y, use_imm=True)
+        check_alu(lambda x, y: x + y)
+        check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
+
+    vta.testing.run(_run)
+
+
+def test_relu():
+    """Test RELU on ALU"""
+    def _run(env, remote):
+        m = 8
+        n = 10
+        # compute
+        a = tvm.placeholder(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            name="a",
+            dtype=env.acc_dtype)
+        a_buf = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: a(*i),
+            "a_buf") # DRAM->SRAM
+        max_buf = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: tvm.max(a_buf(*i), 0),
+            "res_buf") # relu
+        min_buf = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
+            "max_buf") # relu
+        res = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: min_buf(*i).astype(env.inp_dtype),
+            "min_buf") # SRAM->DRAM
+        # schedule
+        s = tvm.create_schedule(res.op)
+        s[a_buf].set_scope(env.acc_scope) # SRAM
+        s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
+        s[max_buf].set_scope(env.acc_scope) # SRAM
+        s[min_buf].set_scope(env.acc_scope) # SRAM
+        s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
+        s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
+        s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
+        # build
+        with vta.build_config():
+            mod = vta.build(s, [a, res], "ext_dev", env.target_host)
+        if not remote:
+            return
+        temp = util.tempdir()
+        mod.save(temp.relpath("load_act.o"))
+        remote.upload(temp.relpath("load_act.o"))
+        f = remote.load_module("load_act.o")
+        # verify
+        ctx = remote.ext_dev(0)
+        a_np = np.random.randint(
+            -256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
+        res_np = np.clip(a_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
+        a_nd = tvm.nd.array(a_np, ctx)
+        res_nd = tvm.nd.array(
+            np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+        f(a_nd, res_nd)
+        np.testing.assert_equal(res_np, res_nd.asnumpy())
+
+    vta.testing.run(_run)
+
+
+def test_shift_and_scale():
+    """Test shift and scale on ALU"""
+    def _run(env, remote):
+        m = 2
+        n = 8
+        imm_shift = np.random.randint(0,8)
+        imm_scale = np.random.randint(1,5)
+        # compute
+        a = tvm.placeholder(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            name="a", dtype=env.acc_dtype)
+        a_buf = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: a(*i),
+            "a_buf") # DRAM->SRAM
+        res_shift = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: a_buf(*i)+imm_shift,
+            "res_shift") # compute
+        res_scale = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: res_shift(*i)>>imm_scale,
+            "res_scale") # compute
+        res = tvm.compute(
+            (m, n, env.BATCH, env.BLOCK_OUT),
+            lambda *i: res_scale(*i).astype(env.inp_dtype),
+            "res") # SRAM->DRAM
+        # schedule
+        s = tvm.create_schedule(res.op)
+        s[a_buf].set_scope(env.acc_scope) # SRAM
+        s[res_shift].set_scope(env.acc_scope) # SRAM
+        s[res_scale].set_scope(env.acc_scope) # SRAM
+        s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
+        s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
+        s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
+        s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
+        # build
+        mod = vta.build(s, [a, res], "ext_dev", env.target_host)
+        if not remote:
+            return
+        temp = util.tempdir()
+        mod.save(temp.relpath("load_act.o"))
+        remote.upload(temp.relpath("load_act.o"))
+        f = remote.load_module("load_act.o")
+        # verify
+        ctx = remote.ext_dev(0)
+        a_np = np.random.randint(
+            -10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
+        res_np = np.right_shift((a_np + imm_shift), imm_scale)
+        res_np = res_np.astype(res.dtype)
+        a_nd = tvm.nd.array(a_np, ctx)
+        res_nd = tvm.nd.array(
+            np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+        f(a_nd, res_nd)
+        np.testing.assert_equal(res_np, res_nd.asnumpy())
+
+    vta.testing.run(_run)
+
+if __name__ == "__main__":
+    print("Load/store test")
+    test_save_load_out()
+    print("Padded load test")
+    #test_padded_load()
+    print("GEMM test")
+    test_gemm()
+    test_alu()
+    print("ALU test")
+    test_relu()
+    print("Shift and scale")
+    test_shift_and_scale()
-- 
GitLab