From dc8fd79c43b4425dcba22b2e10a81f9d11718fbd Mon Sep 17 00:00:00 2001 From: Wuwei Lin <vincentl13x@gmail.com> Date: Tue, 25 Dec 2018 03:16:32 +0800 Subject: [PATCH] [RELAY] Add missing arg in vgg (#2329) --- python/tvm/relay/testing/vgg.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/testing/vgg.py b/python/tvm/relay/testing/vgg.py index 811de33c5..bec141f70 100644 --- a/python/tvm/relay/testing/vgg.py +++ b/python/tvm/relay/testing/vgg.py @@ -98,7 +98,8 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32", - num_layers=11): + num_layers=11, + batch_norm=False): """Get benchmark workload for VGG nets. Parameters @@ -118,6 +119,9 @@ def get_workload(batch_size, num_layers : int Number of layers for the variant of vgg. Options are 11, 13, 16, 19. + batch_norm : bool + Use batch normalization. + Returns ------- net : nnvm.Symbol @@ -126,5 +130,5 @@ def get_workload(batch_size, params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, image_shape, num_classes, dtype, num_layers) + net = get_net(batch_size, image_shape, num_classes, dtype, num_layers, batch_norm) return create_workload(net) -- GitLab