From 96488c11097529c3e98cdc466888506178053e4d Mon Sep 17 00:00:00 2001 From: Thierry Moreau <moreau@cs.washington.edu> Date: Thu, 22 Mar 2018 21:42:12 -0400 Subject: [PATCH] [PYTHON, TVM] Python TVM library, unit tests and end to end example * VTA python library * Python unit tests * End to end example with Resnet18 * README instructions * Bug fixes --- vta/Makefile | 6 +- vta/apps/pynq_rpc/README.md | 80 ++ vta/apps/pynq_rpc/start_rpc_server.sh | 2 +- vta/examples/resnet18/pynq/.gitignore | 5 + vta/examples/resnet18/pynq/README.md | 98 +++ .../resnet18/pynq/imagenet_predict.py | 174 ++++ vta/hardware/{vivado => xilinx}/.gitignore | 0 vta/hardware/{vivado => xilinx}/Makefile | 4 +- vta/hardware/xilinx/README.md | 73 ++ .../{vivado => xilinx}/scripts/hls.tcl | 0 .../{vivado => xilinx}/scripts/hsi.tcl | 0 .../{vivado => xilinx}/scripts/vivado.tcl | 0 .../{vivado => xilinx}/sim/vta_test.cc | 2 + vta/hardware/{vivado => xilinx}/src/vta.cc | 105 +-- vta/hardware/{vivado => xilinx}/src/vta.h | 30 +- vta/make/config.mk | 2 +- vta/python/vta/__init__.py | 13 +- vta/python/vta/arm_conv2d.py | 335 ++++++++ vta/python/vta/build.py | 55 ++ vta/python/vta/graph.py | 348 ++++++++ vta/python/vta/hw_spec.py | 70 ++ vta/python/vta/intrin.py | 183 +++++ vta/python/vta/ir_pass.py | 762 ++++++++++++++++++ vta/python/vta/mock.py | 7 + vta/python/vta/runtime.py | 42 + vta/python/vta/vta_conv2d.py | 373 +++++++++ vta/src/pynq/pynq_driver.h | 2 +- vta/src/runtime.cc | 4 +- vta/tests/hardware/common/test_lib.h | 2 +- vta/tests/hardware/pynq/Makefile | 2 +- .../python/pynq/test_benchmark_conv2d.py | 414 ++++++++++ vta/tests/python/pynq/test_benchmark_gemm.py | 267 ++++++ vta/tests/python/pynq/test_benchmark_topi.py | 144 ++++ vta/tests/python/pynq/test_program_rpc.py | 21 + vta/tests/python/pynq/test_vta_insn.py | 498 ++++++++++++ 35 files changed, 4046 insertions(+), 77 deletions(-) create mode 100644 vta/apps/pynq_rpc/README.md create mode 100644 vta/examples/resnet18/pynq/.gitignore create mode 100644 vta/examples/resnet18/pynq/README.md create mode 100644 vta/examples/resnet18/pynq/imagenet_predict.py rename vta/hardware/{vivado => xilinx}/.gitignore (100%) rename vta/hardware/{vivado => xilinx}/Makefile (95%) create mode 100644 vta/hardware/xilinx/README.md rename vta/hardware/{vivado => xilinx}/scripts/hls.tcl (100%) rename vta/hardware/{vivado => xilinx}/scripts/hsi.tcl (100%) rename vta/hardware/{vivado => xilinx}/scripts/vivado.tcl (100%) rename vta/hardware/{vivado => xilinx}/sim/vta_test.cc (91%) rename vta/hardware/{vivado => xilinx}/src/vta.cc (90%) rename vta/hardware/{vivado => xilinx}/src/vta.h (93%) create mode 100644 vta/python/vta/arm_conv2d.py create mode 100644 vta/python/vta/build.py create mode 100644 vta/python/vta/graph.py create mode 100644 vta/python/vta/hw_spec.py create mode 100644 vta/python/vta/intrin.py create mode 100644 vta/python/vta/ir_pass.py create mode 100644 vta/python/vta/mock.py create mode 100644 vta/python/vta/runtime.py create mode 100644 vta/python/vta/vta_conv2d.py create mode 100644 vta/tests/python/pynq/test_benchmark_conv2d.py create mode 100644 vta/tests/python/pynq/test_benchmark_gemm.py create mode 100644 vta/tests/python/pynq/test_benchmark_topi.py create mode 100644 vta/tests/python/pynq/test_program_rpc.py create mode 100644 vta/tests/python/pynq/test_vta_insn.py diff --git a/vta/Makefile b/vta/Makefile index 74c23d691..0137929c9 100644 --- a/vta/Makefile +++ b/vta/Makefile @@ -55,10 +55,10 @@ endif all: lib/libvta.$(SHARED_LIBRARY_SUFFIX) VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc) -ifeq ($(TARGET), PYNQ_TARGET) +ifeq ($(TARGET), VTA_PYNQ_TARGET) VTA_LIB_SRC += $(wildcard src/pynq/*.cc) LDFLAGS += -L/usr/lib -lsds_lib - LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -l:libdma.so + LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/ -l:libdma.so endif VTA_LIB_OBJ = $(patsubst %.cc, build/%.o, $(VTA_LIB_SRC)) @@ -79,7 +79,7 @@ cpplint: python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests pylint: - pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc + pylint python/tvm_vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc doc: doxygen docs/Doxyfile diff --git a/vta/apps/pynq_rpc/README.md b/vta/apps/pynq_rpc/README.md new file mode 100644 index 000000000..4cc2b46fb --- /dev/null +++ b/vta/apps/pynq_rpc/README.md @@ -0,0 +1,80 @@ +### PYNQ RPC Server for VTA + +This guide describes how to setup a Pynq-based RPC server to accelerate deep learning workloads with VTA. + +## Pynq Setup + +Follow the getting started tutorial for the [Pynq board](http://pynq.readthedocs.io/en/latest/getting_started.html). +* For this RPC setup make sure to go with the *Connect to a Computer* Ethernet setup. + +Make sure that you can ssh into your Pynq board successfully: +```bash +ssh xilinx@192.168.2.99 +``` + +When ssh-ing onto the board, the default password for the `xilinx` account is `xilinx`. + +For convenience let's go ahead and mount the Pynq board's file system to easily access it and maintain it: +```bash +sshfs xilinx@192.168.2.99:/home/xilinx <mountpoint> +``` + +## Pynq TVM & VTA installation + +On your **host PC**, go to the `<mountpoint>` directory of your Pynq board file system. +```bash +cd <mountpoint> +``` + +From there, clone the VTA repository: +```bash +git clone git@github.com:uwsaml/vta.git --recursive +``` + +Next, clone the TVM repository: +```bash +git clone git@github.com:dmlc/tvm.git --recursive +``` + +TVM is rapidly changing, and to ensure stability, we keep track of working TVM checkpoints. +As of now, the TVM checkpoint `e4c2af9abdcb3c7aabafba8084414d7739c17c4c` is known to work with VTA. +```bash +git checkout e4c2af9abdcb3c7aabafba8084414d7739c17c4c +``` + +Now, ssh into your **Pynq board** to build the TVM runtime with the following commands: +```bash +ssh xilinx@192.168.2.99 # ssh if you haven't done so +cd ~/tvm +cp make/config.mk . +echo USE_RPC=1 >> config.mk +make runtime -j2 +``` + +## Pynq RPC server setup + +We're now ready to build the Pynq RPC server on the Pynq board. +```bash +ssh xilinx@192.168.2.99 # ssh if you haven't done so +cd ~/vta +export TVM_PATH = /home/xilinx/tvm +make +``` + +The last stage will build the `192.168.2.99:home/xilinx/vta/lib/libvta.so` library file. We are now ready to launch the RPC server on the Pynq. In order to enable the FPGA drivers, we need to run the RPC server with administrator privileges (using `su`, account: `xilinx`, pwd: `xilinx`). +```bash +ssh xilinx@192.168.2.99 # ssh if you haven't done so +cd ~/vta +su +./apps/pynq_rpc/start_rpc_server.sh +``` + +You should see the following being displayed when starting the RPC server: +``` +INFO:root:Load additional library /home/xilinx/vta/lib/libvta.so +INFO:root:RPCServer: bind to 0.0.0.0:9091 +``` + +Note that it should be listening on port `9091`. + +To kill the RPC server, just enter the `Ctrl + c` command. \ No newline at end of file diff --git a/vta/apps/pynq_rpc/start_rpc_server.sh b/vta/apps/pynq_rpc/start_rpc_server.sh index d5a1202a1..950da23ce 100755 --- a/vta/apps/pynq_rpc/start_rpc_server.sh +++ b/vta/apps/pynq_rpc/start_rpc_server.sh @@ -1,4 +1,4 @@ #!/bin/bash export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/lib/ python -m tvm.exec.rpc_server --load-library /home/xilinx/vta/lib/libvta.so diff --git a/vta/examples/resnet18/pynq/.gitignore b/vta/examples/resnet18/pynq/.gitignore new file mode 100644 index 000000000..11686e2f9 --- /dev/null +++ b/vta/examples/resnet18/pynq/.gitignore @@ -0,0 +1,5 @@ +quantize_graph.json +quantize_params.pkl +synset.txt +*.jpg +vta.bit \ No newline at end of file diff --git a/vta/examples/resnet18/pynq/README.md b/vta/examples/resnet18/pynq/README.md new file mode 100644 index 000000000..5d35fcbdd --- /dev/null +++ b/vta/examples/resnet18/pynq/README.md @@ -0,0 +1,98 @@ +# Resnet-18 Example on Pynq-based VTA Design + +In order to run this example you'll need to have: +* VTA installed +* TVM installed +* NNVM installed +* A Pynq-based RPC server running + +## VTA installation + +Clone the VTA repository in the directory of your choosing: +```bash +git clone git@github.com:uwsaml/vta.git --recursive +``` + +Update your `~/.bashrc` file to include the VTA python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!): +```bash +export PYTHONPATH=<vta root>/python:${PYTHONPATH} +``` + +## TVM installation + +Clone the TVM repository in the directory of your choosing: +```bash +git clone git@github.com:dmlc/tvm.git --recursive +``` + +TVM is rapidly changing, and to ensure stability, we keep track of working TVM checkpoints. +As of now, the TVM checkpoint `e4c2af9abdcb3c7aabafba8084414d7739c17c4c` is known to work with VTA. +```bash +git checkout e4c2af9abdcb3c7aabafba8084414d7739c17c4c +``` + +Before building TVM, copy the `make/config.mk` file into the root TVM directory: +```bash +cd <tvm root> +cp make/config.mk . +``` + +In the 'config.mk' file sure that: +* `LLVM_CONFIG` points to the llvm-config executable (e.g. `LLVM_CONFIG = /usr/bin/llvm-config-4.0`). You'll need to have llvm4.0 installed or later. +* `USE_RPC` should be set to 1 + +Launch the compilation, this takes about 5 minutes. +```bash +cd <tvm root> +make -j4 +``` + +Finally update your `~/.bashrc` file to include the TVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!): +```bash +export PYTHONPATH=<tvm root>/python:<tvm root>/topi/python:${PYTHONPATH} +``` + +## NNVM installation + +Clone the NNVM repository from `tqchen` in the directory of your choosing: +```bash +git clone git@github.com:tqchen/nnvm.git --recursive +``` + +To run this example, we rely on a special branch of NNVM: `qt`: +```bash +cd <nnvm root> +git checkout qt +``` + +Launch the compilation, this takes less a minute. +```bash +cd <nnvm root> +make -j4 +``` + +Finally update your `~/.bashrc` file to include the NNVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!): +```bash +export PYTHONPATH=<nnvm root>/python:${PYTHONPATH} +``` + +## Pynq RPC Server Setup + +Follow the [Pynq RPC Server Guide](https://github.com/saml/vta/tree/master/apps/pynq_rpc/README.md) + +## Running the example + +Simply run the following python script: +```bash +python imagenet_predict.py +``` + +This will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`. + +The script reports runtime measured on the Pynq board, and the top-1 result category: +``` +('x', (1, 3, 224, 224)) +Build complete... +('TVM prediction top-1:', 281, 'tabby, tabby cat') +t-cost=0.41906 +``` \ No newline at end of file diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py new file mode 100644 index 000000000..8221281bd --- /dev/null +++ b/vta/examples/resnet18/pynq/imagenet_predict.py @@ -0,0 +1,174 @@ +# some standard imports +import nnvm +import tvm +from nnvm.compiler import graph_attr +import vta +import os +import numpy as np +from PIL import Image +import pickle +import json +import logging +import wget +from tvm.contrib import graph_runtime, rpc, util + +factor = 16 +host = "pynq" +port = 9091 +verbose = False +# only run fpga component, mark non-conv ops as nop +debug_fpga_only = False + +# Obtain model and hardware files (they're too large to check-in) +url = "https://homes.cs.washington.edu/~moreau/media/vta/" +TEST_FILE = 'cat.jpg' +CATEG_FILE = 'synset.txt' +RESNET_GRAPH_FILE = 'quantize_graph.json' +RESNET_PARAMS_FILE = 'quantize_params.pkl' +BITSTREAM_FILE = 'vta.bit' +for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITSTREAM_FILE]: + if not os.path.isfile(file): + print "Downloading {}".format(file) + wget.download(url+file) + +# Program the FPGA remotely +assert tvm.module.enabled("rpc") +remote = rpc.connect(host, port) +remote.upload(BITSTREAM_FILE, BITSTREAM_FILE) +fprogram = remote.get_function("tvm.contrib.vta.init") +fprogram(BITSTREAM_FILE) + +if verbose: + logging.basicConfig(level=logging.INFO) + +# Change to -device=tcpu to run cpu only inference. +target = "llvm -device=vta" + +synset = eval(open(os.path.join(CATEG_FILE)).read()) +image = Image.open(os.path.join(TEST_FILE)).resize((224, 224)) + +def transform_image(image): + image = np.array(image) - np.array([123., 117., 104.]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + return image + +def mark_nop(graph, conv_layer=-1, skip_conv_layer=()): + """Helper function to mark certain op as nop + + Useful to debug performance issues. + """ + jgraph = json.loads(graph.json()) + counter = 0 + for nid, node in enumerate(jgraph["nodes"]): + op_name = node["op"] + if op_name != "tvm_op": + continue + attrs = node["attrs"] + node_name = node["name"] + func_name = attrs["func_name"] + if func_name.find("quantized_conv2d") != -1: + if conv_layer >= 0: + if counter != conv_layer: + attrs["func_name"] = "__nop" + if counter in skip_conv_layer: + attrs["func_name"] = "__nop" + counter += 1 + else: + if conv_layer >= 0: + attrs["func_name"] = "__nop" + attrs["func_name"] = "__nop" + if attrs["func_name"] != "__nop": + print("Run function %s"% func_name) + graph = nnvm.graph.load_json(json.dumps(jgraph)) + return graph + +x = transform_image(image) +print('x', x.shape) + +###################################################################### +# now compile the graph +import nnvm.compiler +np.random.seed(0) +sym = nnvm.graph.load_json( + open(os.path.join(RESNET_GRAPH_FILE)).read()) +params = pickle.load( + open(os.path.join(RESNET_PARAMS_FILE))) + +shape_dict = {"data": x.shape} +dtype_dict = {"data": 'float32'} +shape_dict.update({k: v.shape for k, v in params.items()}) +dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + +graph = nnvm.graph.create(sym) +graph_attr.set_shape_inputs(sym, shape_dict) +graph_attr.set_dtype_inputs(sym, dtype_dict) +graph = graph.apply("InferShape").apply("InferType") + +dtype = "float32" +sym = vta.graph.remove_stochastic(sym) +sym = vta.graph.clean_cast(sym) +sym = vta.graph.clean_conv_fuse(sym) +if "vta" in target: + sym = vta.graph.pack(sym, shape_dict, factor) + +graph_attr.set_shape_inputs(sym, shape_dict) +sym = sym.apply("InferShape") +graph_attr.set_dtype_inputs(sym, dtype_dict) +sym = sym.apply("InferType") + +with nnvm.compiler.build_config(opt_level=3): + bdict = {} + if "vta" not in target: + bdict = {"add_lower_pass": []} + else: + bdict = {"add_lower_pass": vta.debug_mode(0)} + with tvm.build_config(**bdict): + graph, lib, params = nnvm.compiler.build( + sym, target, shape_dict, dtype_dict, + params=params) + +remote = rpc.connect(host, port) +temp = util.tempdir() +lib.save(temp.relpath("graphlib.o")) +remote.upload(temp.relpath("graphlib.o")) +lib = remote.load_module("graphlib.o") +ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0) + +print("Build complete...") + +def run_e2e(graph): + """Running end to end example + """ + if debug_fpga_only: + graph = mark_nop(graph, skip_conv_layer=(0,)) + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input('data', tvm.nd.array(x.astype("float32"))) + m.set_input(**params) + # execute + timer = m.module.time_evaluator("run", ctx, number=10) + tcost = timer() + # get outputs + tvm_output = m.get_output( + 0,tvm.nd.empty((1000,), dtype, remote.cpu(0))) + top1 = np.argmax(tvm_output.asnumpy()) + print('TVM prediction top-1:', top1, synset[top1]) + print("t-cost=%g" % tcost.mean) + + +def run_layer(old_graph): + """Run a certain layer.""" + for layer_id in range(1, 2): + graph = mark_nop(old_graph, layer_id) + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input('data', tvm.nd.array(x.astype("float32"))) + m.set_input(**params) + # execute + timer = m.module.time_evaluator("run", ctx, number=10) + tcost = timer() + print("resnet[%d]: %g\n"% (layer_id, tcost.mean)) + +run_e2e(graph) diff --git a/vta/hardware/vivado/.gitignore b/vta/hardware/xilinx/.gitignore similarity index 100% rename from vta/hardware/vivado/.gitignore rename to vta/hardware/xilinx/.gitignore diff --git a/vta/hardware/vivado/Makefile b/vta/hardware/xilinx/Makefile similarity index 95% rename from vta/hardware/vivado/Makefile rename to vta/hardware/xilinx/Makefile index f3d779ee2..c7ffc7ec9 100644 --- a/vta/hardware/vivado/Makefile +++ b/vta/hardware/xilinx/Makefile @@ -1,6 +1,6 @@ # Directories ROOTDIR = $(CURDIR) -BUILD_DIR = $(ROOTDIR)/../../build/hardware/vivado +BUILD_DIR = $(ROOTDIR)/../../build/hardware/xilinx SCRIPT_DIR = $(ROOTDIR)/scripts SRC_DIR = $(ROOTDIR)/src SIM_DIR = $(ROOTDIR)/sim @@ -64,7 +64,7 @@ bit: ip cd $(HW_BUILD_PATH) && \ $(VIVADO) -mode tcl -source $(SCRIPT_DIR)/vivado.tcl \ -tclargs $(IP_BUILD_PATH) $(VTA_HW_COMP_THREADS) $(VTA_HW_COMP_CLOCK_FREQ) \ - $(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(OUT_WIDTH) \ + $(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(VTA_OUT_WIDTH) \ $(VTA_BATCH) $(VTA_IN_BLOCK) $(VTA_OUT_BLOCK) \ $(VTA_INP_BUFF_SIZE) $(VTA_WGT_BUFF_SIZE) $(VTA_OUT_BUFF_SIZE) diff --git a/vta/hardware/xilinx/README.md b/vta/hardware/xilinx/README.md new file mode 100644 index 000000000..0c68724ea --- /dev/null +++ b/vta/hardware/xilinx/README.md @@ -0,0 +1,73 @@ +# Hardware Compilation Guide + +**This hardware compilation guide aims to provide guidance on generating VTA bitstreams with the Xilinx Vivado toolchains.** + +As of writing this guide, we recommend using `Vivado 2017.1` since our scripts have been tested to work on this version of the Xilinx toolchains. + +# Vivado Toolchains Installation for Pynq Board + +## Ubuntu instructions + +You’ll need to install Xilinx’ FPGA compilation toolchain, [Vivado HL WebPACK 2017.1](https://www.xilinx.com/products/design-tools/vivado.html), which a license-free version of the Vivado HLx toolchain. + +### Obtaining and launching the installation binary + +1. Go to the [download webpage](https://www.xilinx.com/support/download.html), and download the Linux Self Extracting Web Installer for Vivado HL 2017.1 WebPACK and Editions. +2. You’ll have to sign in with a Xilinx account. This requires a Xilinx account creation that will take 2 minutes. +3. Complete the Name and Address Verification by clicking “Nextâ€, and you will get the opportunity to download a binary file, called `Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin`. +4. Now that the file is downloaded, go to your `Downloads` directory, and change the file permissions so it can be executed: +```bash +chmod u+x Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin +``` +5. Now you can execute the binary: +```bash +./Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin +``` + +### Installation Steps + +At this point you've launched the Vivado 2017.1 Installer GUI program. + +1. Click “Next†on the **Welcome** screen. +2. Enter your Xilinx User Credentials under “User Authentication†and select the “Download and Install Now†before clicking “Next†on the **Select Install Type** screen. +3. Accept all terms before clicking on “Next†on the **Accept License Agreements** screen. +4. Select the “Vivado HL WebPACK†before clicking on “Next†on the **Select Edition to Install** screen. +5. Under the **Vivado HL WebPACK** screen, before hitting “Next", check the following options (the rest should be unchecked): + * Design Tools -> Vivado Design Suite -> Vivado + * Design Tools -> Vivado Design Suite -> Vivado High Level Synthesis + * Devices -> Production Services -> SoCs -> Zynq-7000 Series +6. Your total download size should be about 3GB and the amount of Disk Space Required 13GB. +7. Set the installation directory before clicking “Next†on the **Select Destination Directory** screen. It might highlight some paths as red - that’s because the installer doesn’t have the permission to write to that directory. In that case select a path that doesn’t require special write permissions (e.g. in your home directory). +8. Hit “Install†under the **Installation Summary** screen. +9. An **Installation Progress Window** will pop-up to track progress of the download and the installation. +10. This process will take about 20-30 minutes depending on your connection speed. +11. A pop-up window will inform you that the installation completed successfully. Click "OK". +12. Finally the **Vivado License Manager** will launch. Select "Get Free ISE WebPACK, ISE/Vivado IP or PetaLinux License" and click "Connect Now" to complete the license registration process. + +### Environment Setup + +The last step is to update your `~/.bashrc` with the following line: +```bash +# Xilinx Vivado 2017.1 environment +source <install_path>/Vivado/2017.1/settings64.sh +``` + +This will include all of the Xilinx binary paths so you can launch compilation scripts from the command line. + +Note that this will overwrite the paths to GCC required to build TVM, or NNVM. Therefore, when attempting to build TVM and NNVM, please comment this line from your `~/.bashrc` before re-sourcing it. + +# Bitstream compilation + +High-level parameters are listed under `<vta root>/make/config.mk` and can be customized by the user. + +Bitstream generation is driven by a makefile. All it takes is to enter the following command: +```bash +make +``` + +The local `Makefile` containts several variables that can be tweaked by the user: +* `VTA_HW_COMP_THREADS`: determines the number of threads used for the Vivado compilation job (default 8 threads). +* `VTA_HW_COMP_CLOCK_FREQ`: determines the target frequency of the VTA design (default 100MHz). It can only be set to 100, 142, 167 or 200MHz. +* `VTA_HW_COMP_TIMING_COMP`: determines how much additional slack must be provided to close timing (default 0ns). Generally when utilization is high for an FPGA design, setting this paramter to 1, 2 or 3 can help close timing. + +Once the compilation completes, the generated bitstream can be found under `<vta root>/build/hardware/xilinx/vivado/<design name>/export/vta.bit`. \ No newline at end of file diff --git a/vta/hardware/vivado/scripts/hls.tcl b/vta/hardware/xilinx/scripts/hls.tcl similarity index 100% rename from vta/hardware/vivado/scripts/hls.tcl rename to vta/hardware/xilinx/scripts/hls.tcl diff --git a/vta/hardware/vivado/scripts/hsi.tcl b/vta/hardware/xilinx/scripts/hsi.tcl similarity index 100% rename from vta/hardware/vivado/scripts/hsi.tcl rename to vta/hardware/xilinx/scripts/hsi.tcl diff --git a/vta/hardware/vivado/scripts/vivado.tcl b/vta/hardware/xilinx/scripts/vivado.tcl similarity index 100% rename from vta/hardware/vivado/scripts/vivado.tcl rename to vta/hardware/xilinx/scripts/vivado.tcl diff --git a/vta/hardware/vivado/sim/vta_test.cc b/vta/hardware/xilinx/sim/vta_test.cc similarity index 91% rename from vta/hardware/vivado/sim/vta_test.cc rename to vta/hardware/xilinx/sim/vta_test.cc index 16f37a866..60685b384 100644 --- a/vta/hardware/vivado/sim/vta_test.cc +++ b/vta/hardware/xilinx/sim/vta_test.cc @@ -40,6 +40,8 @@ int main(void) { status |= alu_test(VTA_ALU_OPCODE_ADD, true, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, false); + status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, true); + status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, false); // Run ALU test (vector-vector operators) status |= alu_test(VTA_ALU_OPCODE_MIN, false, 16, 128, true); diff --git a/vta/hardware/vivado/src/vta.cc b/vta/hardware/xilinx/src/vta.cc similarity index 90% rename from vta/hardware/vivado/src/vta.cc rename to vta/hardware/xilinx/src/vta.cc index f628b749d..815dc014b 100644 --- a/vta/hardware/vivado/src/vta.cc +++ b/vta/hardware/xilinx/src/vta.cc @@ -13,9 +13,9 @@ void fetch( uint32_t insn_count, volatile insn_T *insns, - hls::stream<insn_T> *load_queue, - hls::stream<insn_T> *gemm_queue, - hls::stream<insn_T> *store_queue) { + hls::stream<insn_T> &load_queue, + hls::stream<insn_T> &gemm_queue, + hls::stream<insn_T> &store_queue) { #pragma HLS INTERFACE s_axilite port = insn_count bundle = CONTROL_BUS #pragma HLS INTERFACE m_axi port = insns offset = slave bundle = ins_port #pragma HLS INTERFACE axis port = load_queue @@ -32,12 +32,12 @@ void fetch( memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0); // Push to appropriate instruction queue if (opcode == VTA_OPCODE_STORE) { - store_queue->write(insn); + store_queue.write(insn); } else if (opcode == VTA_OPCODE_LOAD && (memory_type == VTA_MEM_ID_INP || memory_type == VTA_MEM_ID_WGT)) { - load_queue->write(insn); + load_queue.write(insn); } else { - gemm_queue->write(insn); + gemm_queue.write(insn); } } } @@ -45,9 +45,9 @@ void fetch( void load( volatile inp_vec_T *inputs, volatile wgt_vec_T *weights, - hls::stream<insn_T> *load_queue, - hls::stream<bool> *g2l_dep_queue, - hls::stream<bool> *l2g_dep_queue, + hls::stream<insn_T> &load_queue, + hls::stream<bool> &g2l_dep_queue, + hls::stream<bool> &l2g_dep_queue, inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH], wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT] ) { @@ -61,7 +61,7 @@ void load( #pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS // Pop load instruction - insn_T insn = load_queue->read(); + insn_T insn = load_queue.read(); // Decode instruction bool pop_prev_dependence = insn[VTA_INSN_MEM_1]; @@ -81,7 +81,7 @@ void load( // Pop dependence token if instructed if (pop_next_dependence) { - g2l_dep_queue->read(); + g2l_dep_queue.read(); } // Initialize indices @@ -170,19 +170,19 @@ void load( // Push dependence token if instructed if (push_next_dependence) { - l2g_dep_queue->write(1); + l2g_dep_queue.write(1); } } void compute( - volatile uint32_t *done, + volatile uint32_t &done, volatile uop_T *uops, volatile acc_vec_T *biases, - hls::stream<insn_T> *gemm_queue, - hls::stream<bool> *l2g_dep_queue, - hls::stream<bool> *s2g_dep_queue, - hls::stream<bool> *g2l_dep_queue, - hls::stream<bool> *g2s_dep_queue, + hls::stream<insn_T> &gemm_queue, + hls::stream<bool> &l2g_dep_queue, + hls::stream<bool> &s2g_dep_queue, + hls::stream<bool> &g2l_dep_queue, + hls::stream<bool> &g2s_dep_queue, out_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH], wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT], out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH] @@ -210,7 +210,7 @@ void compute( #pragma HLS ARRAY_PARTITION variable = acc_mem complete dim = 2 // Pop GEMM instruction - insn_T insn = gemm_queue->read(); + insn_T insn = gemm_queue.read(); // Decode opcode_T opcode = insn.range(VTA_INSN_MEM_0_1, VTA_INSN_MEM_0_0); @@ -221,19 +221,19 @@ void compute( // Pop dependence token if instructed if (pop_prev_dependence) { - l2g_dep_queue->read(); + l2g_dep_queue.read(); } if (pop_next_dependence) { - s2g_dep_queue->read(); + s2g_dep_queue.read(); } // Perform action based on opcode if (opcode == VTA_OPCODE_FINISH) { // Set done flag if we reach a FINISH instruction - *done = 1; + done = 1; } else if (opcode == VTA_OPCODE_LOAD || opcode == VTA_OPCODE_STORE) { // Set done value - *done = 0; + done = 0; // Decode instruction memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0); @@ -283,7 +283,7 @@ void compute( } } else if (opcode == VTA_OPCODE_GEMM || opcode == VTA_OPCODE_ALU) { // Set done value - *done = 0; + done = 0; // Decode uop_idx_T uop_bgn = insn.range(VTA_INSN_GEM_5_1, VTA_INSN_GEM_5_0); @@ -383,6 +383,7 @@ void compute( } else if (opcode == VTA_OPCODE_ALU) { // Iterate over micro op READ_ALU_UOP: for (int upc = uop_bgn; upc < uop_end; upc++) { +#pragma HLS PIPELINE II = 2 rewind // Read micro-op fields uop_T uop = uop_mem[upc]; @@ -405,14 +406,15 @@ void compute( // Result matrices acc_vec_T cmp_res[VTA_BATCH]; acc_vec_T add_res[VTA_BATCH]; - acc_vec_T shr_res[VTA_BATCH]; + acc_vec_T rshr_res[VTA_BATCH]; + acc_vec_T lshr_res[VTA_BATCH]; out_vec_T short_cmp_res[VTA_BATCH]; out_vec_T short_add_res[VTA_BATCH]; - out_vec_T short_shr_res[VTA_BATCH]; + out_vec_T short_rshr_res[VTA_BATCH]; + out_vec_T short_lshr_res[VTA_BATCH]; // Perform ALU op over matrix elements for (int i = 0; i < VTA_BATCH; i++) { -#pragma HLS PIPELINE II = 1 rewind // Results vector acc_vec_T res_vec = 0; for (int b = 0; b < VTA_BLOCK_OUT; b++) { @@ -434,12 +436,18 @@ void compute( add_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = add_val; short_add_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = (inp_T) add_val.range(VTA_OUT_WIDTH - 1, 0); - // Compute Shift - acc_T shr_val = + // Compute Right Shift + acc_T rshr_val = src_0 >> (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0); - shr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = shr_val; - short_shr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = - (inp_T) shr_val.range(VTA_OUT_WIDTH-1, 0); + rshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = rshr_val; + short_rshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = + (inp_T) rshr_val.range(VTA_OUT_WIDTH-1, 0); + // Compute Left Shift + acc_T lshr_val = + src_0 << (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0); + lshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = lshr_val; + short_lshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = + (inp_T) lshr_val.range(VTA_OUT_WIDTH-1, 0); } // Store to accum memory/store buffer @@ -451,8 +459,11 @@ void compute( acc_mem[dst_idx][i] = add_res[i]; out_mem[dst_idx][i] = short_add_res[i]; } else if (alu_opcode == VTA_ALU_OPCODE_SHR) { - acc_mem[dst_idx][i] = shr_res[i]; - out_mem[dst_idx][i] = short_shr_res[i]; + acc_mem[dst_idx][i] = rshr_res[i]; + out_mem[dst_idx][i] = short_rshr_res[i]; + } else if (alu_opcode == VTA_ALU_OPCODE_SHL) { + acc_mem[dst_idx][i] = lshr_res[i]; + out_mem[dst_idx][i] = short_lshr_res[i]; } } } @@ -473,18 +484,18 @@ void compute( // Push dependence token if instructed if (push_prev_dependence) { - g2l_dep_queue->write(1); + g2l_dep_queue.write(1); } if (push_next_dependence) { - g2s_dep_queue->write(1); + g2s_dep_queue.write(1); } } void store( volatile out_vec_T *outputs, - hls::stream<insn_T> *store_queue, - hls::stream<bool> *g2s_dep_queue, - hls::stream<bool> *s2g_dep_queue, + hls::stream<insn_T> &store_queue, + hls::stream<bool> &g2s_dep_queue, + hls::stream<bool> &s2g_dep_queue, out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH] ) { #pragma HLS INTERFACE m_axi port = outputs offset = slave bundle = data_port @@ -495,7 +506,7 @@ void store( #pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS // Load buffer - insn_T insn = store_queue->read(); + insn_T insn = store_queue.read(); // Decode bool pop_prev_dependence = insn[VTA_INSN_MEM_1]; @@ -515,7 +526,7 @@ void store( // Pop dependence token if instructed if (pop_prev_dependence) { - g2s_dep_queue->read(); + g2s_dep_queue.read(); } // Initialize indices @@ -546,7 +557,7 @@ void store( // Push dependence token if instructed if (push_prev_dependence) { - s2g_dep_queue->write(1); + s2g_dep_queue.write(1); } } @@ -589,7 +600,7 @@ void vta( out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]; // Push all instructions into the queues - fetch(insn_count, insns, &tmp_load_queue, &tmp_gemm_queue, &tmp_store_queue); + fetch(insn_count, insns, tmp_load_queue, tmp_gemm_queue, tmp_store_queue); // Global done indicator uint32_t done = 0; @@ -621,7 +632,7 @@ void vta( // Push the instruction in the load queue load_queue.write(tmp_load); tmp_load_popped = false; - load(inputs, weights, &load_queue, &g2l_dep_queue, &l2g_dep_queue, inp_mem, wgt_mem); + load(inputs, weights, load_queue, g2l_dep_queue, l2g_dep_queue, inp_mem, wgt_mem); } else { // Execution of load stage pending on completion of other stages, so break here... break; @@ -649,8 +660,8 @@ void vta( // Push the instruction in the load queue gemm_queue.write(tmp_gemv); tmp_gemm_popped = false; - compute(&done, uops, biases, &gemm_queue, &l2g_dep_queue, &s2g_dep_queue, - &g2l_dep_queue, &g2s_dep_queue, inp_mem, wgt_mem, out_mem); + compute(done, uops, biases, gemm_queue, l2g_dep_queue, s2g_dep_queue, + g2l_dep_queue, g2s_dep_queue, inp_mem, wgt_mem, out_mem); } else { // Execution of load stage pending on completion of other stages, // so break here... @@ -671,7 +682,7 @@ void vta( // Push the instruction in the load queue store_queue.write(tmp_store); tmp_store_popped = false; - store(outputs, &store_queue, &g2s_dep_queue, &s2g_dep_queue, out_mem); + store(outputs, store_queue, g2s_dep_queue, s2g_dep_queue, out_mem); } else { // Execution of load stage pending on completion of other stages, so break here... break; diff --git a/vta/hardware/vivado/src/vta.h b/vta/hardware/xilinx/src/vta.h similarity index 93% rename from vta/hardware/vivado/src/vta.h rename to vta/hardware/xilinx/src/vta.h index 37395722f..e01ef90bb 100644 --- a/vta/hardware/vivado/src/vta.h +++ b/vta/hardware/xilinx/src/vta.h @@ -107,9 +107,9 @@ typedef ap_uint<VTA_LOG_ACC_WIDTH> aluop_sh_imm_T; void fetch( uint32_t insn_count, volatile insn_T *insns, - hls::stream<insn_T> *load_queue, - hls::stream<insn_T> *gemm_queue, - hls::stream<insn_T> *store_queue); + hls::stream<insn_T> &load_queue, + hls::stream<insn_T> &gemm_queue, + hls::stream<insn_T> &store_queue); /*! * \brief Load module. @@ -129,9 +129,9 @@ void fetch( void load( volatile inp_vec_T *inputs, volatile wgt_vec_T *weights, - hls::stream<insn_T> *load_queue, - hls::stream<bool> *g2l_dep_queue, - hls::stream<bool> *l2g_dep_queue, + hls::stream<insn_T> &load_queue, + hls::stream<bool> &g2l_dep_queue, + hls::stream<bool> &l2g_dep_queue, inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH], wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT]); @@ -159,14 +159,14 @@ void load( * \param out_mem Local output SRAM buffer. Write only single port BRAM. */ void compute( - volatile uint32_t *done, + volatile uint32_t &done, volatile uop_T *uops, volatile acc_vec_T *biases, - hls::stream<insn_T> *gemm_queue, - hls::stream<bool> *l2g_dep_queue, - hls::stream<bool> *s2g_dep_queue, - hls::stream<bool> *g2l_dep_queue, - hls::stream<bool> *g2s_dep_queue, + hls::stream<insn_T> &gemm_queue, + hls::stream<bool> &l2g_dep_queue, + hls::stream<bool> &s2g_dep_queue, + hls::stream<bool> &g2l_dep_queue, + hls::stream<bool> &g2s_dep_queue, out_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH], wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT], out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]); @@ -186,9 +186,9 @@ void compute( */ void store( volatile out_vec_T *outputs, - hls::stream<insn_T> *store_queue, - hls::stream<bool> *g2s_dep_queue, - hls::stream<bool> *s2g_dep_queue, + hls::stream<insn_T> &store_queue, + hls::stream<bool> &g2s_dep_queue, + hls::stream<bool> &s2g_dep_queue, out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]); /*! diff --git a/vta/make/config.mk b/vta/make/config.mk index 06143d777..062dfa8c3 100644 --- a/vta/make/config.mk +++ b/vta/make/config.mk @@ -84,7 +84,7 @@ VTA_ACC_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_ACC_BUFF_SIZE) ))" ) VTA_LOG_OUT_BUFF_SIZE = \ $(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_ACC_WIDTH) ))" ) # Out buffer size in Bytes -VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(LOG_OUT_BUFF_SIZE) ))" ) +VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" ) # Update ADD_CFLAGS ADD_CFLAGS += \ diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index 59a494fc7..4a6f760d0 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -1,5 +1,12 @@ -"""VTA Python package backed by TVM""" +"""TVM VTA runtime""" +from __future__ import absolute_import as _abs +from .hw_spec import * -# version of this package -__version__ = "0.1.0" +from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU +from .intrin import GEVM, GEMM +from .build import debug_mode + +from . import mock, ir_pass +from . import arm_conv2d, vta_conv2d +from . import graph diff --git a/vta/python/vta/arm_conv2d.py b/vta/python/vta/arm_conv2d.py new file mode 100644 index 000000000..9e46ee7f8 --- /dev/null +++ b/vta/python/vta/arm_conv2d.py @@ -0,0 +1,335 @@ +# pylint: disable=invalid-name,unused-variable,invalid-name +"""Conv2D schedule ported from RASP + +Used for CPU conv2d +""" +from __future__ import absolute_import as _abs + +import tvm +from topi import tag +from topi.nn.conv2d import conv2d, _get_schedule +from topi.nn.conv2d import SpatialPack, Im2ColPack, Workload +from topi.nn.conv2d import _SCH_TO_DECL_FUNC +from topi.nn.conv2d import _get_workload +from topi.nn.util import infer_pad, infer_stride +from topi import generic + +_WORKLOADS = [ + Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2), + Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), +] +_SCHEDULES = [ + # float32 imagenet + SpatialPack(1, 8, 4, 1, 4, True), + SpatialPack(1, 7, 4, 2, 4, True), + SpatialPack(1, 4, 8, 4, 1, True), + SpatialPack(1, 4, 4, 1, 16, False), + SpatialPack(1, 4, 8, 4, 8, False), + SpatialPack(1, 7, 4, 3, 8, True), + SpatialPack(1, 2, 8, 1, 8, True), + SpatialPack(2, 1, 16, 1, 4, True), + SpatialPack(1, 7, 4, 1, 1, True), + Im2ColPack(7, 4, 1, 16, True), + Im2ColPack(7, 4, 1, 8, False), + Im2ColPack(7, 4, 1, 16, False), +] + +@_get_schedule.register(["tcpu", "vta"]) +def _schedule_conv2d(wkl): + if wkl not in _WORKLOADS: + raise ValueError("no schedule for such workload: {}".format(wkl)) + idx = _WORKLOADS.index(wkl) + sch = _SCHEDULES[idx] + return sch + + +@conv2d.register(["tcpu", "vta"]) +def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype): + assert layout == 'NCHW', "only support NCHW convolution on tcpu" + assert data.shape[0].value == 1, "only support batch size=1 convolution on tcpu" + wkl = _get_workload(data, kernel, stride, padding, out_dtype) + sch = _get_schedule(wkl) + return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype) + + +def _schedule_spatial_conv2d(s, data, data_pad, data_vec, + kernel, kernel_vec, + conv_out, output, last): + # no stride and padding info here + padding = infer_pad(data, data_pad) + if data_pad is None: + stride = infer_stride(data, kernel, output) + else: + stride = infer_stride(data_pad, kernel, output) + wkl = _get_workload(data, kernel, stride, padding, output.dtype) + sch = _get_schedule(wkl) + + H, W = wkl.height, wkl.width + CI, CO = wkl.in_filter, wkl.out_filter + HK, WK = wkl.hkernel, wkl.wkernel + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + + HCAT, WCAT = HK-1, WK-1 + DOPAD = (HPAD != 0 and WPAD != 0) + + VH = sch.vh + VW = sch.vw + VC = sch.vc + UNROLL = sch.unroll + + A, B, C = data, kernel, last + A0, A1 = data_pad, data_vec + B0 = kernel_vec + C0, C1 = conv_out, output + + CC = s.cache_write(C0, "global") + + _, co, oh, ow, vh, vw, vc = s[C0].op.axis + if UNROLL: + s[C0].unroll(vw) + s[C0].vectorize(vc) + + s[CC].compute_at(s[C0], ow) + _, co, oh, ow, vh, vw, vc = s[CC].op.axis + ci, dh, dw = s[CC].op.reduce_axis + s[CC].reorder(ci, dh, vh, dw, vw, vc) + + if UNROLL: + s[CC].unroll(vw) + s[CC].vectorize(vc) + + ##### Schedule A + if DOPAD: + s[A0].compute_inline() + + _, h, _, _, _, _ = s[A1].op.axis + if sch.ba == 1: + oaxis = h + paxis = h + else: + oh, ih = s[A1].split(h, sch.ba) + oaxis = oh + paxis = ih + + s[A1].parallel(paxis) + s[A1].pragma(oaxis, "parallel_launch_point") + s[A1].pragma(paxis, "parallel_stride_pattern") + s[A1].pragma(oaxis, "parallel_barrier_when_finish") + + + ##### Schedule B + co, _, _, _, _ = s[B0].op.axis + if sch.bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[B0].split(co, sch.bc) + oaxis = oco + paxis = ico + + s[B0].parallel(paxis) + s[B0].pragma(oaxis, "parallel_launch_point") + s[B0].pragma(paxis, "parallel_stride_pattern") + s[B0].pragma(oaxis, "parallel_barrier_when_finish") + + + ##### Schedule C + n, co, h, w = s[C].op.axis + co, vc = s[C].split(co, VC) + oh, ow, vh, vw = s[C].tile(h, w, VH, VW) + s[C].reorder(n, co, oh, ow, vh, vw, vc) + if C != C1: + s[C1].compute_inline() + s[C0].compute_at(s[C], ow) + + if sch.bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[C].split(co, sch.bc) + oaxis = oco + paxis = ico + + s[C].parallel(paxis) + s[C].pragma(oaxis, "parallel_launch_point") + s[C].pragma(paxis, "parallel_stride_pattern") + s[C].pragma(oaxis, "parallel_barrier_when_finish") + + return s + +def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, + kernel, kernel_vec, + conv_out, output, last): + # no stride and padding info here + padding = infer_pad(data, data_pad) + if data_pad is None: + stride = infer_stride(data, kernel, output) + else: + stride = infer_stride(data_pad, kernel, output) + wkl = _get_workload(data, kernel, stride, padding, output.dtype) + + sch = _get_schedule(wkl) + + H, W = wkl.height, wkl.width + CI = wkl.in_filter + CO = wkl.out_filter + HK, WK = wkl.hkernel, wkl.wkernel + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + + HCAT, WCAT = HK-1, WK-1 + DOPAD = (HPAD != 0 and WPAD != 0) + + P = sch.vp + Q = sch.vq + UNROLL = sch.unroll + + A, B, C = data, kernel, last + A0, A1, A2 = data_pad, data_col, data_vec + B0 = kernel_vec + C0, C1 = conv_out, output + + CC = s.cache_write(C0, "global") + AA = s.cache_read(A2, "global", [CC]) + BB = s.cache_read(B0, "global", [CC]) + + + ##### Schedule CC + _, co, im, vim, vco = s[C0].op.axis + s[C0].unroll(vim) + s[C0].vectorize(vco) + + s[CC].compute_at(s[C0], im) + _, co, im, vim, vco = s[CC].op.axis + ci, hk, wk = s[CC].op.reduce_axis + s[CC].reorder(ci, hk, wk, vim, vco) + s[CC].unroll(vim) + s[CC].vectorize(vco) + # s[CC].unroll(ccr) + + ### Schedule C + _, co, h, w = s[C].op.axis + im = s[C].fuse(h, w) + im, vim = s[C].split(im, P) + co, vco = s[C].split(co, Q) + s[C].reorder(co, im, vim, vco) + + if sch.bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[C].split(co, sch.bc) + oaxis = oco + paxis = ico + + s[C].parallel(paxis) + s[C].pragma(oaxis, "parallel_launch_point") + s[C].pragma(paxis, "parallel_stride_pattern") + s[C].pragma(oaxis, "parallel_barrier_when_finish") + if C1 != C: + s[C1].compute_inline() + + s[C0].compute_at(s[C], paxis) + + ##### Schedule A + if DOPAD: + s[A0].compute_inline() + s[A1].compute_inline() + s[AA].compute_at(s[CC], wk) + s[AA].unroll(AA.op.axis[4]) + + _, im, _, _, _, _ = s[A2].op.axis + if sch.ba == 1: + oaxis = im + paxis = im + else: + oim, iim = s[A2].split(im, sch.ba) + oaxis = oim + paxis = iim + + s[A2].parallel(paxis) + s[A2].pragma(oaxis, "parallel_launch_point") + s[A2].pragma(paxis, "parallel_stride_pattern") + s[A2].pragma(oaxis, "parallel_barrier_when_finish") + + + ##### Schedule B + s[BB].compute_at(s[CC], wk) + s[BB].vectorize(BB.op.axis[4]) + + co, _, _, _, _ = s[B0].op.axis + if sch.bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[B0].split(co, sch.bc) + oaxis = oco + paxis = ico + + s[B0].parallel(paxis) + s[B0].pragma(oaxis, "parallel_launch_point") + s[B0].pragma(paxis, "parallel_stride_pattern") + s[B0].pragma(oaxis, "parallel_barrier_when_finish") + + return s + +@generic.schedule_conv2d_nchw.register(["tcpu", "vta"]) +def schedule_conv2d(outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + + if 'spatial_conv_output' in op.tag: + output = op.output(0) + conv_out = op.input_tensors[0] + kernel_vec = conv_out.op.input_tensors[1] + kernel = kernel_vec.op.input_tensors[0] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + _schedule_spatial_conv2d(s, data, data_pad, data_vec, + kernel, kernel_vec, + conv_out, output, outs[0]) + + if 'im2col_conv_output' in op.tag: + output = op.output(0) + conv_out = op.input_tensors[0] + kernel_vec = conv_out.op.input_tensors[1] + kernel = kernel_vec.op.input_tensors[0] + data_vec = conv_out.op.input_tensors[0] + data_col = data_vec.op.input_tensors[0] + data = data_col.op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, + kernel, kernel_vec, + conv_out, output, outs[0]) + + traverse(outs[0].op) + return s diff --git a/vta/python/vta/build.py b/vta/python/vta/build.py new file mode 100644 index 000000000..8f422e1fa --- /dev/null +++ b/vta/python/vta/build.py @@ -0,0 +1,55 @@ +"""Runtime function related hooks""" +from __future__ import absolute_import as _abs + +import tvm +from tvm import build_module +from . runtime import CB_HANDLE +from . import ir_pass + + +def lift_coproc_scope(x): + x = ir_pass.lift_alloc_to_scope_begin(x) + x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False) + return x + +def early_rewrite(stmt): + try: + return tvm.ir_pass.StorageRewrite(stmt) + except tvm.TVMError: + return stmt + + +def debug_mode(debug_flag): + """Pass to enable vta debug mode. + + Parameters + ---------- + debug_flag : int + The dbeug flag to be passed. + + Returns + ------- + pass_list: list of function + The pass to set to build_config(add_lower_pass=vta.debug_mode(mode)) + """ + def add_debug(stmt): + debug = tvm.call_extern( + "int32", "VTASetDebugMode", CB_HANDLE, debug_flag) + return tvm.make.stmt_seq(debug, stmt) + pass_list = [(1, ir_pass.inject_dma_intrin), + (1, ir_pass.inject_skip_copy), + (1, ir_pass.annotate_alu_coproc_scope), + (1, lambda x: tvm.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)), + (1, lift_coproc_scope), + (1, ir_pass.inject_coproc_sync), + (1, early_rewrite)] + if debug_flag: + pass_list.append((1, add_debug)) + pass_list.append((2, ir_pass.inject_alu_intrin)) + pass_list.append((3, ir_pass.fold_uop_loop)) + pass_list.append((3, ir_pass.cpu_access_rewrite)) + return pass_list + + +# Add a lower pass to sync uop +build_module.BuildConfig.current.add_lower_pass = debug_mode(0) diff --git a/vta/python/vta/graph.py b/vta/python/vta/graph.py new file mode 100644 index 000000000..b8237980d --- /dev/null +++ b/vta/python/vta/graph.py @@ -0,0 +1,348 @@ +"""Graph transformation specific to accelerator. + +This module provide specific NNVM graph transformations +to transform a generic NNVM graph to a version that can +be executed on accelerator. +""" + +import nnvm + +from nnvm.compiler import graph_attr, graph_util + + +def _pack_channel(data, dshape, factor): + """Pack the data channel dimension. + """ + assert dshape[1] % factor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0], dshape[1] // factor, + factor, dshape[2], dshape[3])) + data = nnvm.sym.transpose( + data, axes=(0, 1, 3, 4, 2)) + return data + + +def _unpack_channel(data, old_shape): + """Unpack the data channel dimension. + """ + data = nnvm.sym.transpose(data, axes=(0, 1, 4, 2, 3)) + data = nnvm.sym.reshape(data, shape=old_shape) + return data + + +def _pack_weight(data, dshape, factor): + """Pack the weight into packed format. + """ + assert len(dshape) == 4 + assert dshape[0] % factor == 0 + assert dshape[1] % factor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0] // factor, factor, + dshape[1] // factor, factor, + dshape[2], dshape[3])) + data = nnvm.sym.transpose( + data, axes=(0, 2, 4, 5, 1, 3)) + return data + + +def _pack_bias(data, dshape, factor): + """Pack the bias parameter. + """ + assert len(dshape) == 3 + assert dshape[0] % factor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0] // factor, + factor, dshape[1], dshape[2])) + data = nnvm.sym.transpose( + data, axes=(0, 2, 3, 1)) + return data + + +def _get_shape(sym, shape_dict): + """Get the shape of a node. + """ + return graph_util.infer_shape( + nnvm.graph.create(sym), **shape_dict)[1][0] + + +def remove_stochastic(graph): + """ + Replace stochastic rounding and shift with determinstic version. + + Parameters + ---------- + graph : Graph + The input graph + + Returns + ------- + replaced_graph : Graph + The final replaced graph. + """ + gidx = graph.index + node_map = {} + + for nid, node in enumerate(gidx.nodes): + children = [node_map[e[0]] for e in node["inputs"]] + attrs = node.get("attrs", {}) + node_name = node["name"] + op_name = node["op"] + get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)( + *c, name=n_n, **a) + if op_name == "null": + new_node = nnvm.symbol.Variable(node_name) + elif op_name == "stochastic_round": + new_node = children[0] + elif op_name == "noise_lshift": + new_node = nnvm.symbol.left_shift( + children[0], **attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + node_map[nid] = new_node + + assert len(graph.index.output_entries) == 1 + ret = node_map[graph.index.output_entries[0][0]] + ret = nnvm.graph.create(ret) + return ret + + +def clean_conv_fuse(graph): + """Cleanup the convolution's later fuse stages + + Parameters + ---------- + graph : Graph + Input graph + + Returns + ------- + graph : Graph + Optimized graph + """ + def _clean_entry(entry): + node, flag = entry + if flag: + node = nnvm.symbol.clip(node, a_max=127, a_min=-127) + node = nnvm.symbol.cast(node, dtype="int8") + # Use identity as a hint to block conv2d schedules + node = nnvm.symbol.identity(node) + flag = False + return node, flag + + gidx = graph.index + ref_count = {} + # count reference of each node + for nid, node in enumerate(gidx.nodes): + ref_count[nid] = 0 + for elem in node["inputs"]: + ref_count[elem[0]] += 1 + # construction remap + # entry_id->(new_node, conv_fuse) + # need_fold: bool indicates if we need fold + node_map = {} + + for nid, node in enumerate(gidx.nodes): + children = [node_map[e[0]] for e in node["inputs"]] + attrs = node.get("attrs", {}) + node_name = node["name"] + op_name = node["op"] + get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)( + *c, name=n_n, **a) + + new_entry = None + if op_name == "null": + new_entry = (nnvm.symbol.Variable(node_name), False) + elif op_name in ("cast", "clip"): + if children[0][1]: + new_entry = children[0] + else: + new_entry = ( + get_clone([children[0][0]], op_name, node_name, attrs), + False) + elif op_name == "quantized_conv2d": + data, weight = children + data = _clean_entry(data) + new_node = nnvm.sym.quantized_conv2d( + data[0], weight[0], name=node_name, **attrs) + new_entry = (new_node, True) + elif op_name in ("left_shift", "right_shift", "relu"): + new_entry = ( + get_clone([children[0][0]], op_name, node_name, attrs), + children[0][1]) + elif op_name in ("broadcast_add", "broadcast_mul"): + rhs = children[1][0] + lhs, _ = _clean_entry(children[0]) + lhs = nnvm.sym.cast(lhs, dtype="int32") + rhs = nnvm.sym.cast(rhs, dtype="int32") + new_entry = ( + get_clone([lhs, rhs], op_name, node_name, attrs), + False) + + if new_entry is None: + inputs = [_clean_entry(x) for x in children] + new_entry = ( + get_clone([x[0] for x in inputs], op_name, node_name, attrs), + False) + if ref_count[nid] > 1: + new_entry = _clean_entry(new_entry) + node_map[nid] = new_entry + + assert len(graph.index.output_entries) == 1 + ret = node_map[graph.index.output_entries[0][0]][0] + ret = nnvm.graph.create(ret) + return ret + + +def clean_cast(graph): + """ + Move the casts to early part of graph, + remove uncessary clip operations when possible. + """ + gidx = graph.index + node_map = {} + + def _clean_cast(node, target_type): + op_name = node.attr("op_name") + if op_name == "cast": + return _clean_cast(node.get_children(), target_type) + elif op_name == "relu": + data, has_clip = _clean_cast( + node.get_children(), target_type) + data = nnvm.sym.relu(data) + return data, has_clip + return nnvm.sym.cast(node, dtype=target_type), False + + for nid, node in enumerate(gidx.nodes): + children = [node_map[e[0]] for e in node["inputs"]] + attrs = node.get("attrs", {}) + node_name = node["name"] + op_name = node["op"] + get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)( + *c, name=n_n, **a) + + if op_name == "null": + new_node = nnvm.symbol.Variable(node_name) + elif op_name == "cast": + dtype = attrs["dtype"] + new_node, _ = _clean_cast(children[0], dtype) + elif op_name == "quantized_conv2d": + data, weight = children + data, _ = _clean_cast(data, "int8") + weight, _ = _clean_cast(weight, "int8") + new_node = nnvm.sym.quantized_conv2d( + data, weight, name=node_name, **attrs) + elif op_name == "elemwise_add": + lhs, rhs = children + rhs = nnvm.sym.cast(rhs, dtype="int8") + new_node = nnvm.sym.elemwise_add(lhs, rhs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + node_map[nid] = new_node + + assert len(graph.index.output_entries) == 1 + ret = node_map[graph.index.output_entries[0][0]] + ret = nnvm.graph.create(ret) + return ret + + +def pack(graph, shape_dict, factor, start_name=None): + """Pack the graph into channel packed format. + + Parameters + ---------- + graph : Graph + The input graph. + + shape_dict : dict of str to shapex + The input shape. + + factor : int + The packing factor + + start_name: str, optional + Start name start packing from certain known node. + + Returns + ------- + graph : Graph + The transformed graph. + """ + graph = graph_attr.set_shape_inputs(graph, shape_dict) + graph = graph.apply("InferShape") + shape = graph.json_attr("shape") + gidx = graph.index + node_map = {} + dset = set() + counter = 0 + start_pack = False + + for nid, node in enumerate(gidx.nodes): + children = [node_map[e[0]] for e in node["inputs"]] + ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]] + oshape = shape[gidx.entry_id(nid, 0)] + attrs = node.get("attrs", {}) + node_name = node["name"] + op_name = node["op"] + get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)( + *c, name=n_n, **a) + + if op_name == "null": + new_node = nnvm.symbol.Variable(node_name) + if start_name and node_name == start_name: + start_pack = True + new_node = _pack_channel(new_node, oshape, factor) + elif op_name == "max_pool2d": + assert not start_pack + start_pack = True + new_node = get_clone(children, op_name, node_name, attrs) + new_node = _pack_channel(new_node, oshape, factor) + elif op_name == "global_avg_pool2d": + if start_pack: + start_pack = False + children[0] = _unpack_channel(children[0], ishape[0]) + new_node = getattr(nnvm.symbol, op_name)( + *children, name=node_name, **attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name == "quantized_conv2d": + if start_pack: + attrs["pack_channel"] = str(factor) + data, weight = children + weight = _pack_weight(weight, ishape[1], factor) + new_node = nnvm.sym.quantized_conv2d( + data, weight, name=node_name, **attrs) + elif counter == 1: + attrs["pack_channel"] = str(factor) + data, weight = children + data = _pack_channel(data, ishape[0], factor) + weight = _pack_weight(weight, ishape[1], factor) + new_node = nnvm.sym.quantized_conv2d( + data, weight, name=node_name, **attrs) + new_node = _unpack_channel(new_node, oshape) + counter = counter + 1 + else: + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name.startswith("broadcast"): + if start_pack: + assert len(ishape[1]) == 3 + children[1] = _pack_bias(children[1], ishape[1], factor) + new_node = getattr(nnvm.symbol, op_name)( + *children, name=node_name, **attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name.startswith("elementwise_add"): + new_node = get_clone(children, op_name, node_name, attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + dset.add(op_name) + node_map[nid] = new_node + + assert len(graph.index.output_entries) == 1 + ret = node_map[graph.index.output_entries[0][0]] + if start_pack: + oshape = shape[graph.index.output_entries[0][0]] + ret = _unpack_channel(ret, oshape) + graph = nnvm.graph.create(ret) + graph = graph_attr.set_shape_inputs(graph, shape_dict) + graph = graph.apply("InferShape") + return graph diff --git a/vta/python/vta/hw_spec.py b/vta/python/vta/hw_spec.py new file mode 100644 index 000000000..b6b89df81 --- /dev/null +++ b/vta/python/vta/hw_spec.py @@ -0,0 +1,70 @@ +"""VTA configuration constants (should match hw_spec.h""" +from __future__ import absolute_import as _abs + +# The Constants +VTA_WGT_WIDTH = 8 +VTA_INP_WIDTH = VTA_WGT_WIDTH +VTA_OUT_WIDTH = 32 + +# Dimensions of the GEMM unit +# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT) +VTA_BATCH = 1 +VTA_BLOCK_IN = 16 +VTA_BLOCK_OUT = 16 + +# log-2 On-chip wgt buffer size in Bytes +VTA_LOG_WGT_BUFF_SIZE = 15 +# log-2 On-chip input buffer size in Bytes +VTA_LOG_INP_BUFF_SIZE = 15 +# log-2 On-chip output buffer size in Bytes +VTA_LOG_OUT_BUFF_SIZE = 17 +# On-chip wgt buffer size in Bytes +VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE +# Input buffer size +VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE +# Output buffer size. +VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE + +# Number of bytes per buffer +VTA_INP_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_IN*VTA_INP_WIDTH//8) +VTA_WGT_ELEM_BYTES = (VTA_BLOCK_OUT*VTA_BLOCK_IN*VTA_WGT_WIDTH//8) +VTA_OUT_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_OUT*VTA_OUT_WIDTH//8) + +# Maximum external buffer size in bytes +VTA_MAX_XFER = 1 << 22 + +# Number of elements +VTA_INP_BUFF_DEPTH = VTA_INP_BUFF_SIZE//VTA_INP_ELEM_BYTES +VTA_WGT_BUFF_DEPTH = VTA_WGT_BUFF_SIZE//VTA_WGT_ELEM_BYTES +VTA_OUT_BUFF_DEPTH = VTA_OUT_BUFF_SIZE//VTA_OUT_ELEM_BYTES + +# Memory id for DMA +VTA_MEM_ID_UOP = 0 +VTA_MEM_ID_WGT = 1 +VTA_MEM_ID_INP = 2 +VTA_MEM_ID_ACC = 3 +VTA_MEM_ID_OUT = 4 + +# VTA ALU Opcodes +VTA_ALU_OPCODE_MIN = 0 +VTA_ALU_OPCODE_MAX = 1 +VTA_ALU_OPCODE_ADD = 2 +VTA_ALU_OPCODE_SUB = 3 +VTA_ALU_OPCODE_MUL = 4 +VTA_ALU_OPCODE_SHL = 5 +VTA_ALU_OPCODE_SHR = 6 +VTA_ALU_OPCODE_UNSET = 7 + +# Task queue id (pipeline stage) +VTA_QID_LOAD_INP = 1 +VTA_QID_LOAD_WGT = 1 +VTA_QID_LOAD_OUT = 2 +VTA_QID_STORE_OUT = 3 +VTA_QID_COMPUTE = 2 +VTA_QID_STORE_INP = 3 + +# Debug flags +DEBUG_DUMP_INSN = (1 << 1) +DEBUG_DUMP_UOP = (1 << 2) +DEBUG_SKIP_READ_BARRIER = (1 << 3) +DEBUG_SKIP_WRITE_BARRIER = (1 << 4) \ No newline at end of file diff --git a/vta/python/vta/intrin.py b/vta/python/vta/intrin.py new file mode 100644 index 000000000..73928aa11 --- /dev/null +++ b/vta/python/vta/intrin.py @@ -0,0 +1,183 @@ +"""VTA related intrinsics""" +from __future__ import absolute_import as _abs + +import tvm +from . import hw_spec as spec +from .runtime import VTA_AXIS, VTA_PUSH_UOP, get_task_qid +from .runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT + +# The memory information for the compiler +@tvm.register_func("tvm.info.mem.%s" % SCOPE_INP) +def mem_info_inp_buffer(): + return tvm.make.node("MemoryInfo", + unit_bits=spec.VTA_INP_ELEM_BYTES * 8, + max_simd_bits=spec.VTA_INP_ELEM_BYTES * 8, + max_num_bits=spec.VTA_INP_BUFF_SIZE * 8, + head_address=None) + +@tvm.register_func("tvm.info.mem.%s" % SCOPE_WGT) +def mem_info_wgt_buffer(): + return tvm.make.node("MemoryInfo", + unit_bits=spec.VTA_WGT_ELEM_BYTES * 8, + max_simd_bits=spec.VTA_WGT_ELEM_BYTES * 8, + max_num_bits=spec.VTA_WGT_BUFF_SIZE * 8, + head_address=None) + +@tvm.register_func("tvm.info.mem.%s" % SCOPE_OUT) +def mem_info_out_buffer(): + return tvm.make.node("MemoryInfo", + unit_bits=spec.VTA_OUT_ELEM_BYTES * 8, + max_simd_bits=spec.VTA_OUT_ELEM_BYTES * 8, + max_num_bits=spec.VTA_OUT_BUFF_SIZE * 8, + head_address=None) + +def intrin_gevm(mock=False): + """Vector-matrix multiply intrinsic""" + wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH + assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN + wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN) + assert wgt_shape[0] * wgt_shape[1] == wgt_lanes + inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH + out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH + wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), + dtype="int%d" % spec.VTA_WGT_WIDTH, + name=SCOPE_WGT) + inp = tvm.placeholder((wgt_shape[1], ), + dtype="int%d" % spec.VTA_INP_WIDTH, + name=SCOPE_INP) + k = tvm.reduce_axis((0, wgt_shape[1]), name="k") + out_dtype = "int%d" % spec.VTA_OUT_WIDTH + out = tvm.compute((wgt_shape[0],), + lambda i: tvm.sum(inp[k].astype(out_dtype) * + wgt[i, k].astype(out_dtype), + axis=[k]), + name="out") + wgt_layout = tvm.decl_buffer( + wgt.shape, wgt.dtype, SCOPE_WGT, + scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes) + inp_layout = tvm.decl_buffer( + inp.shape, inp.dtype, SCOPE_INP, + scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes) + out_layout = tvm.decl_buffer( + out.shape, out.dtype, SCOPE_OUT, + scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes) + + def intrin_func(ins, outs): + """Vector-matrix multiply intrinsic function""" + dinp, dwgt = ins + dout = outs[0] + def instr(index): + """Generate vector-matrix multiply VTA instruction""" + irb = tvm.ir_builder.create() + irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) + irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP) + if index == 0 or index == 2: + irb.emit(tvm.call_extern( + "int32", "VTAUopPush", + 0, 0, + dout.access_ptr("rw", "int32"), + dinp.access_ptr("r", "int32"), + dwgt.access_ptr("r", "int32"), + 0, 0, 0)) + else: + irb.emit(tvm.call_extern( + "int32", "VTAUopPush", + 0, 1, + dout.access_ptr("rw", "int32"), + 0, + 0, + 0, 0, 0)) + return irb.get() + # return a triple of normal-set, reset, update + nop = tvm.make.Evaluate(0) + if mock: + return (nop, nop, nop) + return (instr(0), instr(1), instr(2)) + + return tvm.decl_tensor_intrin(out.op, intrin_func, + name="GEVM", + binds={inp: inp_layout, + wgt: wgt_layout, + out: out_layout}) + + +def intrin_gemm(mock=False): + """Matrix-matrix multiply intrinsic""" + wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH + assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN + wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN) + assert wgt_shape[0] * wgt_shape[1] == wgt_lanes + + inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH + assert inp_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_IN + inp_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_IN) + assert inp_shape[0] * inp_shape[1] == inp_lanes + + out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH + assert out_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_OUT + out_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_OUT) + assert out_shape[0] * out_shape[1] == out_lanes + + wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), + dtype="int%d" % spec.VTA_WGT_WIDTH, + name=SCOPE_WGT) + inp = tvm.placeholder((inp_shape[0], inp_shape[1]), + dtype="int%d" % spec.VTA_INP_WIDTH, + name=SCOPE_INP) + k = tvm.reduce_axis((0, wgt_shape[1]), name="k") + out_dtype = "int%d" % spec.VTA_OUT_WIDTH + out = tvm.compute((out_shape[0], out_shape[1]), + lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) * + wgt[j, k].astype(out_dtype), + axis=[k]), + name="out") + wgt_layout = tvm.decl_buffer( + wgt.shape, wgt.dtype, SCOPE_WGT, + scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes) + inp_layout = tvm.decl_buffer( + inp.shape, inp.dtype, SCOPE_INP, + scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes) + out_layout = tvm.decl_buffer( + out.shape, out.dtype, SCOPE_OUT, + scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes) + + def intrin_func(ins, outs): + """Matrix-matrix multiply intrinsic function""" + dinp, dwgt = ins + dout = outs[0] + def instr(index): + """Generate matrix-matrix multiply VTA instruction""" + irb = tvm.ir_builder.create() + irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) + irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP) + if index == 0 or index == 2: + irb.emit(tvm.call_extern( + "int32", "VTAUopPush", + 0, 0, + dout.access_ptr("rw", "int32"), + dinp.access_ptr("r", "int32"), + dwgt.access_ptr("r", "int32"), + 0, 0, 0)) + else: + irb.emit(tvm.call_extern( + "int32", "VTAUopPush", + 0, 1, + dout.access_ptr("rw", "int32"), + 0, + 0, + 0, 0, 0)) + return irb.get() + # return a triple of normal-set, reset, update + nop = tvm.make.Evaluate(0) + if mock: + return (nop, nop, nop) + return (instr(0), instr(1), instr(2)) + + return tvm.decl_tensor_intrin(out.op, intrin_func, + name="GEMM", + binds={inp: inp_layout, + wgt: wgt_layout, + out: out_layout}) + +GEMM = intrin_gemm() +GEVM = intrin_gevm() diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py new file mode 100644 index 000000000..00c73c4ab --- /dev/null +++ b/vta/python/vta/ir_pass.py @@ -0,0 +1,762 @@ +"""Additional IR Pass for VTA""" +from __future__ import absolute_import as _abs + +import tvm +from topi import util as util + +from . import hw_spec as spec +from . runtime import CB_HANDLE, VTA_AXIS, VTA_PUSH_UOP +from . runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT, DMA_COPY, get_task_qid + + +def fold_uop_loop(stmt_in): + """Pass to fold uop loops""" + def _fold_outermost_loop(body): + stmt = body + while not isinstance(stmt, tvm.stmt.For): + if isinstance(stmt, (tvm.stmt.ProducerConsumer, )): + stmt = stmt.body + else: + return None, body, None + + loop_var = stmt.loop_var + gemm_offsets = [None, None, None] + fail = [False] + + def _post_order(op): + assert isinstance(op, tvm.expr.Call) + base_args = 2 + if op.name == "VTAUopPush": + args = [] + args += op.args[:base_args] + for i in range(3): + m = tvm.arith.DetectLinearEquation(op.args[i + base_args], [loop_var]) + if not m: + fail[0] = True + return op + if gemm_offsets[i] is not None: + if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]): + fail[0] = True + return op + args.append(m[1]) + else: + gemm_offsets[i] = m[0] + args.append(m[1]) + args += op.args[base_args+3:] + return tvm.call_extern("int32", "VTAUopPush", *args) + else: + if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): + raise RuntimeError("unexpected op %s" % op) + return op + + ret = tvm.ir_pass.IRTransform( + stmt.body, None, _post_order, ["Call"]) + + if not fail[0] and all(x is not None for x in gemm_offsets): + def _visit(op): + if op.same_as(loop_var): + fail[0] = True + tvm.ir_pass.PostOrderVisit(ret, _visit) + if not fail[0]: + begin = tvm.call_extern( + "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) + end = tvm.call_extern( + "int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets) + return [begin, ret, end] + raise ValueError("Failed to fold the GEMM instructions..") + + def _do_fold(stmt): + if (stmt.attr_key == "coproc_uop_scope" and + isinstance(stmt.value, tvm.expr.StringImm) and + stmt.value.value == VTA_PUSH_UOP.value): + body = stmt.body + begins = [] + ends = [] + try: + begin, body, end = _fold_outermost_loop(body) + if begin is not None: + begins.append(begin) + if end is not None: + ends.append(end) + begin, body, end = _fold_outermost_loop(body) + if begin is not None: + begins.append(begin) + if end is not None: + ends.append(end) + except ValueError: + pass + if body == stmt.body: + return stmt + ends = list(reversed(ends)) + body = tvm.make.stmt_seq(*(begins + [body] + ends)) + return tvm.make.AttrStmt( + stmt.node, stmt.attr_key, stmt.value, body) + return None + out = tvm.ir_pass.IRTransform( + stmt_in, _do_fold, None, ["AttrStmt"]) + return out + + +def cpu_access_rewrite(stmt_in): + """Rewrite the code when there is CPU access happening. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + rw_info = {} + def _post_order(op): + if isinstance(op, tvm.stmt.Allocate): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + return None + new_var = rw_info[buffer_var] + let_stmt = tvm.make.LetStmt( + new_var, tvm.call_extern( + "handle", "VTABufferCPUPtr", CB_HANDLE, + buffer_var), op.body) + alloc = tvm.make.Allocate( + buffer_var, op.dtype, op.extents, + op.condition, let_stmt) + del rw_info[buffer_var] + return alloc + elif isinstance(op, tvm.expr.Load): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + rw_info[buffer_var] = tvm.var( + buffer_var.name + "_ptr", "handle") + new_var = rw_info[buffer_var] + return tvm.make.Load(op.dtype, new_var, op.index) + elif isinstance(op, tvm.stmt.Store): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + rw_info[buffer_var] = tvm.var( + buffer_var.name + "_ptr", "handle") + new_var = rw_info[buffer_var] + return tvm.make.Store(new_var, op.value, op.index) + else: + raise RuntimeError("not reached") + stmt = tvm.ir_pass.IRTransform( + stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) + for buffer_var, new_var in rw_info.items(): + stmt = tvm.make.LetStmt( + new_var, tvm.call_extern( + "handle", "VTABufferCPUPtr", CB_HANDLE, + buffer_var), stmt) + return stmt + + +def lift_alloc_to_scope_begin(stmt_in): + """Lift allocate to beginning of the current scope. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + lift_stmt = [[]] + def _merge_block(slist, body): + for op in slist: + if op.body == body: + body = op + elif isinstance(op, tvm.stmt.Allocate): + body = tvm.make.Allocate( + op.buffer_var, op.dtype, + op.extents, op.condition, body) + elif isinstance(op, tvm.stmt.AttrStmt): + body = tvm.make.AttrStmt( + op.node, op.attr_key, op.value, body) + elif isinstance(op, tvm.stmt.For): + body = tvm.make.For( + op.loop_var, op.min, op.extent, op.for_type, + op.device_api, body) + else: + raise RuntimeError("unexpected op") + del slist[:] + # n = len(slist) + # for i in range(n): + # slist.pop() + return body + + def _pre_order(op): + if isinstance(op, tvm.stmt.For): + lift_stmt.append([]) + elif isinstance(op, tvm.stmt.AttrStmt): + if op.attr_key == "virtual_thread": + lift_stmt.append([]) + + return None + + def _post_order(op): + if isinstance(op, tvm.stmt.Allocate): + lift_stmt[-1].append(op) + return op.body + elif isinstance(op, tvm.stmt.AttrStmt): + if op.attr_key == "storage_scope": + lift_stmt[-1].append(op) + return op.body + elif op.attr_key == "virtual_thread": + return _merge_block(lift_stmt.pop() + [op], op.body) + return op + elif isinstance(op, tvm.stmt.For): + return _merge_block(lift_stmt.pop() + [op], op.body) + else: + raise RuntimeError("not reached") + stmt = tvm.ir_pass.IRTransform( + stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) + assert len(lift_stmt) == 1 + return _merge_block(lift_stmt[0], stmt) + + +def inject_skip_copy(stmt_in): + """Pass to inject skip copy stmt, used in debug. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + def _do_fold(stmt): + if stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy": + return tvm.make.Evaluate(0) + return None + return tvm.ir_pass.IRTransform( + stmt_in, _do_fold, None, ["AttrStmt"]) + + +def inject_coproc_sync(stmt_in): + """Pass to inject skip copy stmt, used in debug. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + success = [False] + def _do_fold(stmt): + if stmt.attr_key == "pragma_scope" and stmt.value.value == "coproc_sync": + success[0] = True + sync = tvm.make.Call( + "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0) + return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync)) + elif stmt.attr_key == "pragma_scope" and stmt.value.value == "trim_loop": + op = stmt.body + assert isinstance(op, tvm.stmt.For) + return tvm.make.For( + op.loop_var, op.min, 2, op.for_type, + op.device_api, op.body) + return None + stmt = tvm.ir_pass.IRTransform( + stmt_in, None, _do_fold, ["AttrStmt"]) + stmt = tvm.ir_pass.CoProcSync(stmt) + return stmt + + +def inject_dma_intrin(stmt_in): + """Pass to inject DMA copy intrinsics. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + def _check_compact(buf): + ndim = len(buf.shape) + size = tvm.const(1, buf.shape[0].dtype) + for i in reversed(range(ndim)): + if not util.equal_const_int(size - buf.strides[i], 0): + raise RuntimeError( + "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)) + size = size * buf.shape[i] + + def _fold_buffer_dim(buf, scope, elem_block): + ndim = len(buf.shape) + x_size = 1 + base = 0 + for i in range(1, ndim + 1): + if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0): + raise RuntimeError("scope %s need need to have block=%d" % (scope, elem_block)) + x_size = x_size * buf.shape[ndim - i] + if util.equal_const_int(x_size - elem_block, 0): + base = i + 1 + break + if base == 0: + raise RuntimeError("scope %s need to have block=%d, shape=%s" % ( + scope, elem_block, buf.shape)) + shape = [elem_block] + strides = [1] + + if (base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block)): + shape.append(1) + strides.append(elem_block) + + while base < ndim + 1: + x_size = 1 + x_stride = buf.strides[ndim - base] + next_base = base + if not util.equal_const_int(x_stride % elem_block, 0): + raise RuntimeError("scope %s need to have block=%d, shape=%s, strides=%s" % ( + scope, elem_block, buf.shape, buf.strides)) + for i in range(base, ndim + 1): + k = ndim - i + if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0): + break + x_size = x_size * buf.shape[k] + next_base = i + 1 + shape.append(tvm.ir_pass.Simplify(x_size)) + strides.append(x_stride) + assert next_base != base + base = next_base + + strides = list(reversed(strides)) + shape = list(reversed(shape)) + return shape, strides + + + def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): + elem_block = elem_bytes * 8 // elem_width + if buf.dtype != dtype: + raise RuntimeError("Expect buffer type to be %s instead of %s" % + (dtype, buf.dtype)) + shape, strides = buf.shape, buf.strides + if not util.equal_const_int(buf.elem_offset % elem_block, 0): + raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) + if allow_fold: + shape, strides = _fold_buffer_dim(buf, scope, elem_block) + else: + shape = list(x for x in shape) + strides = list(x for x in strides) + + def raise_error(): + """Internal function to raise error """ + raise RuntimeError( + ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" + + " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides)) + + ndim = len(shape) + + # Check if the inner-tensor is already flat + flat = util.equal_const_int(shape[-1], elem_block) + + if flat: + if not util.equal_const_int(strides[-1], 1): + raise_error() + + if ndim == 1: + x_size = 1 + x_stride = 1 + y_size = 1 + return x_size, y_size, x_stride, buf.elem_offset / elem_block + if not util.equal_const_int(strides[-2] - elem_block, 0): + raise_error() + + if ndim == 2: + x_size = shape[-2] + x_stride = shape[-2] + y_size = 1 + return x_size, y_size, x_stride, buf.elem_offset / elem_block + if not util.equal_const_int(strides[-3] % elem_block, 0): + raise_error() + + if ndim == 3: + x_size = shape[-2] + x_stride = strides[-3] / elem_block + y_size = shape[-3] + return x_size, y_size, x_stride, buf.elem_offset / elem_block + + else: + if not util.equal_const_int(strides[-1], 1): + raise_error() + if not util.equal_const_int(strides[-2] - shape[-1], 0): + raise_error() + if not util.equal_const_int(shape[-1] * shape[-2], elem_block): + raise_error() + + if ndim == 2: + x_size = 1 + x_stride = 1 + y_size = 1 + return x_size, y_size, x_stride, buf.elem_offset / elem_block + if not util.equal_const_int(strides[-3], elem_block): + raise_error() + + if ndim == 3: + x_size = shape[-3] + x_stride = shape[-3] + y_size = 1 + return x_size, y_size, x_stride, buf.elem_offset / elem_block + if not util.equal_const_int(strides[-4] % elem_block, 0): + raise_error() + + if ndim == 4: + x_size = shape[-3] + x_stride = strides[-4] / elem_block + y_size = shape[-4] + return x_size, y_size, x_stride, buf.elem_offset / elem_block + + raise_error() + + + def _inject_copy(src, dst, pad_before, pad_after, pad_value): + # FIXME: pad_value is ignored... + if dst.scope == "global": + # Store + if pad_before or pad_after: + raise RuntimeError("Do not support copy into DRAM with pad") + if src.scope == SCOPE_OUT: + elem_width = spec.VTA_INP_WIDTH # output compression to inp type + elem_bytes = spec.VTA_INP_ELEM_BYTES # output compression to inp type + mem_type = spec.VTA_MEM_ID_OUT + data_type = "int%d" % spec.VTA_INP_WIDTH + task_qid = spec.VTA_QID_STORE_OUT + else: + raise RuntimeError("Do not support copy %s->dram" % (src.scope)) + _check_compact(src) + x_size, y_size, x_stride, offset = _get_2d_pattern( + dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) + irb = tvm.ir_builder.create() + irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid)) + irb.emit(tvm.call_extern( + "int32", "VTAStoreBuffer2D", CB_HANDLE, + src.access_ptr("r", "int32"), + mem_type, dst.data, offset, x_size, y_size, x_stride)) + return irb.get() + elif src.scope == "global": + if dst.scope == SCOPE_OUT: + elem_width = spec.VTA_OUT_WIDTH + elem_bytes = spec.VTA_OUT_ELEM_BYTES + mem_type = spec.VTA_MEM_ID_ACC + data_type = "int%d" % spec.VTA_OUT_WIDTH + task_qid = spec.VTA_QID_LOAD_OUT + elif dst.scope == SCOPE_INP: + elem_width = spec.VTA_INP_WIDTH + elem_bytes = spec.VTA_INP_ELEM_BYTES + mem_type = spec.VTA_MEM_ID_INP + data_type = "int%d" % spec.VTA_INP_WIDTH + task_qid = spec.VTA_QID_LOAD_INP + elif dst.scope == SCOPE_WGT: + elem_width = spec.VTA_WGT_WIDTH + elem_bytes = spec.VTA_WGT_ELEM_BYTES + mem_type = spec.VTA_MEM_ID_WGT + data_type = "int%d" % spec.VTA_WGT_WIDTH + task_qid = spec.VTA_QID_LOAD_WGT + else: + raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) + # collect pad statistics + if pad_before: + assert pad_after + ndim = len(pad_before) + if ndim <= 2 or ndim > 4: + raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim) + if ndim > 2: + if not util.equal_const_int(pad_before[ndim - 1], 0): + raise ValueError("Do not support pad on the innermost block") + if not util.equal_const_int(pad_after[ndim - 1], 0): + raise ValueError("Do not support pad on the innermost block") + if ndim > 3: + if not util.equal_const_int(pad_before[ndim - 2], 0): + raise ValueError("Do not support pad on the innermost block") + if not util.equal_const_int(pad_after[ndim - 2], 0): + raise ValueError("Do not support pad on the innermost block") + y_pad_before = pad_before[0] + x_pad_before = pad_before[1] + y_pad_after = pad_after[0] + x_pad_after = pad_after[1] + allow_fold = False + else: + x_pad_before = 0 + y_pad_before = 0 + x_pad_after = 0 + y_pad_after = 0 + allow_fold = True + + _check_compact(dst) + x_size, y_size, x_stride, offset = _get_2d_pattern( + src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold) + + irb = tvm.ir_builder.create() + irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid)) + + irb.emit(tvm.call_extern( + "int32", "VTALoadBuffer2D", CB_HANDLE, + src.data, offset, x_size, y_size, x_stride, + x_pad_before, y_pad_before, + x_pad_after, y_pad_after, + dst.access_ptr("r", "int32"), mem_type)) + return irb.get() + + else: + raise RuntimeError("Donot support copy %s->%s" % (src.scope, dst.scope)) + + return tvm.ir_pass.InjectCopyIntrin(stmt_in, DMA_COPY, _inject_copy) + + +def annotate_alu_coproc_scope(stmt_in): + """Pass to insert ALU instruction. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + def _do_fold(stmt): + if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"): + irb = tvm.ir_builder.create() + irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) + irb.scope_attr(VTA_AXIS, "coproc_uop_scope", tvm.make.StringImm("VTAPushALUOp")) + irb.emit(stmt) + return irb.get() + elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"): + return tvm.make.Evaluate(0) + return stmt + + stmt_out = tvm.ir_pass.IRTransform( + stmt_in, None, _do_fold, ["AttrStmt"]) + + return stmt_out + + +def inject_alu_intrin(stmt_in): + """Pass to inject ALU micro-ops. + + Parameters + ---------- + stmt_in : Stmt + Input statement + + Returns + ------- + stmt_out : Stmt + Transformed statement + """ + def _do_fold(stmt): + def _equal(x, y): + return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0) + + def _flatten_loop(src_coeff, dst_coeff, extents): + src_coeff = list(src_coeff) + dst_coeff = list(dst_coeff) + extents = list(extents) + rev_src_coeff = [src_coeff.pop()] + rev_dst_coeff = [dst_coeff.pop()] + rev_extents = [] + assert src_coeff + vsrc = src_coeff.pop() + vdst = dst_coeff.pop() + vext = extents.pop() + while src_coeff: + next_src = src_coeff.pop() + next_dst = dst_coeff.pop() + next_ext = extents.pop() + + if (_equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext)): + vext = tvm.ir_pass.Simplify(vext * next_ext) + else: + rev_src_coeff.append(vsrc) + rev_dst_coeff.append(vdst) + rev_extents.append(vext) + vsrc = next_src + vdst = next_dst + vext = next_ext + rev_src_coeff.append(vsrc) + rev_dst_coeff.append(vdst) + rev_extents.append(vext) + rev_src_coeff.reverse() + rev_dst_coeff.reverse() + rev_extents.reverse() + + return rev_src_coeff, rev_dst_coeff, rev_extents + + if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"): + # Get to the innermost loop body + loop_body = stmt.body + nest_size = 0 + while isinstance(loop_body, tvm.stmt.For): + loop_body = loop_body.body + nest_size += 1 + # Get the src/dst arguments + dst_var = loop_body.buffer_var + dst_idx = loop_body.index + # Derive loop variables and extents + tmp_body = stmt.body + indices = [] + extents = [] + for _ in range(nest_size): + indices.append(tmp_body.loop_var) + extents.append(tmp_body.extent) + tmp_body = tmp_body.body + # Derive opcode + alu_opcode = spec.VTA_ALU_OPCODE_UNSET + if isinstance(loop_body.value, tvm.expr.Add): + alu_opcode = spec.VTA_ALU_OPCODE_ADD + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.expr.Sub): + alu_opcode = spec.VTA_ALU_OPCODE_SUB + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.expr.Mul): + alu_opcode = spec.VTA_ALU_OPCODE_MUL + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.expr.Min): + alu_opcode = spec.VTA_ALU_OPCODE_MIN + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.expr.Max): + alu_opcode = spec.VTA_ALU_OPCODE_MAX + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.expr.Call): + if loop_body.value.name == 'shift_left': + alu_opcode = spec.VTA_ALU_OPCODE_SHL + lhs = loop_body.value.args[0] + rhs = loop_body.value.args[1] + elif loop_body.value.name == 'shift_right': + alu_opcode = spec.VTA_ALU_OPCODE_SHR + lhs = loop_body.value.args[0] + rhs = loop_body.value.args[1] + else: + raise RuntimeError( + "Function call not recognized %s" % (loop_body.value.name)) + else: + raise RuntimeError("Expression not recognized %s" % (type(loop_body.value))) + + # Derive array index coefficients + dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices) + # Check if lhs/rhs is immediate + use_imm = False + imm_val = None + if isinstance(rhs, tvm.expr.IntImm): + assert lhs.buffer_var.same_as(dst_var) + src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) + use_imm = True + imm_val = rhs + if isinstance(lhs, tvm.expr.IntImm): + assert rhs.buffer_var.same_as(dst_var) + src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) + use_imm = True + imm_val = lhs + if imm_val is None: + imm_val = 0 + assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) + src_lhs_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) + src_rhs_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) + # Determine which side has the same coefficients + lhs_equal = True + rhs_equal = True + for i, coef in enumerate(dst_coeff): + if not tvm.ir_pass.Equal(coef, src_lhs_coeff[i]): + lhs_equal = False + if not tvm.ir_pass.Equal(coef, src_rhs_coeff[i]): + rhs_equal = False + # Make sure at least one of the source is identical to the + # destination (in-place computation) + assert lhs_equal or rhs_equal + # Assign the source coefficients + if lhs_equal: + src_coeff = src_rhs_coeff + else: + src_coeff = src_lhs_coeff + + # Ensure that we have the proper tensor dimensions in the + # innermost loop (pattern match) + src_coeff = list(src_coeff) + dst_coeff = list(dst_coeff) + extents = list(extents) + assert len(src_coeff) > 1 + assert len(dst_coeff) > 1 + assert len(extents) > 0 + assert tvm.ir_pass.Equal( + tvm.ir_pass.Simplify( + src_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0) + assert tvm.ir_pass.Equal( + tvm.ir_pass.Simplify( + dst_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0) + assert tvm.ir_pass.Equal(src_coeff[-2], 1) + assert tvm.ir_pass.Equal(dst_coeff[-2], 1) + if spec.VTA_BATCH > 1: + assert len(src_coeff) > 2 + assert len(dst_coeff) > 2 + assert len(extents) > 1 + assert tvm.ir_pass.Equal(src_coeff[-3], spec.VTA_BLOCK_OUT) + assert tvm.ir_pass.Equal(dst_coeff[-3], spec.VTA_BLOCK_OUT) + + # Apply tensorization of the loop coefficients + src_offset = src_coeff[-1] + dst_offset = dst_coeff[-1] + if spec.VTA_BATCH == 1: + src_coeff = src_coeff[:-2] + dst_coeff = dst_coeff[:-2] + extents = extents[:-1] + else: + src_coeff = src_coeff[:-3] + dst_coeff = dst_coeff[:-3] + extents = extents[:-2] + src_coeff.append(src_offset) + dst_coeff.append(dst_offset) + src_coeff = [ + tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in src_coeff] + dst_coeff = [ + tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in dst_coeff] + + # Flatten the outer loops + src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents) + + # Insert ALU micro-ops + irb = tvm.ir_builder.create() + for idx, extent in enumerate(extents): + irb.emit(tvm.call_extern( + "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx])) + irb.emit(tvm.call_extern( + "int32", "VTAUopPush", + 1, 0, + dst_coeff[len(dst_coeff)-1], + src_coeff[len(src_coeff)-1], + 0, + alu_opcode, use_imm, imm_val)) + for extent in extents: + irb.emit(tvm.call_extern( + "int32", "VTAUopLoopEnd")) + return irb.get() + + return stmt + + stmt_out = tvm.ir_pass.IRTransform( + stmt_in, None, _do_fold, ["AttrStmt"]) + return stmt_out + +def debug_print(stmt): + print stmt + return stmt diff --git a/vta/python/vta/mock.py b/vta/python/vta/mock.py new file mode 100644 index 000000000..3b50bb8c7 --- /dev/null +++ b/vta/python/vta/mock.py @@ -0,0 +1,7 @@ +"""Mock interface for skip part of compute """ +from .intrin import intrin_gevm, intrin_gemm + +GEMM = intrin_gemm(True) +GEVM = intrin_gevm(True) +DMA_COPY = "skip_dma_copy" +ALU = "skip_alu" diff --git a/vta/python/vta/runtime.py b/vta/python/vta/runtime.py new file mode 100644 index 000000000..bfcd130ff --- /dev/null +++ b/vta/python/vta/runtime.py @@ -0,0 +1,42 @@ +"""Runtime function related hooks""" +from __future__ import absolute_import as _abs + +import tvm + +def thread_local_command_buffer(): + """Get thread local command buffer""" + ctx = tvm.call_extern("handle", "VTATLSCommandHandle") + return tvm.make.Call( + "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0) + +CB_HANDLE = thread_local_command_buffer() + +VTA_AXIS = tvm.thread_axis("vta") +VTA_PUSH_UOP = tvm.make.StringImm("VTAPushGEMMOp") + +SCOPE_INP = "local.inp_buffer" +SCOPE_OUT = "local.out_buffer" +SCOPE_WGT = "local.wgt_buffer" +DMA_COPY = "dma_copy" +ALU = "alu" +DEBUG_NO_SYNC = False + +def get_task_qid(qid): + """Get transformed queue index.""" + return 1 if DEBUG_NO_SYNC else qid + + +@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync") +def coproc_sync(op): + return tvm.call_extern( + "int32", "VTASynchronize", CB_HANDLE, 1<<31) + +@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push") +def coproc_dep_push(op): + return tvm.call_extern( + "int32", "VTADepPush", CB_HANDLE, op.args[0], op.args[1]) + +@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop") +def coproc_dep_pop(op): + return tvm.call_extern( + "int32", "VTADepPop", CB_HANDLE, op.args[0], op.args[1]) diff --git a/vta/python/vta/vta_conv2d.py b/vta/python/vta/vta_conv2d.py new file mode 100644 index 000000000..97fa42bf1 --- /dev/null +++ b/vta/python/vta/vta_conv2d.py @@ -0,0 +1,373 @@ +"""Namespace for supporting packed_conv2d + ewise variant of nnvm.""" + +from collections import namedtuple + +import logging +import tvm +import topi + +from nnvm.top import registry as reg, OpPattern + +from . import intrin, runtime as vta +from intrin import GEVM + +TARGET_BOARD = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" + +Workload = namedtuple("Conv2DWorkload", + ['height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + + +def packed_conv2d(data, + kernel, + padding, + strides, + out_dtype="int32"): + """ Packed conv2d function. + """ + if padding[0]: + pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0], name="pad_data") + else: + pad_data = data + assert len(data.shape) == 5 + assert len(kernel.shape) == 6 + oheight = topi.util.simplify((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1) + owidth = topi.util.simplify((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1) + oshape = (data.shape[0], kernel.shape[0], oheight, owidth, kernel.shape[4]) + + ishape = topi.util.get_const_tuple(data.shape) + kshape = topi.util.get_const_tuple(kernel.shape) + assert data.dtype == "int8", data.dtype + assert kernel.dtype == "int8", kernel.dtype + d_i = tvm.reduce_axis((0, kshape[2]), name='d_i') + d_j = tvm.reduce_axis((0, kshape[3]), name='d_j') + k_o = tvm.reduce_axis((0, ishape[1]), name='k_o') + k_i = tvm.reduce_axis((0, ishape[-1]), name='k_i') + hstride, wstride = strides + res = tvm.compute( + oshape, + lambda b, co, i, j, ci: tvm.sum( + pad_data[b, k_o, i*hstride+d_i, j*wstride+d_j, k_i].astype(out_dtype) * + kernel[co, k_o, d_i, d_j, ci, k_i].astype(out_dtype), + axis=[k_o, d_i, d_j, k_i]), + name="res", tag="packed_conv2d") + return res + + +@tvm.register_func("nnvm.compiler.build_target", override=True) +def _build(funcs, target, target_host): + tvm_t = tvm.target.create(target) + if tvm_t.device_name == "vta": + return tvm.build(funcs, target="ext_dev", + target_host=TARGET_BOARD) + elif tvm_t.device_name == "rasp" or tvm_t.device_name == "tcpu": + return tvm.build(funcs, target=TARGET_BOARD) + return tvm.build(funcs, target=target) + + +@tvm.register_func("nnvm.compiler.lower", override=True) +def _lower(sch, inputs, func_name, graph): + import traceback + # pylint: disable=broad-except + try: + f = tvm.lower(sch, inputs, name=func_name) + if "quantized_conv2d" in func_name: + logging.info(graph.ir(join_entry_attrs=["shape"])) + except Exception: + msg = traceback.format_exc() + msg += "Error during compile graph\n" + msg += "--------------------------\n" + msg += graph.ir(join_entry_attrs=["shape"]) + raise RuntimeError(msg) + return f if isinstance( + f, (tvm.container.Array, tuple, list)) else [f] + + +@reg.register_compute("clip", level=11) +def compute_clip(attrs, inputs, _): + """ Clip operator. + """ + x = inputs[0] + a_min = attrs.get_float("a_min") + a_max = attrs.get_float("a_max") + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + with tvm.tag_scope(topi.tag.ELEMWISE): + x = tvm.compute( + x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute( + x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + +reg.register_pattern("identity", OpPattern.INJECTIVE, level=11) + +@reg.register_compute("quantized_conv2d", level=11) +def compute_quantized_conv2d(attrs, inputs, out): + """ 2D convolution algorithm. + """ + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + channels = attrs.get_int("channels") + layout = attrs["layout"] + out_dtype = attrs['out_type'] + cmp_dtype = 'int32' # compute data type + + assert layout == "NCHW", "only support nchw for now" + assert dilation == (1, 1), "not support dilate now" + assert attrs.get_bool("use_bias") is False + pack_channel = attrs.get_int("pack_channel") + if pack_channel != 0: + assert groups == 1 + return packed_conv2d(inputs[0], inputs[1], + padding, strides) + if groups == 1: + out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype) + elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: + out = topi.nn.depthwise_conv2d_nchw( + inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype) + else: + raise ValueError("not support arbitrary group number for now") + + assert out_dtype == cmp_dtype + return out + + +@reg.register_schedule("quantized_conv2d", level=11) +def schedule_quantized_conv2d(attrs, outs, target): + """ 2D convolution schedule. + """ + channels = attrs.get_int("channels") + pack_channel = attrs.get_int("pack_channel") + if channels != 0 and pack_channel: + target = tvm.target.create(target) + if target.device_name == "vta": + return schedule_packed_conv2d(outs) + elif target.startswith("llvm"): + return tvm.create_schedule([x.op for x in outs]) + else: + raise RuntimeError("not support target %s" % target) + with tvm.target.create(target): + return topi.generic.schedule_conv2d_nchw(outs) + + +def _get_workload(data, pad_data, kernel, output): + """ Get the workload structure. + """ + o_shape = topi.util.get_const_tuple(output.shape) + d_shape = topi.util.get_const_tuple(data.shape) + k_shape = topi.util.get_const_tuple(kernel.shape) + o_b, o_c, o_h, o_w, o_blk = o_shape + i_b, i_c, i_h, i_w, i_blk = d_shape + k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape + # For now we need to assume that input channel blocking is the same + # as the output channel blocking + assert o_blk == i_blk + # Make sure that dimensions match + assert o_b == i_b + assert o_blk == ko_blk + assert i_blk == ki_blk + assert k_o == o_c + assert k_i == i_c + # Scale the channel size + i_c *= i_blk + o_c *= o_blk + if pad_data is not None: + p_shape = topi.util.get_const_tuple(pad_data.shape) + h_pad = (p_shape[2] - d_shape[2]) // 2 + w_pad = (p_shape[3] - d_shape[3]) // 2 + else: + h_pad, w_pad = 0, 0 + h_str = (i_h + h_pad*2 - k_h) // (o_h - 1) + w_str = (i_w + w_pad*2 - k_w) // (o_w - 1) + return Workload(i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str) + +_WL2PLAN = {} + +def schedule_packed_conv2d(outs): + """ Schedule the packed conv2d. + """ + assert len(outs) == 1 + output = outs[0] + ewise_inputs = [] + ewise_ops = [] + conv2d_res = [] + assert output.dtype == "int8" + assert output.op.input_tensors[0].dtype == "int32" + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + ewise_ops.append(op) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.PlaceholderOp): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + assert op.tag == "packed_conv2d" + conv2d_res.append(op) + + _traverse(output.op) + assert len(conv2d_res) == 1 + conv2d_stage = conv2d_res[0].output(0) + + data, kernel = conv2d_stage.op.input_tensors + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + temp = data.op.input_tensors[0] + pad_data = data + data = temp + else: + pad_data = None + wrkld = _get_workload(data, pad_data, kernel, output) + + plan = _WL2PLAN[wrkld] + load_inp = load_wgt = load_out = store_out = "dma_copy" + alu = "alu" + gevm = GEVM + + # schedule1 + oshape = topi.util.get_const_tuple(output.shape) + s = tvm.create_schedule(output.op) + + # setup pad + if pad_data is not None: + cdata = pad_data + s[pad_data].set_scope(vta.SCOPE_INP) + else: + cdata = s.cache_read(data, vta.SCOPE_INP, [conv2d_stage]) + ckernel = s.cache_read(kernel, vta.SCOPE_WGT, [conv2d_stage]) + s[conv2d_stage].set_scope(vta.SCOPE_OUT) + # cache read input + cache_read_ewise = [] + + for consumer, tensor in ewise_inputs: + cache_read_ewise.append( + s.cache_read(tensor, vta.SCOPE_OUT, [consumer])) + # set ewise scope + for op in ewise_ops: + s[op].set_scope(vta.SCOPE_OUT) + s[op].pragma(s[op].op.axis[0], alu) + + # tile + oc_factor = (plan.oc_factor if plan.oc_factor + else wrkld.out_filter // vta.VTA_BLOCK_OUT) + h_factor = (plan.h_factor if plan.h_factor else oshape[2]) + w_factor = (plan.w_factor if plan.w_factor else oshape[3]) + + x_b, x_oc, x_i, x_j, x_ic = s[output].op.axis + x_oc0, x_oc1 = s[output].split(x_oc, factor=oc_factor) + x_i0, x_i1 = s[output].split(x_i, factor=h_factor) + x_j0, x_j1 = s[output].split(x_j, factor=w_factor) + s[output].reorder(x_b, x_oc0, x_i0, x_j0, x_oc1, x_i1, x_j1, x_ic) + store_pt = x_j0 + + # set all compute scopes + s[conv2d_stage].compute_at(s[output], store_pt) + for op in ewise_ops: + s[op].compute_at(s[output], store_pt) + + for tensor in cache_read_ewise: + s[tensor].compute_at(s[output], store_pt) + s[tensor].pragma(s[tensor].op.axis[0], load_out) + + # virtual threading along output channel axes + if plan.oc_nthread: + _, v_t = s[output].split(x_oc0, factor=plan.oc_nthread) + s[output].reorder(v_t, x_b) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + # virtual threading along spatial rows + if plan.h_nthread: + _, v_t = s[output].split(x_i0, factor=plan.h_nthread) + s[output].reorder(v_t, x_b) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + x_b, x_oc, x_i, x_j, x_ic = s[conv2d_stage].op.axis + k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis + s[conv2d_stage].reorder(k_o, x_j, d_j, d_i, x_oc, x_i, x_ic, k_i) + + if plan.ko_factor: + k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ko_factor) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) + + # Use VTA instructions + s[cdata].pragma(s[cdata].op.axis[0], load_inp) + s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt) + s[conv2d_stage].tensorize(x_ic, gevm) + s[output].pragma(x_oc1, store_out) + return s + + +class Conv2DSchedule(object): + """ 2D convolution schedule object. + """ + def __init__(self, + oc_factor, + ko_factor=1, + h_factor=1, + w_factor=0, + oc_nthread=0, + h_nthread=0): + self.oc_factor = oc_factor + self.ko_factor = ko_factor + self.h_factor = h_factor + self.w_factor = w_factor + self.oc_nthread = oc_nthread + self.h_nthread = h_nthread + +Schedule = Conv2DSchedule + +# ResNet18 workloads +RESNET = { + # Workloads of resnet18 on imagenet + 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), + 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), +} + +# Serial schedule +RESNET_SERIAL = { + RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56), + RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0), + RESNET[2]: Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0), + RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0), + RESNET[4]: Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0), + RESNET[5]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), + RESNET[6]: Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0), + RESNET[7]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0), + RESNET[8]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), + RESNET[9]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), + RESNET[10]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0), + RESNET[11]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0), +} + +# Latency hiding schedule +RESNET_OPT = { + RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56), + RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2), + RESNET[2]: Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2), + RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2), + RESNET[4]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2), + RESNET[5]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2), + RESNET[6]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), + RESNET[7]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), + RESNET[8]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), + RESNET[9]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), + RESNET[10]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), + RESNET[11]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2), +} + +_WL2PLAN = RESNET_OPT diff --git a/vta/src/pynq/pynq_driver.h b/vta/src/pynq/pynq_driver.h index 952c4cff8..481df6bbe 100644 --- a/vta/src/pynq/pynq_driver.h +++ b/vta/src/pynq/pynq_driver.h @@ -80,4 +80,4 @@ void xlnkInvalidateCache(void* buf, int size); #ifdef __cplusplus } #endif -#endif // VTA_PYNQ_PYNQ_DRIVER_H_ \ No newline at end of file +#endif // VTA_PYNQ_PYNQ_DRIVER_H_ diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index dde88e8cc..7c5708b4e 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -1043,9 +1043,9 @@ class CommandQueue { VTAWriteMappedReg(vta_load_handle_, 0x10, 0); // LOAD @ 0x18 : Data signal of weight_V VTAWriteMappedReg(vta_load_handle_, 0x18, 0); - // COMPUTE @ 0x10 : Data signal of uops_V + // COMPUTE @ 0x20 : Data signal of uops_V VTAWriteMappedReg(vta_compute_handle_, 0x20, 0); - // COMPUTE @ 0x18 : Data signal of biases_V + // COMPUTE @ 0x28 : Data signal of biases_V VTAWriteMappedReg(vta_compute_handle_, 0x28, 0); // STORE @ 0x10 : Data signal of outputs_V VTAWriteMappedReg(vta_store_handle_, 0x10, 0); diff --git a/vta/tests/hardware/common/test_lib.h b/vta/tests/hardware/common/test_lib.h index 037e2fcee..458ff7138 100644 --- a/vta/tests/hardware/common/test_lib.h +++ b/vta/tests/hardware/common/test_lib.h @@ -39,7 +39,7 @@ uint64_t vta( #else // NO_SIM -#include "../../../hardware/vivado/src/vta.h" +#include "../../../hardware/xilinx/src/vta.h" #endif // NO_SIM diff --git a/vta/tests/hardware/pynq/Makefile b/vta/tests/hardware/pynq/Makefile index 7e70366f3..dabf55e26 100644 --- a/vta/tests/hardware/pynq/Makefile +++ b/vta/tests/hardware/pynq/Makefile @@ -1,6 +1,6 @@ CC ?= g++ CFLAGS = -Wall -O3 -std=c++11 -I/usr/include -LDFLAGS = -L/usr/lib -L/home/xilinx/pynq/drivers +LDFLAGS = -L/usr/lib -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/ LIBS = -l:libsds_lib.so -l:libdma.so INCLUDE_DIR = ../../../include DRIVER_DIR = ../../../src/pynq diff --git a/vta/tests/python/pynq/test_benchmark_conv2d.py b/vta/tests/python/pynq/test_benchmark_conv2d.py new file mode 100644 index 000000000..1ffb3e73d --- /dev/null +++ b/vta/tests/python/pynq/test_benchmark_conv2d.py @@ -0,0 +1,414 @@ +import os +import tvm +import mxnet as mx +import vta +import numpy as np +import topi +from collections import namedtuple +from tvm.contrib import rpc, util +import pandas as pd + +host = "pynq" +port = 9091 +target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" +out_dtype = "int%d" % vta.VTA_OUT_WIDTH +inp_dtype = "int%d" % vta.VTA_INP_WIDTH +wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH + +Workload = namedtuple("Conv2DWorkload", + ['height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +class Conv2DSchedule(object): + def __init__(self, + oc_factor, + ko_factor=1, + h_factor=1, + w_factor=0, + oc_nthread=0, + h_nthread=0, + debug_sync=False): + self.oc_factor = oc_factor + self.ko_factor = ko_factor + self.h_factor = h_factor + self.w_factor = w_factor + self.oc_nthread = oc_nthread + self.h_nthread = h_nthread + self.debug_sync = debug_sync + +Schedule = Conv2DSchedule + +def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): + assert batch_size % vta.VTA_BATCH == 0 + assert wl.in_filter % vta.VTA_BLOCK_IN == 0 + assert wl.out_filter % vta.VTA_BLOCK_OUT == 0 + data_shape = (batch_size//vta.VTA_BATCH, wl.in_filter//vta.VTA_BLOCK_IN, + wl.height, wl.width, vta.VTA_BATCH, vta.VTA_BLOCK_IN) + kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN, + wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN) + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + res_shape = (batch_size//vta.VTA_BATCH, wl.out_filter//vta.VTA_BLOCK_OUT, + fout_height, fout_width, vta.VTA_BATCH, vta.VTA_BLOCK_OUT) + data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype) + if wl.hpad or wl.wpad: + data_buf = topi.nn.pad(data, [0, 0, wl.hpad, wl.wpad, 0, 0], name="data_buf") + else: + data_buf = tvm.compute(data_shape, lambda *i: data(*i), "data_buf") + kernel_buf = tvm.compute(kernel_shape, lambda *i: kernel(*i), "kernel_buf") + di = tvm.reduce_axis((0, wl.hkernel), name='di') + dj = tvm.reduce_axis((0, wl.wkernel), name='dj') + ko = tvm.reduce_axis((0, wl.in_filter//vta.VTA_BLOCK_IN), name='ko') + ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki') + res_cnv = tvm.compute( + res_shape, + lambda bo, co, i, j, bi, ci: tvm.sum( + data_buf[bo, ko, i*wl.hstride+di, j*wl.wstride+dj, bi, ki].astype(out_dtype) * + kernel_buf[co, ko, di, dj, ci, ki].astype(out_dtype), + axis=[ko, di, dj, ki]), + name="res_cnv") + res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf") + res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(inp_dtype), name="res") + num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + + def verify(s, check_correctness): + mod = tvm.build(s, [data, kernel, res], "ext_dev", target, name="conv2d") + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + # verify + ctx = remote.ext_dev(0) + # Data in original format + data_orig = np.random.randint( + -128, 128, size=(batch_size, wl.in_filter, wl.height, wl.width)).astype(data.dtype) + kernel_orig = np.random.randint( + -128, 128, size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype) + data_packed = data_orig.reshape( + batch_size//vta.VTA_BATCH, vta.VTA_BATCH, + wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) + kernel_packed = kernel_orig.reshape( + wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT, + wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_packed, ctx) + kernel_arr = tvm.nd.array(kernel_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=10) + cost = time_f(data_arr, kernel_arr, res_arr) + res_unpack = res_arr.asnumpy().transpose( + (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) + if check_correctness: + res_ref = mx.nd.Convolution( + mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)), + mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)), + stride=(wl.hstride, wl.wstride), + kernel=(wl.hkernel, wl.wkernel), + num_filter=wl.out_filter, + no_bias=True, + pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype) + res_ref = np.right_shift(res_ref, 8).astype(res.dtype) + np.testing.assert_allclose(res_unpack, res_ref) + print("Correctness check pass...") + return cost + + def run_schedule(load_inp, load_wgt, gemm, alu, store_out, + print_ir, check_correctness): + # schedule1 + s = tvm.create_schedule(res.op) + s[data_buf].set_scope(vta.SCOPE_INP) + s[kernel_buf].set_scope(vta.SCOPE_WGT) + s[res_cnv].set_scope(vta.SCOPE_OUT) + s[res_shf].set_scope(vta.SCOPE_OUT) + # tile + oc_factor = (plan.oc_factor if plan.oc_factor + else wl.out_filter // vta.VTA_BLOCK_OUT) + h_factor = (plan.h_factor if plan.h_factor else fout_height) + w_factor = (plan.w_factor if plan.w_factor else fout_width) + xbo, xco, xi, xj, xbi, xci = s[res].op.axis + xco0, xco1 = s[res].split(xco, factor=oc_factor) + xi0, xi1 = s[res].split(xi, factor=h_factor) + xj0, xj1 = s[res].split(xj, factor=w_factor) + s[res].reorder(xbo, xi0, xco0, xj0, xco1, xi1, xj1, xbi, xci) + s[res_cnv].compute_at(s[res], xj0) + s[res_shf].compute_at(s[res], xj0) + + if plan.oc_nthread: + _, tx = s[res].split(xco0, factor=plan.oc_nthread) + s[res].reorder(tx, xbo) + s[res].bind(tx, tvm.thread_axis("cthread")) + + if plan.h_nthread: + xo, tx = s[res].split(xi0, factor=plan.h_nthread) + s[res].reorder(tx, xbo) + s[res].bind(tx, tvm.thread_axis("cthread")) + + xbo, xco, xi, xj, xbi, xci = s[res_cnv].op.axis + s[res_cnv].reorder(xbo, ko, xj, dj, di, xco, xi, xbi, xci, ki) + + if plan.ko_factor: + ko0, ko1 = s[res_cnv].split(ko, factor=plan.ko_factor) + s[data_buf].compute_at(s[res_cnv], ko0) + s[kernel_buf].compute_at(s[res_cnv], ko0) + # Use VTA instructions + s[data_buf].pragma(s[data_buf].op.axis[0], load_inp) + s[kernel_buf].pragma(s[kernel_buf].op.axis[0], load_wgt) + s[res_cnv].tensorize(xbi, gemm) + s[res_shf].pragma(s[res_shf].op.axis[0], alu) + s[res].pragma(xco1, store_out) + if plan.debug_sync: + s[res].pragma(xco0, "coproc_sync") + if print_ir: + print(tvm.lower(s, [data, kernel, res], simple_mode=True)) + return verify(s, check_correctness) + + def conv_normal(print_ir): + print("----- CONV2D End-to-End Test-------") + def run_test(header, print_ir, check_correctness): + cost = run_schedule( + vta.DMA_COPY, vta.DMA_COPY, + vta.GEMM, vta.ALU, vta.DMA_COPY, + print_ir, check_correctness) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + log_frame["key"].append(key) + log_frame["total-gops"].append(gops) + log_frame["total-cost"].append(cost.mean) + + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir, True) + + def skip_alu_unittest(print_ir): + mock = vta.mock + print("----- Skip ALU Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + vta.DMA_COPY, vta.DMA_COPY, + vta.GEMM, mock.ALU, vta.DMA_COPY, + print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + log_frame["skip-alu-gops"].append(gops) + log_frame["skip-alu-cost"].append(cost.mean) + + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def gemm_unittest(print_ir): + mock = vta.mock + print("----- GEMM Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, mock.DMA_COPY, + vta.GEMM, mock.ALU, mock.DMA_COPY, + print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + log_frame["gemm-gops"].append(gops) + log_frame["gemm-cost"].append(cost.mean) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def alu_unittest(print_ir): + mock = vta.mock + print("----- ALU Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, mock.DMA_COPY, + mock.GEMM, vta.ALU, mock.DMA_COPY, + print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + log_frame["alu-gops"].append(gops) + log_frame["alu-cost"].append(cost.mean) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def load_inp_unittest(print_ir): + mock = vta.mock + print("----- LoadInp Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + vta.DMA_COPY, mock.DMA_COPY, + mock.GEMM, mock.ALU, mock.DMA_COPY, + print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + bandwith = (batch_size * wl.in_filter * wl.height * + wl.width * vta.INP_WIDTH / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % ( + cost.mean, gops, bandwith)) + log_frame["ld-inp-gbits"].append(bandwith) + log_frame["ld-inp-cost"].append(cost.mean) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def load_wgt_unittest(print_ir): + mock = vta.mock + print("----- LoadWgt Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, vta.DMA_COPY, + mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, + False) + gops = (num_ops / cost.mean) / float(10 ** 9) + bandwith = (wl.out_filter * wl.in_filter * wl.hkernel * + wl.wkernel * vta.WGT_WIDTH / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % ( + cost.mean, gops, bandwith)) + log_frame["ld-wgt-gbits"].append(bandwith) + log_frame["ld-wgt-cost"].append(cost.mean) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def store_out_unittest(print_ir): + mock = vta.mock + print("----- StoreOut Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, mock.DMA_COPY, + mock.GEMM, mock.ALU, vta.DMA_COPY, print_ir, + False) + gops = (num_ops / cost.mean) / float(10 ** 9) + bandwith = (batch_size * wl.out_filter * fout_height * + fout_width * vta.OUT_WIDTH / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % ( + cost.mean, gops, bandwith)) + log_frame["st-out-gbits"].append(bandwith) + log_frame["st-out-cost"].append(cost.mean) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def manual_unittest(print_ir): + # Manual section used to teak the components + mock = vta.mock + print("----- Manual Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + vta.DMA_COPY, vta.DMA_COPY, + vta.GEMM, vta.ALU, mock.DMA_COPY, print_ir, + False) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % ( + cost.mean, gops)) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + print("=================================") + print("key=%s" % key) + print(wl) + conv_normal(False) + if not profile: + return + skip_alu_unittest(False) + gemm_unittest(False) + alu_unittest(False) + load_inp_unittest(False) + load_wgt_unittest(False) + store_out_unittest(False) + +# ResNet18 workloads +resnet = { + # Workloads of resnet18 on imagenet + 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), + 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), +} + +# List of simple benchmarks +simple = [ + Workload(height=22, width=22, in_filter=256, out_filter=64, + hkernel=3, wkernel=3, hpad=1, wpad=1, hstride=1, wstride=1) +] + +# Serial schedule +resnet_serial = [ + [None, None], + [resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0)], + [resnet[2], Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0)], + [resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0)], + [resnet[4], Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0)], + [resnet[5], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)], + [resnet[6], Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0)], + [resnet[7], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)], + [resnet[8], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)], + [resnet[9], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)], + [resnet[10], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)], + [resnet[11], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)], +] + +# SMT schedule +resnet_smt = [ + [resnet[0], Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56)], + [resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2)], + [resnet[2], Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2)], + [resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)], + [resnet[4], Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2)], + [resnet[5], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)], + [resnet[6], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)], + [resnet[7], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)], + [resnet[8], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)], + [resnet[9], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)], + [resnet[10], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)], + [resnet[11], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)], +] + +# Perform profiling +profile = False +# Whether use SMT +use_smt = True +# Data set batch size +batch_size = 1 + +resnet_schedule = resnet_smt if use_smt else resnet_serial + +begin = 0 +end = len(resnet_schedule) +keys = ["key", "total-gops", "total-cost", + "skip-alu-gops", "skip-alu-cost", + "gemm-gops", "gemm-cost", "alu-gops", "alu-cost", + "ld-inp-cost", "ld-wgt-cost", "st-out-cost", + "ld-inp-gbits", "ld-wgt-gbits", "st-out-gbits",] +log_frame = { + k : [] for k in keys +} +for i, x in enumerate(resnet_schedule[begin:end]): + wl, plan = x + if not wl: + continue + key = "resnet-cfg[%d]" % i + test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile) + +if profile: + pd.set_option('expand_frame_repr', False) + log_df = pd.DataFrame() + for k in keys: + log_df[k] = log_frame[k] + print(log_df) diff --git a/vta/tests/python/pynq/test_benchmark_gemm.py b/vta/tests/python/pynq/test_benchmark_gemm.py new file mode 100644 index 000000000..e5180c144 --- /dev/null +++ b/vta/tests/python/pynq/test_benchmark_gemm.py @@ -0,0 +1,267 @@ +import os +import tvm +import vta +import numpy as np +import time +from tvm.contrib import rpc, util + +host = "pynq" +port = 9091 +target = "llvm -target=armv7-none-linux-gnueabihf" +out_dtype = "int%d" % vta.VTA_OUT_WIDTH +inp_dtype = "int%d" % vta.VTA_INP_WIDTH +wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH + +def test_gemm_packed(batch_size, channel, block): + data_shape = (batch_size//vta.VTA_BATCH, + channel//vta.VTA_BLOCK_IN, + vta.VTA_BATCH, + vta.VTA_BLOCK_IN) + weight_shape = (channel//vta.VTA_BLOCK_OUT, + channel//vta.VTA_BLOCK_IN, + vta.VTA_BLOCK_OUT, + vta.VTA_BLOCK_IN) + res_shape = (batch_size//vta.VTA_BATCH, + channel//vta.VTA_BLOCK_OUT, + vta.VTA_BATCH, + vta.VTA_BLOCK_OUT) + num_ops = channel * channel * batch_size + + ko = tvm.reduce_axis((0, channel//vta.VTA_BLOCK_IN), name='ko') + ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki') + + data = tvm.placeholder(data_shape, + name="data", + dtype=inp_dtype) + weight = tvm.placeholder(weight_shape, + name="weight", + dtype=wgt_dtype) + data_buf = tvm.compute(data_shape, + lambda *i: data(*i), + "data_buf") + weight_buf = tvm.compute(weight_shape, + lambda *i: weight(*i), + "weight_buf") + res_gem = tvm.compute(res_shape, + lambda bo, co, bi, ci: tvm.sum( + data_buf[bo, ko, bi, ki].astype(out_dtype) * + weight_buf[co, ko, ci, ki].astype(out_dtype), + axis=[ko, ki]), + name="res_gem") + res_shf = tvm.compute(res_shape, + lambda *i: res_gem(*i)>>8, + name="res_shf") + res_max = tvm.compute(res_shape, + lambda *i: tvm.max(res_shf(*i), 0), + "res_max") #relu + res_min = tvm.compute(res_shape, + lambda *i: tvm.min(res_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1), + "res_min") #relu + res = tvm.compute(res_shape, + lambda *i: res_min(*i).astype(inp_dtype), + name="res") + + def verify(s, check_correctness=True): + mod = tvm.build(s, [data, weight, res], "ext_dev", target, name="gemm") + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("gemm.o")) + remote.upload(temp.relpath("gemm.o")) + f = remote.load_module("gemm.o") + # verify + ctx = remote.ext_dev(0) + # Data in original format + data_orig = np.random.randint( + -128, 128, size=(batch_size, channel)).astype(data.dtype) + weight_orig = np.random.randint( + -128, 128, size=(channel, channel)).astype(weight.dtype) + data_packed = data_orig.reshape( + batch_size//vta.VTA_BATCH, vta.VTA_BATCH, + channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3)) + weight_packed = weight_orig.reshape( + channel//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT, + channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3)) + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_packed, ctx) + weight_arr = tvm.nd.array(weight_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + res_ref = np.zeros(res_shape).astype(out_dtype) + for b in range(batch_size//vta.VTA_BATCH): + for i in range(channel//vta.VTA_BLOCK_OUT): + for j in range(channel//vta.VTA_BLOCK_IN): + res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(out_dtype), + weight_packed[i,j].T.astype(out_dtype)) + res_ref = np.right_shift(res_ref, 8) + res_ref = np.clip(res_ref, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype) + time_f = f.time_evaluator("gemm", ctx, number=20) + cost = time_f(data_arr, weight_arr, res_arr) + res_unpack = res_arr.asnumpy().reshape(batch_size//vta.VTA_BATCH, + channel//vta.VTA_BLOCK_OUT, + vta.VTA_BATCH, + vta.VTA_BLOCK_OUT) + if check_correctness: + np.testing.assert_allclose(res_unpack, res_ref) + return cost + + def run_schedule(load_inp, + load_wgt, + gemm, + alu, + store_out, + print_ir, + check_correctness): + s = tvm.create_schedule(res.op) + s[data_buf].set_scope(vta.SCOPE_INP) + s[weight_buf].set_scope(vta.SCOPE_WGT) + s[res_gem].set_scope(vta.SCOPE_OUT) + s[res_shf].set_scope(vta.SCOPE_OUT) + s[res_min].set_scope(vta.SCOPE_OUT) + s[res_max].set_scope(vta.SCOPE_OUT) + + if block: + bblock = block // vta.VTA_BATCH + iblock = block // vta.VTA_BLOCK_IN + oblock = block // vta.VTA_BLOCK_OUT + xbo, xco, xbi, xci = s[res].op.axis + xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock) + store_pt = xb2 + + s[res_gem].compute_at(s[res], xco1) + s[res_shf].compute_at(s[res], xco1) + s[res_min].compute_at(s[res], xco1) + s[res_max].compute_at(s[res], xco1) + + xbo, xco, xbi, xci = s[res_gem].op.axis + # Compute one line at a time + ko1, ko2 = s[res_gem].split(ko, iblock) + s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki) + s[data_buf].compute_at(s[res_gem], ko1) + s[weight_buf].compute_at(s[res_gem], ko1) + # Use VTA instructions + s[data_buf].pragma(s[data_buf].op.axis[0], load_inp) + s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt) + s[res_gem].tensorize(xbi, gemm) + s[res_shf].pragma(s[res_shf].op.axis[0], alu) + s[res_min].pragma(s[res_min].op.axis[0], alu) + s[res_max].pragma(s[res_max].op.axis[0], alu) + s[res].pragma(store_pt, store_out) + else: + xbo, xco, xbi, xci = s[res_gem].op.axis + s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki) + # Use VTA instructions + s[data_buf].pragma(s[data_buf].op.axis[0], load_inp) + s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt) + s[res_gem].tensorize(xbi, gemm) + s[res_shf].pragma(s[res_shf].op.axis[0], alu) + s[res_min].pragma(s[res_min].op.axis[0], alu) + s[res_max].pragma(s[res_max].op.axis[0], alu) + s[res].pragma(s[res].op.axis[0], store_out) + + if print_ir: + print(tvm.lower(s, [data, weight, res], simple_mode=True)) + return verify(s, check_correctness) + + def gemm_normal(print_ir): + mock = vta.mock + print("----- GEMM GFLOPS End-to-End Test-------") + def run_test(header, print_ir, check_correctness): + cost = run_schedule( + vta.DMA_COPY, vta.DMA_COPY, vta.GEMM, vta.ALU, vta.DMA_COPY, + print_ir, check_correctness) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + with tvm.build_config(add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)): + run_test("NORMAL", print_ir, True) + + print("") + + def gevm_unittest(print_ir): + mock = vta.mock + print("----- GEMM Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, mock.DMA_COPY, vta.GEMM, mock.ALU, mock.DMA_COPY, + print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def alu_unittest(print_ir): + mock = vta.mock + print("----- ALU Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, vta.ALU, mock.DMA_COPY, + print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def load_inp_unittest(print_ir): + mock = vta.mock + print("----- LoadInp Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + vta.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + bandwith = (batch_size * channel * vta.VTA_INP_WIDTH / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( + cost.mean, gops, bandwith)) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def load_wgt_unittest(print_ir): + mock = vta.mock + print("----- LoadWgt Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, vta.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + bandwith = (channel * channel * vta.VTA_WGT_WIDTH / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( + cost.mean, gops, bandwith)) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + def store_out_unittest(print_ir): + mock = vta.mock + print("----- StoreOut Unit Test-------") + def run_test(header, print_ir): + cost = run_schedule( + mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, vta.DMA_COPY, + print_ir, False) + gops = (num_ops / cost.mean) / float(10 ** 9) + bandwith = (batch_size * channel * vta.VTA_OUT_WIDTH / cost.mean) / float(10 ** 9) + print(header) + print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( + cost.mean, gops, bandwith)) + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + run_test("NORMAL", print_ir) + print("") + + gemm_normal(False) + gevm_unittest(False) + alu_unittest(False) + # FIXME: report time that is too short + # load_inp_unittest(False) + # load_wgt_unittest(False) + # store_out_unittest(False) + + +print("========GEMM 128=========") +test_gemm_packed(128, 128, 128) + +# FIXME: hanging run +# print("========GEMM 1024========") +# test_gemm_packed(1024, 1024, 128) diff --git a/vta/tests/python/pynq/test_benchmark_topi.py b/vta/tests/python/pynq/test_benchmark_topi.py new file mode 100644 index 000000000..f98e70f01 --- /dev/null +++ b/vta/tests/python/pynq/test_benchmark_topi.py @@ -0,0 +1,144 @@ +"""Testing if we can generate code in topi style""" + +import topi +import tvm +from tvm.contrib import util, rpc +import vta +from vta import vta_conv2d +import numpy as np +import mxnet as mx + +Workload = vta_conv2d.Workload + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +host = "pynq" +port = 9091 +out_dtype = "int%d" % vta.VTA_OUT_WIDTH +wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH +inp_dtype = "int%d" % vta.VTA_INP_WIDTH +target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" +print_ir = False + +def test_vta_conv2d(key, batch_size, wl, profile=True): + data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN, + wl.height, wl.width, vta.VTA_BLOCK_IN) + kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN, + wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN) + bias_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT) + + + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype) + bias = tvm.placeholder(bias_shape, name="kernel", dtype=out_dtype) + + res_conv = vta_conv2d.packed_conv2d( + data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) + res = topi.right_shift(res_conv, 8) + res = topi.broadcast_add(res, bias) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + + def verify(s, check_correctness): + mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d") + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + # verify + ctx = remote.ext_dev(0) + # Data in original format + data_orig = (np.random.uniform( + size=(batch_size, wl.in_filter, wl.height, wl.width)) * 4).astype(data.dtype) + kernel_orig = (np.random.uniform( + size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)) * 4).astype(kernel.dtype) + bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") + + data_orig = np.abs(data_orig) + kernel_orig = np.abs(kernel_orig) + bias_orig = np.abs(bias_orig) + + data_packed = data_orig.reshape( + batch_size, wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, + wl.height, wl.width).transpose((0, 1, 3, 4, 2)) + kernel_packed = kernel_orig.reshape( + wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT, + wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) + bias_packed = bias_orig.reshape( + wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT) + res_shape = topi.util.get_const_tuple(res.shape) + + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_packed, ctx) + kernel_arr = tvm.nd.array(kernel_packed, ctx) + bias_arr = tvm.nd.array(bias_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=10) + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + res_unpack = res_arr.asnumpy().transpose( + (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) + if check_correctness: + res_ref = mx.nd.Convolution( + mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)), + mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)), + stride=(wl.hstride, wl.wstride), + kernel=(wl.hkernel, wl.wkernel), + num_filter=wl.out_filter, + no_bias=True, + pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype) + res_ref = res_ref >> 8 + res_ref += bias_orig.reshape(wl.out_filter, 1, 1) + res_ref = np.clip(res_ref, 0, 127).astype("int8") + np.testing.assert_allclose(res_unpack, res_ref) + print("Correctness check pass...") + return cost + + def conv_normal(print_ir): + print("----- CONV2D End-to-End Test-------") + with tvm.build_config(add_lower_pass=vta.debug_mode(0)): + s = vta_conv2d.schedule_packed_conv2d([res]) + if print_ir: + print(tvm.lower(s, [data, kernel, bias, res], simple_mode=True)) + cost = verify(s, True) + gops = (num_ops / cost.mean) / float(10 ** 9) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + + conv_normal(print_ir) + +# ResNet18 workloads +resnet = { + # Workloads of resnet18 on imagenet + 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), + 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), +} + +batch_size = 1 +for i in range(0, len(resnet)): + wl = resnet[i] + key = "resnet-cfg[%d]" % i + print "key=%s" % key + print wl + test_vta_conv2d(key, batch_size, wl) diff --git a/vta/tests/python/pynq/test_program_rpc.py b/vta/tests/python/pynq/test_program_rpc.py new file mode 100644 index 000000000..cc3b79281 --- /dev/null +++ b/vta/tests/python/pynq/test_program_rpc.py @@ -0,0 +1,21 @@ +import tvm +import vta +import os +from tvm.contrib import rpc, util + +host = "pynq" +port = 9091 +target = "llvm -target=armv7-none-linux-gnueabihf" +bit = "vta.bit" + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +bitstream = os.path.join(curr_path, "./", bit) + +def test_program_rpc(): + assert tvm.module.enabled("rpc") + remote = rpc.connect(host, port) + remote.upload(bitstream, bit) + fprogram = remote.get_function("tvm.contrib.vta.init") + fprogram(bit) + +test_program_rpc() diff --git a/vta/tests/python/pynq/test_vta_insn.py b/vta/tests/python/pynq/test_vta_insn.py new file mode 100644 index 000000000..b1d17b6d8 --- /dev/null +++ b/vta/tests/python/pynq/test_vta_insn.py @@ -0,0 +1,498 @@ +"""Unit test TPU's instructions """ +import tvm +import vta +import mxnet as mx +import numpy as np +import topi +from tvm.contrib import rpc, util + +host = "pynq" +port = 9091 +target = "llvm -target=armv7-none-linux-gnueabihf" +out_dtype = "int%d" % vta.VTA_OUT_WIDTH +inp_dtype = "int%d" % vta.VTA_INP_WIDTH +wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH +do_verify = True +print_ir = False + + +def test_save_load_out(): + """Test save/store output command""" + n = 4 + x = tvm.placeholder( + (n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + name="x", + dtype=out_dtype) + x_buf = tvm.compute( + (n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: x(*i), + "x_buf") + # insert no-op that won't be optimized away + y_buf = tvm.compute( + (n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: x_buf(*i)>>0, + "y_buf") + y = tvm.compute( + (n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: y_buf(*i).astype(inp_dtype), + "y") + # schedule + s = tvm.create_schedule(y.op) + s[x_buf].set_scope(vta.SCOPE_OUT) + s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY) + s[y_buf].set_scope(vta.SCOPE_OUT) + s[y_buf].pragma(y_buf.op.axis[0], vta.ALU) + s[y].pragma(y.op.axis[0], vta.DMA_COPY) + + def verify(): + # build + m = tvm.build(s, [x, y], "ext_dev", target) + temp = util.tempdir() + remote = rpc.connect(host, port) + m.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + x_np = np.random.randint( + 1, 10, size=(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype) + y_np = x_np.astype(y.dtype) + x_nd = tvm.nd.array(x_np, ctx) + y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + f(x_nd, y_nd) + np.testing.assert_equal(y_np, y_nd.asnumpy()) + print("\tFinished verification...") + if do_verify: + verify() + +def test_padded_load(): + """Test padded load.""" + # declare + n = 21 + m = 20 + pad_before = [0, 1, 0, 0] + pad_after = [1, 3, 0, 0] + x = tvm.placeholder( + (n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + name="x", + dtype=out_dtype) + x_buf = topi.nn.pad(x, pad_before, pad_after, name="y") + # insert no-op that won't be optimized away + y_buf = tvm.compute((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + vta.VTA_BATCH, + vta.VTA_BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf") + y = tvm.compute((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + vta.VTA_BATCH, + vta.VTA_BLOCK_OUT), lambda *i: y_buf(*i).astype(inp_dtype), "y") + # schedule + s = tvm.create_schedule(y.op) + s[x_buf].set_scope(vta.SCOPE_OUT) + s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY) + s[y_buf].set_scope(vta.SCOPE_OUT) + s[y_buf].pragma(y_buf.op.axis[0], vta.ALU) + s[y].pragma(y.op.axis[0], vta.DMA_COPY) + + def verify(): + # build + mod = tvm.build(s, [x, y], "ext_dev", target) + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("padded_load.o")) + remote.upload(temp.relpath("padded_load.o")) + f = remote.load_module("padded_load.o") + # verify + ctx = remote.ext_dev(0) + x_np = np.random.randint(1, 2, size=( + n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype) + y_np = np.zeros((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + vta.VTA_BATCH, + vta.VTA_BLOCK_OUT)).astype(y.dtype) + y_np[pad_before[0]:pad_before[0] + n, + pad_before[1]:pad_before[1] + m, + :] = x_np + x_nd = tvm.nd.array(x_np, ctx) + y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + f(x_nd, y_nd) + np.testing.assert_equal(y_np, y_nd.asnumpy()) + print("\tFinished verification...") + if print_ir: + print(tvm.lower(s, [y, x], simple_mode=True)) + if do_verify: + with tvm.build_config(add_lower_pass=vta.debug_mode( + vta.DEBUG_DUMP_INSN)): + verify() + +def test_gemm(): + """Test GEMM.""" + # declare + o = 4 + n = 4 + m = 4 + x = tvm.placeholder((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), name="x", dtype=inp_dtype) + w = tvm.placeholder((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), name="w", dtype=wgt_dtype) + x_buf = tvm.compute((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), lambda *i: x(*i), "x_buf") + w_buf = tvm.compute((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), lambda *i: w(*i), "w_buf") + ko = tvm.reduce_axis((0, n), name="ko") + ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name="ki") + y_gem = tvm.compute( + (o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda bo, co, bi, ci: + tvm.sum(x_buf[bo, ko, bi, ki].astype(out_dtype) * + w_buf[co, ko, ci, ki].astype(out_dtype), + axis=[ko, ki]), + name="y_gem") + y_shf = tvm.compute( + (o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: y_gem(*i)>>8, + name="y_shf") + y_max = tvm.compute( + (o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: tvm.max(y_shf(*i), 0), + "y_max") #relu + y_min = tvm.compute( + (o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: tvm.min(y_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1), + "y_min") #relu + y = tvm.compute( + (o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: y_min(*i).astype(inp_dtype), + name="y") + + def verify(s): + mod = tvm.build(s, [x, w, y], "ext_dev", target) + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("gemm.o")) + remote.upload(temp.relpath("gemm.o")) + f = remote.load_module("gemm.o") + # verify + ctx = remote.ext_dev(0) + x_np = np.random.randint( + -128, 128, size=(o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN)).astype(x.dtype) + w_np = np.random.randint( + -128, 128, size=(m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)).astype(w.dtype) + y_np = np.zeros((o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(y.dtype) + x_nd = tvm.nd.array(x_np, ctx) + w_nd = tvm.nd.array(w_np, ctx) + y_nd = tvm.nd.array(y_np, ctx) + y_np = y_np.astype(out_dtype) + for b in range(o): + for i in range(m): + for j in range(n): + y_np[b,i,:] += np.dot(x_np[b,j,:].astype(out_dtype), + w_np[i,j].T.astype(out_dtype)) + y_np = np.right_shift(y_np, 8) + y_np = np.clip(y_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(y.dtype) + f(x_nd, w_nd, y_nd) + np.testing.assert_equal(y_np, y_nd.asnumpy()) + print("\tFinished verification...") + + def test_schedule1(): + # default schedule with no smt + s = tvm.create_schedule(y.op) + # set the scope of the SRAM buffers + s[x_buf].set_scope(vta.SCOPE_INP) + s[w_buf].set_scope(vta.SCOPE_WGT) + s[y_gem].set_scope(vta.SCOPE_OUT) + s[y_shf].set_scope(vta.SCOPE_OUT) + s[y_max].set_scope(vta.SCOPE_OUT) + s[y_min].set_scope(vta.SCOPE_OUT) + # set pragmas for DMA transfer and ALU ops + s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY) + s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY) + s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU) + s[y_max].pragma(s[y_max].op.axis[0], vta.ALU) + s[y_min].pragma(s[y_min].op.axis[0], vta.ALU) + s[y].pragma(s[y].op.axis[0], vta.DMA_COPY) + # tensorization + s[y_gem].reorder( + ko, + s[y_gem].op.axis[0], + s[y_gem].op.axis[1], + s[y_gem].op.axis[2], + s[y_gem].op.axis[3], + ki) + s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM) + if print_ir: + print(tvm.lower(s, [x, w, y], simple_mode=True)) + if do_verify: + with tvm.build_config( + add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)): + verify(s) + + def test_smt(): + # test smt schedule + s = tvm.create_schedule(y.op) + s[x_buf].set_scope(vta.SCOPE_INP) + s[w_buf].set_scope(vta.SCOPE_WGT) + s[y_gem].set_scope(vta.SCOPE_OUT) + s[y_shf].set_scope(vta.SCOPE_OUT) + s[y_max].set_scope(vta.SCOPE_OUT) + s[y_min].set_scope(vta.SCOPE_OUT) + abo, aco, abi, aci = s[y].op.axis + abo1, abo2 = s[y].split(abo, nparts=2) + s[y].bind(abo1, tvm.thread_axis("cthread")) + s[y_gem].compute_at(s[y], abo1) + s[y_shf].compute_at(s[y], abo1) + s[y_max].compute_at(s[y], abo1) + s[y_min].compute_at(s[y], abo1) + s[y_gem].reorder( + ko, + s[y_gem].op.axis[0], + s[y_gem].op.axis[1], + s[y_gem].op.axis[2], + s[y_gem].op.axis[3], + ki) + s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM) + s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU) + s[y_max].pragma(s[y_max].op.axis[0], vta.ALU) + s[y_min].pragma(s[y_min].op.axis[0], vta.ALU) + s[x_buf].compute_at(s[y_gem], ko) + s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY) + s[w_buf].compute_at(s[y_gem], ko) + s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY) + s[y].pragma(abo2, vta.DMA_COPY) + if print_ir: + print(tvm.lower(s, [x, y, w], simple_mode=True)) + if do_verify: + with tvm.build_config( + add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)): + verify(s) + + test_schedule1() + test_smt() + +def test_alu(tvm_op, np_op=None, use_imm=False): + """Test ALU""" + m = 8 + n = 8 + imm = np.random.randint(1,5) + # compute + a = tvm.placeholder( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + name="a", + dtype=out_dtype) + a_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: a(*i), + "a_buf") #DRAM->SRAM + if use_imm: + res_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: tvm_op(a_buf(*i), imm), + "res_buf") #compute + else: + b = tvm.placeholder( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + name="b", + dtype=out_dtype) + b_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: b(*i), + "b_buf") #DRAM->SRAM + res_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: tvm_op(a_buf(*i), b_buf(*i)), + "res_buf") #compute + res = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: res_buf(*i).astype(inp_dtype), + "res") #SRAM->DRAM + # schedule + s = tvm.create_schedule(res.op) + s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM + s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM + s[res_buf].set_scope(vta.SCOPE_OUT) # SRAM + s[res_buf].pragma(res_buf.op.axis[0], vta.ALU) # compute + s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM + if use_imm: + if print_ir: + print(tvm.lower(s, [a, res], simple_mode=True)) + else: + s[b_buf].set_scope(vta.SCOPE_OUT) # SRAM + s[b_buf].pragma(b_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM + if print_ir: + print(tvm.lower(s, [a, b, res], simple_mode=True)) + + def verify(): + # build + if use_imm: + mod = tvm.build(s, [a, res], "ext_dev", target) + else: + mod = tvm.build(s, [a, b, res], "ext_dev", target) + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + a_np = np.random.randint( + -16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype) + if use_imm: + res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm) + else: + b_np = np.random.randint( + -16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(b.dtype) + res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np) + res_np = res_np.astype(res.dtype) + a_nd = tvm.nd.array(a_np, ctx) + res_nd = tvm.nd.array( + np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx) + if use_imm: + f(a_nd, res_nd) + else: + b_nd = tvm.nd.array(b_np, ctx) + f(a_nd, b_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + print("\tFinished verification...") + + if do_verify: + verify() + +def test_relu(): + """Test RELU on ALU""" + m = 8 + n = 8 + # compute + a = tvm.placeholder( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + name="a", + dtype=out_dtype) + a_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: a(*i), + "a_buf") # DRAM->SRAM + max_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: tvm.max(a_buf(*i), 0), + "res_buf") # relu + min_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: tvm.min(max_buf(*i), (1<<(vta.VTA_INP_WIDTH-1))-1), + "max_buf") # relu + res = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: min_buf(*i).astype(inp_dtype), + "min_buf") # SRAM->DRAM + # schedule + s = tvm.create_schedule(res.op) + s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM + s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM + s[max_buf].set_scope(vta.SCOPE_OUT) # SRAM + s[min_buf].set_scope(vta.SCOPE_OUT) # SRAM + s[max_buf].pragma(max_buf.op.axis[0], vta.ALU) # compute + s[min_buf].pragma(min_buf.op.axis[0], vta.ALU) # compute + s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM + if print_ir: + print(tvm.lower(s, [a, res], simple_mode=True)) + + def verify(): + # build + mod = tvm.build(s, [a, res], "ext_dev", target) + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + a_np = np.random.randint( + -256, 256, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype) + res_np = np.clip(a_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype) + a_nd = tvm.nd.array(a_np, ctx) + res_nd = tvm.nd.array( + np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx) + f(a_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + print("\tFinished verification...") + + if do_verify: + verify() + +def test_shift_and_scale(): + """Test shift and scale on ALU""" + m = 8 + n = 8 + imm_shift = np.random.randint(-10,10) + imm_scale = np.random.randint(1,5) + # compute + a = tvm.placeholder( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + name="a", dtype=out_dtype) + a_buf = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: a(*i), + "a_buf") # DRAM->SRAM + res_shift = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: a_buf(*i)+imm_shift, + "res_shift") # compute + res_scale = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: res_shift(*i)>>imm_scale, + "res_scale") # compute + res = tvm.compute( + (m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), + lambda *i: res_scale(*i).astype(inp_dtype), + "res") # SRAM->DRAM + # schedule + s = tvm.create_schedule(res.op) + s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM + s[res_shift].set_scope(vta.SCOPE_OUT) # SRAM + s[res_scale].set_scope(vta.SCOPE_OUT) # SRAM + s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM + s[res_shift].pragma(res_shift.op.axis[0], vta.ALU) # compute + s[res_scale].pragma(res_scale.op.axis[0], vta.ALU) # compute + s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM + if print_ir: + print(tvm.lower(s, [a, res], simple_mode=True)) + + def verify(): + # build + mod = tvm.build(s, [a, res], "ext_dev", target) + temp = util.tempdir() + remote = rpc.connect(host, port) + mod.save(temp.relpath("load_act.o")) + remote.upload(temp.relpath("load_act.o")) + f = remote.load_module("load_act.o") + # verify + ctx = remote.ext_dev(0) + a_np = np.random.randint( + -10, 10, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype) + res_np = np.right_shift((a_np + imm_shift), imm_scale) + res_np = res_np.astype(res.dtype) + a_nd = tvm.nd.array(a_np, ctx) + res_nd = tvm.nd.array( + np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx) + f(a_nd, res_nd) + np.testing.assert_equal(res_np, res_nd.asnumpy()) + print("\tFinished verification...") + + if do_verify: + verify() + +if __name__ == "__main__": + print("Padded load test") + test_padded_load() + print("Load/store test") + test_save_load_out() + print("GEMM test") + test_gemm() + print("Max immediate") + test_alu(tvm.max, np.maximum, use_imm=True) + print("Max") + test_alu(tvm.max, np.maximum) + print("Add immediate") + test_alu(lambda x, y: x + y, use_imm=True) + print("Add") + test_alu(lambda x, y: x + y) + print("Shift right immediate") + test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True) + print("Relu") + test_relu() + # print("Shift and scale") + # test_shift_and_scale() -- GitLab