diff --git a/vta/docs/how_to/install.md b/vta/docs/how_to/install.md
index e2ccf63e90801da093e195a0f6307267beffb4ca..816dd6e307ea138fa4cc0743f96886e754b93c6e 100644
--- a/vta/docs/how_to/install.md
+++ b/vta/docs/how_to/install.md
@@ -59,31 +59,7 @@ In the 'config.mk' file, make sure that:
 
 For the *Python Package Installation*, we recommend updating your `~/.bashrc` file to extend your `PYTHONPATH` with the TVM Python libraries.
 ```bash
-export PYTHONPATH=<tvm root>/python:<tvm root>/topi/python:${PYTHONPATH}
-```
-
-#### NNVM Installation
-
-Clone the NNVM repository from `tqchen` in the directory of your choosing:
-```bash
-git clone git@github.com:tqchen/nnvm.git --recursive
-```
-
-To run this example, we rely on a special branch of NNVM `qt`:
-```bash
-cd <nnvm root>
-git checkout qt
-```
-
-Launch the compilation, this takes about a minute on two threads.
-```bash
-cd <nnvm root>
-make -j2
-```
-
-Finally update your `~/.bashrc` file to include the NNVM python libraries in your `PYTHONPATH`:
-```bash
-export PYTHONPATH=<nnvm root>/python:${PYTHONPATH}
+export PYTHONPATH=<tvm root>/python:<tvm root>/topi/python:<tvm root>/nnvm/python:${PYTHONPATH}
 ```
 
 #### MxNet Installation
@@ -236,7 +212,7 @@ This time again, we will run the 2D convolution testbench. But beforehand, we'll
 * Runtime building on the Pynq, which needs to be run everytime the `config.json` configuration is modified. This ensures that the VTA software runtime that generates the accelerator's executable via just-in-time (JIT) compilation matches the specifications of the VTA design that is programmed on the FPGA. The build process takes about 30 seconds to complete.
 
 ```bash
-python tests/python/pynq/test_program_rpc.py 
+python tests/python/pynq/test_program_rpc.py
 ```
 
 > Tip: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq `ssh` session.
@@ -244,7 +220,7 @@ python tests/python/pynq/test_program_rpc.py
 We are now ready to run the 2D convolution testbench for the ResNet-15 workload in hardware.
 
 ```bash
-python tests/python/pynq/test_benchmark_conv2d.py 
+python tests/python/pynq/test_benchmark_conv2d.py
 ```
 
 The performance metrics measured on the Pynq board will be reported for each convolutional layer.
@@ -280,7 +256,7 @@ You’ll need to install Xilinx’ FPGA compilation toolchain, [Vivado HL WebPAC
 ```bash
 chmod u+x Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin
 ```
-5. Now you can execute the binary: 
+5. Now you can execute the binary:
 ```bash
 ./Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin
 ```
@@ -337,7 +313,7 @@ If you just want to generate the HLS-based VTA IP cores without launching the en
 make ip
 ```
 You'll be able to view the HLS synthesis reports under `<vta root>/build/hardware/xilinx/hls/<configuration>/<block>/solution0/syn/report/<block>_csynth.rpt`
-> Note: The `<configuration>` name is a string that summarizes the VTA configuration parameters specified in the `config.json`. The `<block>` name refers to the specific module in the VTA pipeline. 
+> Note: The `<configuration>` name is a string that summarizes the VTA configuration parameters specified in the `config.json`. The `<block>` name refers to the specific module in the VTA pipeline.
 
 Finally to run the full hardware compilation and generate the bitstream, run:
 
diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py
index 5fce83231c5c5cf31d2450b664cae2660b072bf6..f22ec0a8e9771966c5d464cbf2d6ee813da3971e 100644
--- a/vta/examples/resnet18/pynq/imagenet_predict.py
+++ b/vta/examples/resnet18/pynq/imagenet_predict.py
@@ -26,8 +26,8 @@ data_dir = "_data/"
 url = "https://homes.cs.washington.edu/~moreau/media/vta/"
 TEST_FILE = 'cat.jpg'
 CATEG_FILE = 'synset.txt'
-RESNET_GRAPH_FILE = 'quantize_graph.json'
-RESNET_PARAMS_FILE = 'quantize_params.pkl'
+RESNET_GRAPH_FILE = 'resnet18_qt8.json'
+RESNET_PARAMS_FILE = 'resnet18_qt8_params.pkl'
 # Create data dir
 if not os.path.exists(data_dir):
     os.makedirs(data_dir)
@@ -70,7 +70,7 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
         attrs = node["attrs"]
         node_name = node["name"]
         func_name = attrs["func_name"]
-        if func_name.find("quantized_conv2d") != -1:
+        if func_name.find("conv2d") != -1:
             if conv_layer >= 0:
                 if counter != conv_layer:
                     attrs["func_name"] = "__nop"
@@ -109,9 +109,9 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
 graph = graph.apply("InferShape").apply("InferType")
 
 dtype = "float32"
-sym = vta.graph.remove_stochastic(sym)
 sym = vta.graph.clean_cast(sym)
 sym = vta.graph.clean_conv_fuse(sym)
+
 if target.device_name == "vta":
     sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor)
 
@@ -166,8 +166,10 @@ def run_e2e(graph):
     # get outputs
     tvm_output = m.get_output(
         0,tvm.nd.empty((1000,), dtype, remote.cpu(0)))
-    top1 = np.argmax(tvm_output.asnumpy())
-    print('TVM prediction top-1:', top1, synset[top1])
+
+    top = list(reversed(np.argsort(tvm_output.asnumpy())))
+    for i in range(5):
+        print('TVM prediction top-%d: %s' % (i, synset[top[i]]))
     print("t-cost=%g" % tcost.mean)
 
 
diff --git a/vta/python/vta/graph.py b/vta/python/vta/graph.py
index 41a38c2b655b83ea3779d142a9bb12592312af0d..7f2a26fdc4bfa75540998ae23e7fa4a0c3024ea3 100644
--- a/vta/python/vta/graph.py
+++ b/vta/python/vta/graph.py
@@ -71,48 +71,6 @@ def _get_shape(sym, shape_dict):
     return graph_util.infer_shape(
         nnvm.graph.create(sym), **shape_dict)[1][0]
 
-
-def remove_stochastic(graph):
-    """
-    Replace stochastic rounding and shift with determinstic version.
-
-    Parameters
-    ----------
-    graph : Graph
-        The input graph
-
-    Returns
-    -------
-    replaced_graph : Graph
-        The final replaced graph.
-    """
-    gidx = graph.index
-    node_map = {}
-
-    for nid, node in enumerate(gidx.nodes):
-        children = [node_map[e[0]] for e in node["inputs"]]
-        attrs = node.get("attrs", {})
-        node_name = node["name"]
-        op_name = node["op"]
-        get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
-            *c, name=n_n, **a)
-        if op_name == "null":
-            new_node = nnvm.symbol.Variable(node_name)
-        elif op_name == "stochastic_round":
-            new_node = children[0]
-        elif op_name == "noise_lshift":
-            new_node = nnvm.symbol.left_shift(
-                children[0], **attrs)
-        else:
-            new_node = get_clone(children, op_name, node_name, attrs)
-        node_map[nid] = new_node
-
-    assert len(graph.index.output_entries) == 1
-    ret = node_map[graph.index.output_entries[0][0]]
-    ret = nnvm.graph.create(ret)
-    return ret
-
-
 def clean_conv_fuse(graph):
     """Cleanup the convolution's later fuse stages
 
@@ -131,8 +89,8 @@ def clean_conv_fuse(graph):
         if flag:
             node = nnvm.symbol.clip(node, a_max=127, a_min=-127)
             node = nnvm.symbol.cast(node, dtype="int8")
-            # Use identity as a hint to block conv2d schedules
-            node = nnvm.symbol.identity(node)
+            # Use copy as a hint to block conv2d schedules
+            node = nnvm.symbol.copy(node)
             flag = False
         return node, flag
 
@@ -166,13 +124,13 @@ def clean_conv_fuse(graph):
                 new_entry = (
                     get_clone([children[0][0]], op_name, node_name, attrs),
                     False)
-        elif op_name == "quantized_conv2d":
+        elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
             data, weight = children
             data = _clean_entry(data)
-            new_node = nnvm.sym.quantized_conv2d(
+            new_node = nnvm.sym.conv2d(
                 data[0], weight[0], name=node_name, **attrs)
             new_entry = (new_node, True)
-        elif op_name in ("left_shift", "right_shift", "relu"):
+        elif op_name in ("__lshift_scalar__", "__rshift_scalar__", "relu"):
             new_entry = (
                 get_clone([children[0][0]], op_name, node_name, attrs),
                 children[0][1])
@@ -199,7 +157,6 @@ def clean_conv_fuse(graph):
     ret = nnvm.graph.create(ret)
     return ret
 
-
 def clean_cast(graph):
     """
     Move the casts to early part of graph,
@@ -232,11 +189,11 @@ def clean_cast(graph):
         elif op_name == "cast":
             dtype = attrs["dtype"]
             new_node, _ = _clean_cast(children[0], dtype)
-        elif op_name == "quantized_conv2d":
+        elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
             data, weight = children
             data, _ = _clean_cast(data, "int8")
             weight, _ = _clean_cast(weight, "int8")
-            new_node = nnvm.sym.quantized_conv2d(
+            new_node = nnvm.sym.conv2d(
                 data, weight, name=node_name, **attrs)
         elif op_name == "elemwise_add":
             lhs, rhs = children
@@ -314,21 +271,21 @@ def pack(graph, shape_dict, bfactor, cfactor, start_name=None):
                     *children, name=node_name, **attrs)
             else:
                 new_node = get_clone(children, op_name, node_name, attrs)
-        elif op_name == "quantized_conv2d":
+        elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
             if start_pack:
-                attrs["pack_batch"] = str(bfactor)
-                attrs["pack_channel"] = str(cfactor)
+                attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
+                attrs["kernel_layout"] = "OIHW%do%di" % (cfactor, cfactor)
                 data, weight = children
                 weight = _pack_weight(weight, ishape[1], cfactor)
-                new_node = nnvm.sym.quantized_conv2d(
+                new_node = nnvm.sym.conv2d(
                     data, weight, name=node_name, **attrs)
             elif counter == 1:
-                attrs["pack_batch"] = str(bfactor)
-                attrs["pack_channel"] = str(cfactor)
+                attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
+                attrs["kernel_layout"] = "OIHW%do%di" % (cfactor, cfactor)
                 data, weight = children
                 data = _pack_batch_channel(data, ishape[0], bfactor, cfactor)
                 weight = _pack_weight(weight, ishape[1], cfactor)
-                new_node = nnvm.sym.quantized_conv2d(
+                new_node = nnvm.sym.conv2d(
                     data, weight, name=node_name, **attrs)
                 new_node = _unpack_batch_channel(new_node, oshape)
                 counter = counter + 1
diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py
index 489acb1ce12d7b6c1312349d5aa93eb1509d6d59..258d2b20bb8582620820b9dfc7753a0390badfb2 100644
--- a/vta/python/vta/top/vta_conv2d.py
+++ b/vta/python/vta/top/vta_conv2d.py
@@ -215,7 +215,7 @@ def _lower(sch, inputs, func_name, graph):
         f, (tvm.container.Array, tuple, list)) else [f]
 
 
-@reg.register_compute("clip", level=11)
+@reg.register_compute("clip", level=15)
 def compute_clip(attrs, inputs, _):
     """ Clip operator.
     """
@@ -231,11 +231,24 @@ def compute_clip(attrs, inputs, _):
             x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
     return x
 
+# override to force partition at copy
+reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
 
-reg.register_pattern("identity", OpPattern.INJECTIVE, level=11)
+def is_packed_layout(layout):
+    """Check if layout is packed layout"""
+    if layout == "NCHW":
+        return False
+    assert "n" in layout
+    assert "c" in layout
+    return True
 
-@reg.register_compute("quantized_conv2d", level=11)
-def compute_quantized_conv2d(attrs, inputs, out):
+@reg.register_alter_op_layout("conv2d", level=15)
+def alter_conv2d_layout(*_):
+    return None
+
+
+@reg.register_compute("conv2d", level=15)
+def compute_conv2d(attrs, inputs, out):
     """ 2D convolution algorithm.
     """
     padding = attrs.get_int_tuple("padding")
@@ -244,36 +257,30 @@ def compute_quantized_conv2d(attrs, inputs, out):
     groups = attrs.get_int("groups")
     channels = attrs.get_int("channels")
     layout = attrs["layout"]
-    out_dtype = attrs['out_type']
-    cmp_dtype = 'int32' # compute data type
-
-    assert layout == "NCHW", "only support nchw for now"
+    out_dtype = attrs['out_dtype']
     assert dilation == (1, 1), "not support dilate now"
     assert attrs.get_bool("use_bias") is False
-    pack_channel = attrs.get_int("pack_channel")
-    if pack_channel != 0:
+    if is_packed_layout(layout):
         assert groups == 1
         return packed_conv2d(inputs[0], inputs[1],
-                             padding, strides)
+                             padding, strides, out_dtype=out_dtype)
     if groups == 1:
-        out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype)
+        out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, out_dtype=out_dtype)
     elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
         out = topi.nn.depthwise_conv2d_nchw(
-            inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype)
+            inputs[0], inputs[1], strides, padding, out_dtype=out_dtype)
     else:
         raise ValueError("not support arbitrary group number for now")
-
-    assert out_dtype == cmp_dtype
     return out
 
 
-@reg.register_schedule("quantized_conv2d", level=11)
+@reg.register_schedule("conv2d", level=15)
 def schedule_quantized_conv2d(attrs, outs, target):
     """ 2D convolution schedule.
     """
-    channels = attrs.get_int("channels")
-    pack_channel = attrs.get_int("pack_channel")
-    if channels != 0 and pack_channel:
+    layout = attrs["layout"]
+
+    if is_packed_layout(layout):
         target = tvm.target.create(target)
         if target.device_name == "vta":
             return schedule_packed_conv2d(outs)
diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py
index b70fc9cfbf8535fa3fea684a46318d6d1276b59d..24d9968dc31a41fc246e4c22d92a1b996306643d 100644
--- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py
+++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py
@@ -38,7 +38,7 @@ def test_vta_conv2d():
         res_conv = vta.top.packed_conv2d(
             data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
         res = topi.right_shift(res_conv, 8)
-        res = topi.broadcast_add(res, bias)
+        res = topi.add(res, bias)
         res = my_clip(res, 0, 127)
         res = topi.cast(res, "int8")