From 33a309b239260a36d67fad70b9767e54e5dd333d Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sun, 8 Apr 2018 21:43:54 -0700 Subject: [PATCH] [RPC][RUNTIME] Support dynamic reload of runtime API according to config (#19) --- vta/Makefile | 33 ++-- vta/apps/pynq_rpc/start_rpc_server.sh | 4 +- .../resnet18/pynq/imagenet_predict.py | 4 +- vta/include/vta/runtime.h | 66 ++++---- vta/make/config.mk | 1 + vta/python/vta/__init__.py | 22 ++- vta/python/vta/exec/__init__.py | 1 + vta/python/vta/exec/rpc_server.py | 104 +++++++++++++ vta/python/vta/hw_spec.py | 22 ++- vta/python/vta/rpc_client.py | 45 ++++++ vta/python/vta/runtime.py | 1 - vta/src/data_buffer.cc | 44 ++++++ vta/src/data_buffer.h | 90 +++++++++++ vta/src/runtime.cc | 144 +++--------------- vta/src/tvm/vta_device_api.cc | 51 ++++--- vta/tests/python/pynq/test_benchmark_topi.py | 2 + vta/tests/python/pynq/test_program_rpc.py | 10 +- 17 files changed, 433 insertions(+), 211 deletions(-) create mode 100644 vta/python/vta/exec/__init__.py create mode 100644 vta/python/vta/exec/rpc_server.py create mode 100644 vta/python/vta/rpc_client.py create mode 100644 vta/src/data_buffer.cc create mode 100644 vta/src/data_buffer.h diff --git a/vta/Makefile b/vta/Makefile index cdea90e6f..069f6e01c 100644 --- a/vta/Makefile +++ b/vta/Makefile @@ -40,36 +40,31 @@ ifneq ($(ADD_LDFLAGS), NONE) LDFLAGS += $(ADD_LDFLAGS) endif -ifeq ($(UNAME_S), Darwin) - SHARED_LIBRARY_SUFFIX := dylib - WHOLE_ARCH= -all_load - NO_WHOLE_ARCH= -noall_load - LDFLAGS += -undefined dynamic_lookup -else - SHARED_LIBRARY_SUFFIX := so - WHOLE_ARCH= --whole-archive - NO_WHOLE_ARCH= --no-whole-archive -endif - -all: lib/libvta.$(SHARED_LIBRARY_SUFFIX) +all: lib/libvta.so lib/libvta_runtime.so VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc) + ifeq ($(TARGET), VTA_PYNQ_TARGET) VTA_LIB_SRC += $(wildcard src/pynq/*.cc) LDFLAGS += -L/usr/lib -lsds_lib - LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -l:libdma.so + LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ + LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/ + LDFLAGS += -l:libdma.so endif -VTA_LIB_OBJ = $(patsubst %.cc, build/%.o, $(VTA_LIB_SRC)) -test: $(TEST) +VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC)) -build/src/%.o: src/%.cc +build/%.o: src/%.cc @mkdir -p $(@D) - $(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/src/$*.d + $(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d $(CXX) -c $(CFLAGS) -c $< -o $@ -lib/libvta.$(SHARED_LIBRARY_SUFFIX): $(VTA_LIB_OBJ) +lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ)) + @mkdir -p $(@D) + $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) + +lib/libvta_runtime.so: build/runtime.o @mkdir -p $(@D) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) @@ -79,7 +74,7 @@ cpplint: python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests pylint: - pylint python/tvm_vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc + pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc doc: doxygen docs/Doxyfile diff --git a/vta/apps/pynq_rpc/start_rpc_server.sh b/vta/apps/pynq_rpc/start_rpc_server.sh index d5a1202a1..445b72ea7 100755 --- a/vta/apps/pynq_rpc/start_rpc_server.sh +++ b/vta/apps/pynq_rpc/start_rpc_server.sh @@ -1,4 +1,4 @@ #!/bin/bash -export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python +export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python:/home/xilinx/vta/python export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -python -m tvm.exec.rpc_server --load-library /home/xilinx/vta/lib/libvta.so +python -m vta.exec.rpc_server diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py index b9e6ab56a..74b34fecf 100644 --- a/vta/examples/resnet18/pynq/imagenet_predict.py +++ b/vta/examples/resnet18/pynq/imagenet_predict.py @@ -34,9 +34,7 @@ for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITST # Program the FPGA remotely assert tvm.module.enabled("rpc") remote = rpc.connect(host, port) -remote.upload(BITSTREAM_FILE, BITSTREAM_FILE) -fprogram = remote.get_function("tvm.contrib.vta.init") -fprogram(BITSTREAM_FILE) +vta.program_fpga(remote, BITSTREAM_FILE) if verbose: logging.basicConfig(level=logging.INFO) diff --git a/vta/include/vta/runtime.h b/vta/include/vta/runtime.h index e1aae32f4..c9373846d 100644 --- a/vta/include/vta/runtime.h +++ b/vta/include/vta/runtime.h @@ -23,40 +23,20 @@ extern "C" { #define VTA_DEBUG_SKIP_WRITE_BARRIER (1 << 4) #define VTA_DEBUG_FORCE_SERIAL (1 << 5) -/*! \brief VTA command handle */ -typedef void * VTACommandHandle; - -/*! \brief Shutdown hook of VTA to cleanup resources */ -void VTARuntimeShutdown(); - -/*! - * \brief Get thread local command handle. - * \return A thread local command handle. - */ -VTACommandHandle VTATLSCommandHandle(); - /*! * \brief Allocate data buffer. * \param cmd The VTA command handle. * \param size Buffer size. * \return A pointer to the allocated buffer. */ -void* VTABufferAlloc(VTACommandHandle cmd, size_t size); +void* VTABufferAlloc(size_t size); /*! * \brief Free data buffer. * \param cmd The VTA command handle. * \param buffer The data buffer to be freed. */ -void VTABufferFree(VTACommandHandle cmd, void* buffer); - -/*! - * \brief Get the buffer access pointer on CPU. - * \param cmd The VTA command handle. - * \param buffer The data buffer. - * \return The pointer that can be accessed by the CPU. - */ -void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer); +void VTABufferFree(void* buffer); /*! * \brief Copy data buffer from one location to another. @@ -68,20 +48,32 @@ void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer); * \param size Size of copy. * \param kind_mask The memory copy kind. */ -void VTABufferCopy(VTACommandHandle cmd, - const void* from, +void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, int kind_mask); +/*! \brief VTA command handle */ +typedef void* VTACommandHandle; + +/*! \brief Shutdown hook of VTA to cleanup resources */ +void VTARuntimeShutdown(); + /*! - * \brief Set debug mode on the command handle. + * \brief Get thread local command handle. + * \return A thread local command handle. + */ +VTACommandHandle VTATLSCommandHandle(); + +/*! + * \brief Get the buffer access pointer on CPU. * \param cmd The VTA command handle. - * \param debug_flag The debug flag. + * \param buffer The data buffer. + * \return The pointer that can be accessed by the CPU. */ -void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); +void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer); /*! * \brief Perform a write barrier to make a memory region visible to the CPU. @@ -92,9 +84,10 @@ void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); * \param extent The end of the region (in elements). */ void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, uint32_t elem_bits, - uint32_t start, uint32_t extent); - + void* buffer, + uint32_t elem_bits, + uint32_t start, + uint32_t extent); /*! * \brief Perform a read barrier to a memory region visible to VTA. * \param cmd The VTA command handle. @@ -104,8 +97,17 @@ void VTAWriteBarrier(VTACommandHandle cmd, * \param extent The end of the region (in elements). */ void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, uint32_t elem_bits, - uint32_t start, uint32_t extent); + void* buffer, + uint32_t elem_bits, + uint32_t start, + uint32_t extent); + +/*! + * \brief Set debug mode on the command handle. + * \param cmd The VTA command handle. + * \param debug_flag The debug flag. + */ +void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); /*! * \brief Perform a 2D data load from DRAM. diff --git a/vta/make/config.mk b/vta/make/config.mk index 062dfa8c3..9f611896a 100644 --- a/vta/make/config.mk +++ b/vta/make/config.mk @@ -54,6 +54,7 @@ VTA_LOG_WGT_BUFF_SIZE = 15 # Log of acc buffer size in Bytes VTA_LOG_ACC_BUFF_SIZE = 17 + #--------------------- # Derived VTA hardware parameters #-------------------- diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index 4a6f760d0..dc06b7ad0 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -1,12 +1,20 @@ -"""TVM VTA runtime""" +"""TVM-based VTA Compiler Toolchain""" from __future__ import absolute_import as _abs from .hw_spec import * -from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU -from .intrin import GEVM, GEMM -from .build import debug_mode +try: + from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU + from .intrin import GEVM, GEMM + from .build import debug_mode + from . import mock, ir_pass + from . import arm_conv2d, vta_conv2d +except AttributeError: + pass -from . import mock, ir_pass -from . import arm_conv2d, vta_conv2d -from . import graph +from .rpc_client import reconfig_runtime, program_fpga + +try: + from . import graph +except ImportError: + pass diff --git a/vta/python/vta/exec/__init__.py b/vta/python/vta/exec/__init__.py new file mode 100644 index 000000000..2fa9de930 --- /dev/null +++ b/vta/python/vta/exec/__init__.py @@ -0,0 +1 @@ +"""VTA Command line utils.""" diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py new file mode 100644 index 000000000..8a27859f7 --- /dev/null +++ b/vta/python/vta/exec/rpc_server.py @@ -0,0 +1,104 @@ +"""VTA customized TVM RPC Server + +Provides additional runtime function and library loading. +""" +from __future__ import absolute_import + +import logging +import argparse +import os +import ctypes +import tvm +from tvm.contrib import rpc, util, cc + + +@tvm.register_func("tvm.contrib.rpc.server.start", override=True) +def server_start(): + curr_path = os.path.dirname( + os.path.abspath(os.path.expanduser(__file__))) + dll_path = os.path.abspath( + os.path.join(curr_path, "../../../lib/libvta_runtime.so")) + runtime_dll = [] + _load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module") + + @tvm.register_func("tvm.contrib.rpc.server.load_module", override=True) + def load_module(file_name): + if not runtime_dll: + runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL)) + return _load_module(file_name) + + @tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True) + def server_shutdown(): + if runtime_dll: + runtime_dll[0].VTARuntimeShutdown() + runtime_dll.pop() + + @tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True) + def reconfig_runtime(cflags): + """Rebuild and reload runtime with new configuration. + + Parameters + ---------- + cfg_json : str + JSON string used for configurations. + """ + if runtime_dll: + raise RuntimeError("Can only reconfig in the beginning of session...") + cflags = cflags.split() + cflags += ["-O2", "-std=c++11"] + lib_name = dll_path + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + proj_root = os.path.abspath(os.path.join(curr_path, "../../../")) + runtime_source = os.path.join(proj_root, "src/runtime.cc") + cflags += ["-I%s/include" % proj_root] + cflags += ["-I%s/nnvm/tvm/include" % proj_root] + cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root] + cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root] + logging.info("Rebuild runtime dll with %s", str(cflags)) + cc.create_shared(lib_name, [runtime_source], cflags) + + +def main(): + """Main funciton""" + parser = argparse.ArgumentParser() + parser.add_argument('--host', type=str, default="0.0.0.0", + help='the hostname of the server') + parser.add_argument('--port', type=int, default=9090, + help='The port of the PRC') + parser.add_argument('--port-end', type=int, default=9199, + help='The end search port of the PRC') + parser.add_argument('--key', type=str, default="", + help="RPC key used to identify the connection type.") + parser.add_argument('--tracker', type=str, default="", + help="Report to RPC tracker") + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + proj_root = os.path.abspath(os.path.join(curr_path, "../../../")) + lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so")) + + libs = [] + for file_name in [lib_path]: + libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) + logging.info("Load additional library %s", file_name) + + if args.tracker: + url, port = args.tracker.split(":") + port = int(port) + tracker_addr = (url, port) + if not args.key: + raise RuntimeError( + "Need key to present type of resource when tracker is available") + else: + tracker_addr = None + + server = rpc.Server(args.host, + args.port, + args.port_end, + key=args.key, + tracker_addr=tracker_addr) + server.libs += libs + server.proc.join() + +if __name__ == "__main__": + main() diff --git a/vta/python/vta/hw_spec.py b/vta/python/vta/hw_spec.py index b6b89df81..ec7595f4c 100644 --- a/vta/python/vta/hw_spec.py +++ b/vta/python/vta/hw_spec.py @@ -1,11 +1,31 @@ """VTA configuration constants (should match hw_spec.h""" from __future__ import absolute_import as _abs +# Log of input/activation width in bits (default 3 -> 8 bits) +VTA_LOG_INP_WIDTH = 3 +# Log of kernel weight width in bits (default 3 -> 8 bits) +VTA_LOG_WGT_WIDTH = 3 +# Log of accum width in bits (default 5 -> 32 bits) +VTA_LOG_ACC_WIDTH = 5 +# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication) +VTA_LOG_BATCH = 0 +# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication) +VTA_LOG_BLOCK_IN = 4 +# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication) +VTA_LOG_BLOCK_OUT = 4 +VTA_LOG_OUT_WIDTH = VTA_LOG_INP_WIDTH +# Log of uop buffer size in Bytes +VTA_LOG_UOP_BUFF_SIZE = 15 +# Log of acc buffer size in Bytes +VTA_LOG_ACC_BUFF_SIZE = 17 + # The Constants VTA_WGT_WIDTH = 8 VTA_INP_WIDTH = VTA_WGT_WIDTH VTA_OUT_WIDTH = 32 +VTA_TARGET = "VTA_PYNQ_TARGET" + # Dimensions of the GEMM unit # (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT) VTA_BATCH = 1 @@ -67,4 +87,4 @@ VTA_QID_STORE_INP = 3 DEBUG_DUMP_INSN = (1 << 1) DEBUG_DUMP_UOP = (1 << 2) DEBUG_SKIP_READ_BARRIER = (1 << 3) -DEBUG_SKIP_WRITE_BARRIER = (1 << 4) \ No newline at end of file +DEBUG_SKIP_WRITE_BARRIER = (1 << 4) diff --git a/vta/python/vta/rpc_client.py b/vta/python/vta/rpc_client.py new file mode 100644 index 000000000..fb51113e1 --- /dev/null +++ b/vta/python/vta/rpc_client.py @@ -0,0 +1,45 @@ +"""VTA RPC client function""" +import os +from . import hw_spec as spec + +def reconfig_runtime(remote): + """Reconfigure remote runtime based on current hardware spec. + + Parameters + ---------- + remote : RPCSession + The TVM RPC session + """ + keys = ["VTA_LOG_WGT_WIDTH", + "VTA_LOG_INP_WIDTH", + "VTA_LOG_ACC_WIDTH", + "VTA_LOG_OUT_WIDTH", + "VTA_LOG_BATCH", + "VTA_LOG_BLOCK_IN", + "VTA_LOG_BLOCK_OUT", + "VTA_LOG_UOP_BUFF_SIZE", + "VTA_LOG_INP_BUFF_SIZE", + "VTA_LOG_WGT_BUFF_SIZE", + "VTA_LOG_ACC_BUFF_SIZE", + "VTA_LOG_OUT_BUFF_SIZE"] + cflags = ["-D%s" % spec.VTA_TARGET] + for k in keys: + cflags += ["-D%s=%s" % (k, str(getattr(spec, k)))] + freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime") + freconfig(" ".join(cflags)) + + +def program_fpga(remote, bitstream): + """Upload and program bistream + + Parameters + ---------- + remote : RPCSession + The TVM RPC session + + bitstream : str + Path to a local bistream file. + """ + fprogram = remote.get_function("tvm.contrib.vta.init") + remote.upload(bitstream) + fprogram(os.path.basename(bitstream)) diff --git a/vta/python/vta/runtime.py b/vta/python/vta/runtime.py index bfcd130ff..dfbdfe670 100644 --- a/vta/python/vta/runtime.py +++ b/vta/python/vta/runtime.py @@ -25,7 +25,6 @@ def get_task_qid(qid): """Get transformed queue index.""" return 1 if DEBUG_NO_SYNC else qid - @tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync") def coproc_sync(op): return tvm.call_extern( diff --git a/vta/src/data_buffer.cc b/vta/src/data_buffer.cc new file mode 100644 index 000000000..99f959ad8 --- /dev/null +++ b/vta/src/data_buffer.cc @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file data_buffer.cc + * \brief Buffer related API for VTA. + * \note Buffer API remains stable across VTA designes. + */ +#include "./data_buffer.h" + +void* VTABufferAlloc(size_t size) { + return vta::DataBuffer::Alloc(size); +} + +void VTABufferFree(void* buffer) { + vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); +} + +void VTABufferCopy(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t size, + int kind_mask) { + vta::DataBuffer* from_buffer = nullptr; + vta::DataBuffer* to_buffer = nullptr; + + if (kind_mask & 2) { + from_buffer = vta::DataBuffer::FromHandle(from); + from = from_buffer->virt_addr(); + } + if (kind_mask & 1) { + to_buffer = vta::DataBuffer::FromHandle(to); + to = to_buffer->virt_addr(); + } + if (from_buffer) { + from_buffer->InvalidateCache(from_offset, size); + } + + memcpy(static_cast<char*>(to) + to_offset, + static_cast<const char*>(from) + from_offset, + size); + if (to_buffer) { + to_buffer->FlushCache(to_offset, size); + } +} diff --git a/vta/src/data_buffer.h b/vta/src/data_buffer.h new file mode 100644 index 000000000..117a423d0 --- /dev/null +++ b/vta/src/data_buffer.h @@ -0,0 +1,90 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file data_buffer.h + * \brief VTA runtime internal data buffer structure. + */ +#ifndef VTA_DATA_BUFFER_H_ +#define VTA_DATA_BUFFER_H_ + +#include <vta/driver.h> +#include <vta/runtime.h> +#include <cassert> +#include <cstring> + +namespace vta { + +/*! \brief Enable coherent access between VTA and CPU. */ +static const bool kBufferCoherent = true; + +/*! + * \brief Data buffer represents data on CMA. + */ +struct DataBuffer { + /*! \return Virtual address of the data. */ + void* virt_addr() const { + return data_; + } + /*! \return Physical address of the data. */ + uint32_t phy_addr() const { + return phy_addr_; + } + /*! + * \brief Invalidate the cache of given location in data buffer. + * \param offset The offset to the data. + * \param size The size of the data. + */ + void InvalidateCache(size_t offset, size_t size) { + if (!kBufferCoherent) { + VTAInvalidateCache(reinterpret_cast<void*>(phy_addr_ + offset), size); + } + } + /*! + * \brief Invalidate the cache of certain location in data buffer. + * \param offset The offset to the data. + * \param size The size of the data. + */ + void FlushCache(size_t offset, size_t size) { + if (!kBufferCoherent) { + VTAFlushCache(reinterpret_cast<void*>(phy_addr_ + offset), size); + } + } + /*! + * \brief Allocate a buffer of a given size. + * \param size The size of the buffer. + */ + static DataBuffer* Alloc(size_t size) { + void* data = VTAMemAlloc(size, 1); + assert(data != nullptr); + DataBuffer* buffer = new DataBuffer(); + buffer->data_ = data; + buffer->phy_addr_ = VTAGetMemPhysAddr(data); + return buffer; + } + /*! + * \brief Free the data buffer. + * \param buffer The buffer to be freed. + */ + static void Free(DataBuffer* buffer) { + VTAMemFree(buffer->data_); + delete buffer; + } + /*! + * \brief Create data buffer header from buffer ptr. + * \param buffer The buffer pointer. + * \return The corresponding data buffer header. + */ + static DataBuffer* FromHandle(const void* buffer) { + return const_cast<DataBuffer*>( + reinterpret_cast<const DataBuffer*>(buffer)); + } + + private: + /*! \brief The internal data. */ + void* data_; + /*! \brief The physical address of the buffer, excluding header. */ + uint32_t phy_addr_; +}; + +} // namespace vta + +#endif // VTA_DATA_BUFFER_H_ diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index 7c5708b4e..8c7c9fdc9 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2018 by Contributors - * \file vta_runtime.cc + * \file runtime.cc * \brief VTA runtime for PYNQ in C++11 */ @@ -13,85 +13,14 @@ #include <vta/runtime.h> #include <cassert> -#include <cstring> #include <vector> #include <thread> #include <memory> #include <atomic> -namespace vta { - -/*! \brief Enable coherent access between VTA and CPU. */ -static const bool kBufferCoherent = true; +#include "./data_buffer.h" -/*! - * \brief Data buffer represents data on CMA. - */ -struct DataBuffer { - /*! \return Virtual address of the data. */ - void* virt_addr() const { - return data_; - } - /*! \return Physical address of the data. */ - uint32_t phy_addr() const { - return phy_addr_; - } - /*! - * \brief Invalidate the cache of given location in data buffer. - * \param offset The offset to the data. - * \param size The size of the data. - */ - void InvalidateCache(size_t offset, size_t size) { - if (!kBufferCoherent) { - VTAInvalidateCache(reinterpret_cast<void*>(phy_addr_ + offset), size); - } - } - /*! - * \brief Invalidate the cache of certain location in data buffer. - * \param offset The offset to the data. - * \param size The size of the data. - */ - void FlushCache(size_t offset, size_t size) { - if (!kBufferCoherent) { - VTAFlushCache(reinterpret_cast<void*>(phy_addr_ + offset), size); - } - } - /*! - * \brief Allocate a buffer of a given size. - * \param size The size of the buffer. - */ - static DataBuffer* Alloc(size_t size) { - void* data = VTAMemAlloc(size, 1); - assert(data != nullptr); - DataBuffer* buffer = new DataBuffer(); - buffer->data_ = data; - buffer->phy_addr_ = VTAGetMemPhysAddr(data); - return buffer; - } - /*! - * \brief Free the data buffer. - * \param buffer The buffer to be freed. - */ - static void Free(DataBuffer* buffer) { - VTAMemFree(buffer->data_); - delete buffer; - } - /*! - * \brief Create data buffer header from buffer ptr. - * \param buffer The buffer pointer. - * \return The corresponding data buffer header. - */ - static DataBuffer* FromHandle(const void* buffer) { - return const_cast<DataBuffer*>( - reinterpret_cast<const DataBuffer*>(buffer)); - } - - private: - /*! \brief The internal data. */ - void* data_; - /*! \brief The physical address of the buffer, excluding header. */ - uint32_t phy_addr_; -}; +namespace vta { /*! * \brief Micro op kernel. @@ -1130,6 +1059,9 @@ class CommandQueue { static std::shared_ptr<CommandQueue>& ThreadLocal() { static std::shared_ptr<CommandQueue> inst = std::make_shared<CommandQueue>(); + if (inst == nullptr) { + inst = std::make_shared<CommandQueue>(); + } return inst; } @@ -1254,63 +1186,29 @@ void VTARuntimeShutdown() { vta::CommandQueue::Shutdown(); } -void* VTABufferAlloc(VTACommandHandle cmd, size_t size) { - return vta::DataBuffer::Alloc(size); -} - -void VTABufferFree(VTACommandHandle cmd, void* buffer) { - vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); +void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) { + static_cast<vta::CommandQueue*>(cmd)-> + SetDebugFlag(debug_flag); } void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) { return vta::DataBuffer::FromHandle(buffer)->virt_addr(); } -void VTABufferCopy(VTACommandHandle cmd, - const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - int kind_mask) { - vta::DataBuffer* from_buffer = nullptr; - vta::DataBuffer* to_buffer = nullptr; - - if (kind_mask & 2) { - from_buffer = vta::DataBuffer::FromHandle(from); - from = from_buffer->virt_addr(); - } - if (kind_mask & 1) { - to_buffer = vta::DataBuffer::FromHandle(to); - to = to_buffer->virt_addr(); - } - if (from_buffer) { - from_buffer->InvalidateCache(from_offset, size); - } - - memcpy(static_cast<char*>(to) + to_offset, - static_cast<const char*>(from) + from_offset, - size); - if (to_buffer) { - to_buffer->FlushCache(to_offset, size); - } -} - -void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) { - static_cast<vta::CommandQueue*>(cmd)-> - SetDebugFlag(debug_flag); -} - void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, uint32_t elem_bits, - uint32_t start, uint32_t extent) { + void* buffer, + uint32_t elem_bits, + uint32_t start, + uint32_t extent) { static_cast<vta::CommandQueue*>(cmd)-> WriteBarrier(buffer, elem_bits, start, extent); } void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, uint32_t elem_bits, - uint32_t start, uint32_t extent) { + void* buffer, + uint32_t elem_bits, + uint32_t start, + uint32_t extent) { static_cast<vta::CommandQueue*>(cmd)-> ReadBarrier(buffer, elem_bits, start, extent); } @@ -1409,3 +1307,11 @@ void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) { static_cast<vta::CommandQueue*>(cmd)-> Synchronize(wait_cycles); } + +extern "C" int VTARuntimeDynamicMagic() { +#ifdef VTA_DYNAMIC_MAGIC + return VTA_DYNAMIC_MAGIC; +#else + return 0; +#endif +} diff --git a/vta/src/tvm/vta_device_api.cc b/vta/src/tvm/vta_device_api.cc index b7b57e199..ce864df09 100644 --- a/vta/src/tvm/vta_device_api.cc +++ b/vta/src/tvm/vta_device_api.cc @@ -7,31 +7,17 @@ #include <tvm/runtime/registry.h> #include <dmlc/thread_local.h> #include <vta/runtime.h> +#include <dlfcn.h> #include "../../nnvm/tvm/src/runtime/workspace_pool.h" -namespace tvm { -namespace runtime { - -std::string VTARPCGetPath(const std::string& name) { - static const PackedFunc* f = - runtime::Registry::Get("tvm.contrib.rpc.server.workpath"); - CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath"; - return (*f)(name); +extern "C" { + typedef void (*FShutdown)(); + typedef int (*FDynamicMagic)(); } -// Global functions that can be called -TVM_REGISTER_GLOBAL("tvm.contrib.vta.init") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::string path = VTARPCGetPath(args[0]); - VTAProgram(path.c_str()); - LOG(INFO) << "VTA initialization end with bistream " << path; - }); - -TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.shutdown") -.set_body([](TVMArgs args, TVMRetValue* rv) { - VTARuntimeShutdown(); - }); +namespace tvm { +namespace runtime { class VTADeviceAPI final : public DeviceAPI { public: @@ -46,11 +32,11 @@ class VTADeviceAPI final : public DeviceAPI { void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) final { - return VTABufferAlloc(VTATLSCommandHandle(), size); + return VTABufferAlloc(size); } void FreeDataSpace(TVMContext ctx, void* ptr) final { - VTABufferFree(VTATLSCommandHandle(), ptr); + VTABufferFree(ptr); } void CopyDataFromTo(const void* from, @@ -68,8 +54,7 @@ class VTADeviceAPI final : public DeviceAPI { if (ctx_to.device_type != kDLCPU) { kind_mask |= 1; } - VTABufferCopy(VTATLSCommandHandle(), - from, from_offset, + VTABufferCopy(from, from_offset, to, to_offset, size, kind_mask); } @@ -86,6 +71,9 @@ class VTADeviceAPI final : public DeviceAPI { std::make_shared<VTADeviceAPI>(); return inst; } + + private: + void* runtime_dll_{nullptr}; }; struct VTAWorkspacePool : public WorkspacePool { @@ -103,6 +91,21 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()->FreeWorkspace(ctx, data); } +std::string VTARPCGetPath(const std::string& name) { + static const PackedFunc* f = + runtime::Registry::Get("tvm.contrib.rpc.server.workpath"); + CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath"; + return (*f)(name); +} + +// Global functions that can be called +TVM_REGISTER_GLOBAL("tvm.contrib.vta.init") +.set_body([](TVMArgs args, TVMRetValue* rv) { + std::string path = VTARPCGetPath(args[0]); + VTAProgram(path.c_str()); + LOG(INFO) << "VTA initialization end with bistream " << path; + }); + TVM_REGISTER_GLOBAL("device_api.ext_dev") .set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = VTADeviceAPI::Global().get(); diff --git a/vta/tests/python/pynq/test_benchmark_topi.py b/vta/tests/python/pynq/test_benchmark_topi.py index f98e70f01..e6dea3e29 100644 --- a/vta/tests/python/pynq/test_benchmark_topi.py +++ b/vta/tests/python/pynq/test_benchmark_topi.py @@ -27,6 +27,7 @@ inp_dtype = "int%d" % vta.VTA_INP_WIDTH target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" print_ir = False + def test_vta_conv2d(key, batch_size, wl, profile=True): data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN, wl.height, wl.width, vta.VTA_BLOCK_IN) @@ -54,6 +55,7 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d") temp = util.tempdir() remote = rpc.connect(host, port) + mod.save(temp.relpath("conv2d.o")) remote.upload(temp.relpath("conv2d.o")) f = remote.load_module("conv2d.o") diff --git a/vta/tests/python/pynq/test_program_rpc.py b/vta/tests/python/pynq/test_program_rpc.py index cc3b79281..eaf09577d 100644 --- a/vta/tests/python/pynq/test_program_rpc.py +++ b/vta/tests/python/pynq/test_program_rpc.py @@ -14,8 +14,12 @@ bitstream = os.path.join(curr_path, "./", bit) def test_program_rpc(): assert tvm.module.enabled("rpc") remote = rpc.connect(host, port) - remote.upload(bitstream, bit) - fprogram = remote.get_function("tvm.contrib.vta.init") - fprogram(bit) + vta.program_fpga(remote, bit) + +def test_reconfig_runtime(): + assert tvm.module.enabled("rpc") + remote = rpc.connect(host, port) + vta.reconfig_runtime(remote) test_program_rpc() +test_reconfig_runtime() -- GitLab