From dc8fd79c43b4425dcba22b2e10a81f9d11718fbd Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Tue, 25 Dec 2018 03:16:32 +0800
Subject: [PATCH] [RELAY] Add missing arg in vgg (#2329)

---
 python/tvm/relay/testing/vgg.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/testing/vgg.py b/python/tvm/relay/testing/vgg.py
index 811de33c5..bec141f70 100644
--- a/python/tvm/relay/testing/vgg.py
+++ b/python/tvm/relay/testing/vgg.py
@@ -98,7 +98,8 @@ def get_workload(batch_size,
                  num_classes=1000,
                  image_shape=(3, 224, 224),
                  dtype="float32",
-                 num_layers=11):
+                 num_layers=11,
+                 batch_norm=False):
     """Get benchmark workload for VGG nets.
 
     Parameters
@@ -118,6 +119,9 @@ def get_workload(batch_size,
     num_layers : int
         Number of layers for the variant of vgg. Options are 11, 13, 16, 19.
 
+    batch_norm : bool
+        Use batch normalization.
+
     Returns
     -------
     net : nnvm.Symbol
@@ -126,5 +130,5 @@ def get_workload(batch_size,
     params : dict of str to NDArray
         The parameters.
     """
-    net = get_net(batch_size, image_shape, num_classes, dtype, num_layers)
+    net = get_net(batch_size, image_shape, num_classes, dtype, num_layers, batch_norm)
     return create_workload(net)
-- 
GitLab