From ffe1badd9da2086055aa3ce78c8ba1195f62b17a Mon Sep 17 00:00:00 2001
From: Thierry Moreau <moreau@cs.washington.edu>
Date: Tue, 3 Jul 2018 09:47:31 -0700
Subject: [PATCH] [TUTORIAL] Resnet-18 end to end tutorial example (#55)

---
 vta/tutorials/resnet.py | 326 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 326 insertions(+)
 create mode 100644 vta/tutorials/resnet.py

diff --git a/vta/tutorials/resnet.py b/vta/tutorials/resnet.py
new file mode 100644
index 000000000..72ed5d7d2
--- /dev/null
+++ b/vta/tutorials/resnet.py
@@ -0,0 +1,326 @@
+"""
+ResNet Inference Example
+========================
+**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
+
+This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
+onto the VTA accelerator design to perform ImageNet classification tasks.
+
+"""
+
+
+######################################################################
+# Import Libraries
+# ----------------
+# We start by importing the tvm, vta, nnvm libraries to run this example.
+
+from __future__ import absolute_import, print_function
+
+import os
+import sys
+import nnvm
+import nnvm.compiler
+import tvm
+import vta
+import vta.testing
+import numpy as np
+import json
+import requests
+import time
+
+from nnvm.compiler import graph_attr
+from tvm.contrib import graph_runtime, rpc, util
+from tvm.contrib.download import download
+from vta.testing import simulator
+
+from io import BytesIO
+from matplotlib import pyplot as plt
+from PIL import Image
+
+# Load VTA parameters from the config.json file
+env = vta.get_env()
+
+# Helper to crop an image to a square (224, 224)
+# Takes in an Image object, returns an Image object
+def thumbnailify(image, pad=15):
+    w, h = image.size
+    crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
+    image = image.crop(crop)
+    image = image.resize((224, 224))
+    return image
+
+# Helper function to read in image
+# Takes in Image object, returns an ND array
+def process_image(image):
+    # Convert to neural network input format
+    image = np.array(image) - np.array([123., 117., 104.])
+    image /= np.array([58.395, 57.12, 57.375])
+    image = image.transpose((2, 0, 1))
+    image = image[np.newaxis, :]
+
+    return tvm.nd.array(image.astype("float32"))
+
+# Classification helper function
+# Takes in the graph runtime, and an image, and returns top result and time
+def classify(m, image):
+    m.set_input('data', image)
+    timer = m.module.time_evaluator("run", ctx, number=1)
+    tcost = timer()
+    tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0)))
+    top = np.argmax(tvm_output.asnumpy())
+    tcost = "t={0:.2f}s".format(tcost.mean)
+    return tcost + " {}".format(synset[top])
+
+# Helper function to compile the NNVM graph
+# Takes in a path to a graph file, params file, and device target
+# Returns the NNVM graph object, a compiled library object, and the params dict
+def generate_graph(graph_fn, params_fn, device="vta"):
+
+    # Measure build start time
+    build_start = time.time()
+
+    # Derive the TVM target
+    target = tvm.target.create("llvm -device={}".format(device))
+
+    # Derive the LLVM compiler flags
+    # When targetting the Pynq, cross-compile to ARMv7 ISA
+    if env.TARGET == "sim":
+        target_host = "llvm"
+    elif env.TARGET == "pynq":
+        target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
+
+    # Load the ResNet-18 graph and parameters
+    sym = nnvm.graph.load_json(open(graph_fn).read())
+    params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read())
+
+    # Populate the shape and data type dictionary
+    shape_dict = {"data": (1, 3, 224, 224)}
+    dtype_dict = {"data": 'float32'}
+    shape_dict.update({k: v.shape for k, v in params.items()})
+    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
+
+    # Create NNVM graph
+    graph = nnvm.graph.create(sym)
+    graph_attr.set_shape_inputs(sym, shape_dict)
+    graph_attr.set_dtype_inputs(sym, dtype_dict)
+    graph = graph.apply("InferShape").apply("InferType")
+
+    # Apply NNVM graph optimization passes
+    sym = vta.graph.clean_cast(sym)
+    sym = vta.graph.clean_conv_fuse(sym)
+    if target.device_name == "vta":
+        assert env.BLOCK_IN == env.BLOCK_OUT
+        sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
+
+    # Compile NNVM graph
+    with nnvm.compiler.build_config(opt_level=3):
+        if target.device_name != "vta":
+            graph, lib, params = nnvm.compiler.build(
+                sym, target_host, shape_dict, dtype_dict,
+                params=params)
+        else:
+            with vta.build_config():
+                graph, lib, params = nnvm.compiler.build(
+                    sym, target, shape_dict, dtype_dict,
+                    params=params, target_host=target_host)
+
+    # Save the compiled inference graph library
+    assert tvm.module.enabled("rpc")
+    temp = util.tempdir()
+    lib.save(temp.relpath("graphlib.o"))
+
+    # Send the inference library over to the remote RPC server
+    remote.upload(temp.relpath("graphlib.o"))
+    lib = remote.load_module("graphlib.o")
+
+    # Measure build time
+    build_time = time.time() - build_start
+    print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time))
+
+    return graph, lib, params
+
+
+######################################################################
+# Download ResNet Model
+# --------------------------------------------
+# Download the necessary files to run ResNet-18.
+#
+
+# Obtain ResNet model and download them into _data dir
+url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
+categ_fn = 'synset.txt'
+graph_fn = 'resnet18_qt8.json'
+params_fn = 'resnet18_qt8.params'
+
+# Create data dir
+data_dir = "_data/"
+if not os.path.exists(data_dir):
+    os.makedirs(data_dir)
+
+# Download files
+for file in [categ_fn, graph_fn, params_fn]:
+    if not os.path.isfile(file):
+        download(os.path.join(url, file), os.path.join(data_dir, file))
+
+# Read in ImageNet Categories
+synset = eval(open(os.path.join(data_dir, categ_fn)).read())
+
+
+######################################################################
+# Setup the Pynq Board's RPC Server
+# ---------------------------------
+# Build the RPC server's VTA runtime and program the Pynq FPGA.
+
+# Measure build start time
+reconfig_start = time.time()
+
+# We read the Pynq RPC host IP address and port number from the OS environment
+host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
+port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))
+
+# We configure both the bitstream and the runtime system on the Pynq
+# to match the VTA configuration specified by the config.json file.
+if env.TARGET == "pynq":
+
+    # Make sure that TVM was compiled with RPC=1
+    assert tvm.module.enabled("rpc")
+    remote = rpc.connect(host, port)
+
+    # Reconfigure the JIT runtime
+    vta.reconfig_runtime(remote)
+
+    # Program the FPGA with a pre-compiled VTA bitstream.
+    # You can program the FPGA with your own custom bitstream
+    # by passing the path to the bitstream file instead of None.
+    vta.program_fpga(remote, bitstream=None)
+
+    # Report on reconfiguration time
+    reconfig_time = time.time() - reconfig_start
+    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
+
+# In simulation mode, host the RPC server locally.
+elif env.TARGET == "sim":
+    remote = rpc.LocalSession()
+
+
+######################################################################
+# Build the ResNet Runtime
+# ------------------------
+# Build the ResNet graph runtime, and configure the parameters.
+
+# Set ``device=cpu`` to run inference on the CPU,
+# or ``device=vtacpu`` to run inference on the FPGA.
+device = "vta"
+
+# Device context
+ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
+
+# Build the graph runtime
+graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn),
+                                    os.path.join(data_dir, params_fn),
+                                    device)
+m = graph_runtime.create(graph, lib, ctx)
+
+# Set the parameters
+m.set_input(**params)
+
+
+######################################################################
+# Run ResNet-18 inference on a sample image
+# -----------------------------------------
+# Perform image classification on test image.
+# You can change the test image URL to any image of your choosing.
+
+# Read in test image
+image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
+# Read in test image
+response = requests.get(image_url)
+image = Image.open(BytesIO(response.content)).resize((224, 224))
+# Show Image
+plt.imshow(image)
+plt.show()
+# Set the input
+image = process_image(image)
+m.set_input('data', image)
+
+# Perform inference
+timer = m.module.time_evaluator("run", ctx, number=1)
+tcost = timer()
+
+# Get classification results
+tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0)))
+top_categories = np.argsort(tvm_output.asnumpy())
+
+# Report top-5 classification results
+print("ResNet-18 Prediction #1:", synset[top_categories[-1]])
+print("                     #2:", synset[top_categories[-2]])
+print("                     #3:", synset[top_categories[-3]])
+print("                     #4:", synset[top_categories[-4]])
+print("                     #5:", synset[top_categories[-5]])
+print("Performed inference in {0:.2f}s".format(tcost.mean))
+
+
+######################################################################
+# Run a Youtube Video Image Classifier
+# ------------------------------------
+# Perform image classification on test stream on 1 frame every 48 frames.
+# Comment the `if False:` out to run the demo
+
+# Early exit - remove for Demo
+if False:
+
+    import cv2
+    import pafy
+    from IPython.display import clear_output
+
+    # Helper to crop an image to a square (224, 224)
+    # Takes in an Image object, returns an Image object
+    def thumbnailify(image, pad=15):
+        w, h = image.size
+        crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
+        image = image.crop(crop)
+        image = image.resize((224, 224))
+        return image
+
+    # 16:16 inches
+    plt.rcParams['figure.figsize'] = [16, 16]
+
+    # Stream the video in
+    url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
+    video = pafy.new(url)
+    best = video.getbest(preftype="mp4")
+    cap = cv2.VideoCapture(best.url)
+
+    # Process one frame out of every 48 for variety
+    count = 0
+    guess = ""
+    while(count<2400):
+
+        # Capture frame-by-frame
+        ret, frame = cap.read()
+
+        # Process one every 48 frames
+        if count % 48 == 1:
+            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+            frame = Image.fromarray(frame)
+            # Crop and resize
+            thumb = np.array(thumbnailify(frame))
+            image = process_image(thumb)
+            guess = classify(m, image)
+
+            # Insert guess in frame
+            frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
+            cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)
+
+            plt.imshow(thumb)
+            plt.axis('off')
+            plt.show()
+            if cv2.waitKey(1) & 0xFF == ord('q'):
+                break
+            clear_output(wait=True)
+
+        count += 1
+
+    # When everything done, release the capture
+    cap.release()
+    cv2.destroyAllWindows()
\ No newline at end of file
-- 
GitLab