Skip to content
Snippets Groups Projects
Commit a81ebd90 authored by Siva's avatar Siva Committed by Tianqi Chen
Browse files

[NNVM][FRONTEND] Tensorflow frontend support (#1188)

parent 7afeab07
No related branches found
No related tags found
No related merge requests found
......@@ -270,6 +270,10 @@ def build(graph, target=None, shape=None, dtype="float32",
# Apply optimization
with target:
graph = optimize(graph, shape, dtype, layout)
# Clear extra params without nodes.
_remove_noref_params(params, graph)
# Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
......@@ -296,6 +300,24 @@ def build(graph, target=None, shape=None, dtype="float32",
params.update(init_var)
return graph, libmod, params
def _remove_noref_params(params, graph):
""" Helper to clear non referenced params
Parameters
----------
graph : Graph
The input graph
params: dict of str to ndarray
The parameter dictionary
"""
arg_list = set(graph.symbol.list_input_names())
if params:
param_keys = list(params.keys())
for key in param_keys:
if key not in arg_list:
params.pop(key)
def _run_graph(graph, params):
"""Helper utility to build and run and get outputs, only use cpu mode.
......
......@@ -5,3 +5,4 @@ from .onnx import from_onnx
from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
from .tensorflow import from_tensorflow
This diff is collapsed.
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
"""
Tensorflow Model Helpers
========================
Some helper definitions for tensorflow models.
"""
import re
import os.path
import numpy as np
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
######################################################################
# Some helper functions
# ---------------------
def ProcessGraphDefParam(graph_def):
"""Type-checks and possibly canonicalizes `graph_def`.
Parameters
----------
graph_def : Obj
tensorflow graph definition.
Returns
-------
graph_def : Obj
tensorflow graph devinition
"""
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
return graph_def
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
def __init__(self,
label_lookup_path=None,
uid_lookup_path=None):
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
"""Loads a human readable English name for each softmax node.
Parameters
----------
label_lookup_path: String
File containing String UID to integer node ID mapping .
uid_lookup_path: String
File containing String UID to human-readable string mapping.
Returns
-------
node_id_to_name: dict
dict from integer node ID to human-readable string.
"""
if not tf.gfile.Exists(uid_lookup_path):
tf.logging.fatal('File does not exist %s', uid_lookup_path)
if not tf.gfile.Exists(label_lookup_path):
tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
# Loads mapping from string UID to integer node ID.
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
def read_normalized_tensor_from_image_file(file_name,
input_height=299,
input_width=299,
input_mean=0,
input_std=255):
""" Preprocessing of image
Parameters
----------
file_name: String
Image filename.
input_height: int
model input height.
input_width: int
model input width
input_mean: int
Mean to be substracted in normalization.
input_std: int
Standard deviation used in normalization.
Returns
-------
np_array: Numpy array
Normalized image data as a numpy array.
"""
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
image_reader = tf.image.decode_jpeg(file_reader, channels=3,
name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
tf.InteractiveSession()
np_array = normalized.eval()
return np_array
def get_workload_inception_v3():
""" Import Inception V3 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(normalized, graph_def) : Tuple
normalized is normalized input for graph testing.
graph_def is the tensorflow workload for Inception V3.
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (normalized, graph_def)
def get_workload_inception_v1():
""" Import Inception V1 workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
(image_data, tvm_data, graph_def) : Tuple
image_data is raw encoded image data for TF input.
tvm_data is the decoded image data for TVM input.
graph_def is the tensorflow workload for Inception V1.
"""
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)
from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)
if not tf.gfile.Exists(os.path.join("./", image_name)):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(os.path.join("./", image_name), 'rb').read()
# TVM doesn't handle decode, hence decode it.
from PIL import Image
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
tvm_data = np.array(tvm_data)
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (image_data, tvm_data, graph_def)
......@@ -52,6 +52,15 @@ reg.register_schedule("_assign", _fschedule_broadcast)
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)
# cast
@reg.register_compute("cast")
def compute_cast(attrs, inputs, _):
"""Compute definition of cast"""
dtype = attrs.get_string("dtype")
return topi.cast(inputs[0], dtype)
reg.register_pattern("cast", OpPattern.ELEMWISE)
reg.register_schedule("cast", _fschedule_broadcast)
# exp
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)
......
# pylint: disable=import-self, invalid-name, unused-argument
"""
Tensorflow testcases
====================
This article is a test script to test tensorflow operator with NNVM.
"""
from __future__ import print_function
import numpy as np
import nnvm.compiler
import tvm
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf
#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype):
""" Generic function to compile on nnvm and execute on tvm """
sym, params = nnvm.frontend.from_tensorflow(graph_def)
target = 'llvm'
if isinstance(input_data, list):
shape_dict = {}
dtype_dict = {}
for i, e in enumerate(input_node):
shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype
else:
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
dtype=dtype_dict, params=params)
ctx = tvm.cpu(0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
if isinstance(input_data, list):
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
else:
m.set_input(input_node, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
return tvm_output.asnumpy()
def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """
tensor = sess.graph.get_tensor_by_name(output_node)
if isinstance(input_data, list):
input_dict = {}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
else:
input_dict = {input_node: input_data}
output_data = sess.run(tensor, input_dict)
return output_data
#######################################################################
# Pooling
# -------
def _test_pooling(input_shape, **kwargs):
""" One iteration of pool operation with given shapes and attributes """
x = -np.arange(
np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
with tf.Graph().as_default():
in_data = constant_op.constant(x, shape=input_shape, dtype='float32')
# pylint: disable=unused-variable
pool = nn_ops.pool(in_data, **kwargs)
# pylint: enable=unused-variable
if kwargs['pooling_type'] == 'MAX':
out_node = 'max_pool'
out_name = 'max_pool:0'
else:
out_node = 'avg_pool'
out_name = 'avg_pool:0'
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
[out_node],
)
tf_output = run_tf_graph(sess, x, 'Const:0', out_name)
tvm_output = run_tvm_graph(graph_def, x.astype('float32'),
"Const", tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
sess.close()
def test_forward_pooling():
""" Pooling """
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='MAX',
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='MAX',
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[1, 1])
#######################################################################
# Convolution
# -----------
def _test_convolution(tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format):
""" One iteration of convolution with given shapes and attributes """
total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
with tf.Graph().as_default():
in_data = constant_op.constant(data_array, shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
# pylint: disable=unused-variable
conv = nn_ops.conv2d(in_data,
in_filter,
strides=strides,
padding=padding,
data_format=data_format)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['Conv2D'],
)
tf_output = run_tf_graph(sess, np.reshape(data_array, tensor_in_sizes),
'Const:0', 'Conv2D:0')
tvm_output = run_tvm_graph(graph_def,
np.reshape(data_array, tensor_in_sizes).astype('float32'),
"Const", tf_output.shape, 'float32')
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
sess.close()
def test_forward_convolution():
_test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
#######################################################################
# Reshape
# -------
def _test_reshape(data, out_shape):
""" One iteration of reshape operation with given data and out shape """
with tf.Graph().as_default():
in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)
# pylint: disable=unused-variable
reshape_out = array_ops.reshape(in_data, out_shape)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['Reshape'],
)
tf_output = run_tf_graph(sess, data,
'Const:0', 'Reshape:0')
tvm_output = run_tvm_graph(graph_def,
data,
"Const", tf_output.shape, data.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
def test_forward_reshape():
_test_reshape(np.arange(6.0), [2, 3])
_test_reshape(np.arange(6), [-1, 2])
_test_reshape(np.arange(6), [3, -1])
_test_reshape(np.arange(6), [-1])
#######################################################################
# Squeeze
# -------
def _test_squeeze(data, squeeze_dims=None):
""" One iteration of squeeze """
if squeeze_dims is None:
squeeze_dims = []
with tf.Graph().as_default():
in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)
# pylint: disable=unused-variable
if squeeze_dims:
squeeze_out = array_ops.squeeze(in_data, squeeze_dims)
else:
squeeze_out = array_ops.squeeze(in_data)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['Squeeze'],
)
tf_output = run_tf_graph(sess, data,
'Const:0', 'Squeeze:0')
tvm_output = run_tvm_graph(graph_def,
data,
"Const", tf_output.shape, data.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
def test_forward_squeeze():
""" Squeeze """
# Nothing to squeeze.
_test_squeeze(np.arange(2).reshape((2)))
_test_squeeze(np.arange(6).reshape((2, 3)))
# Squeeze the middle element away.
_test_squeeze(np.arange(4).reshape((2, 1, 2)))
# Squeeze on both ends.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)))
# Positive squeeze dim index.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [2, 4])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0, 4, 2])
# Negative squeeze dim index.
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-1])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
#######################################################################
# ConcatV2
# --------
def _test_concat_v2(data, dim):
""" One iteration of ConcatV2 """
with tf.Graph().as_default():
# pylint: disable=unused-variable
concat_out = gen_array_ops._concat_v2(data, dim)
# pylint: enable=unused-variable
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['ConcatV2'],
)
tf_output = run_tf_graph(sess, data,
['ConcatV2/values_0:0', 'ConcatV2/values_1:0'], 'ConcatV2:0')
tvm_output = run_tvm_graph(graph_def,
data,
["ConcatV2/values_0", 'ConcatV2/values_1'],
tf_output.shape, tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
def _test_forward_concat_v2():
t1 = np.array([])
t2 = np.array([])
test_concat_v2([t1, t2], 0)
t1 = np.array([[1, 2, 3], [4, 5, 6]])
t2 = np.array([[7, 8, 9], [10, 11, 12]])
_test_concat_v2([t1, t2], 1)
#######################################################################
# Multi Input to graph
# --------------------
def test_forward_multi_input():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
out1 = tf.add(in1, in2, name='out1')
out2 = tf.subtract(in3, in4, name='out2')
out = tf.multiply(out1, out2, name='out')
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['out'],
)
in_data = np.arange(9, dtype='int32').reshape([3, 3])
tf_output = run_tf_graph(sess, [in_data, in_data, in_data, in_data ],
['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0')
tvm_output = run_tvm_graph(graph_def,
[in_data, in_data, in_data, in_data ],
['in1', 'in2', 'in3', 'in4'],
tf_output.shape, tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output)
sess.close()
#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
(data, graph_def) = nnvm.testing.tf.get_workload_inception_v3()
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
tvm_output = run_tvm_graph(graph_def, data, 'input', (1, 1001), 'float32')
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
top_tvm = np.squeeze(tvm_output).argsort()[-3:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-3:][::-1]
# TVM implementation of SAME padding some times make a slight deviation.
# Hence check for top predictions.
top_tvm = np.sort(top_tvm)
top_tf = np.sort(top_tf)
np.testing.assert_allclose(top_tf, top_tvm)
#######################################################################
# Inception V1
# ------------
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
(data, tvm_data, graph_def) = nnvm.testing.tf.get_workload_inception_v1()
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', (1, 1008), 'float32')
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)
#######################################################################
# Main
# ----
if __name__ == '__main__':
test_forward_convolution()
test_forward_pooling()
test_forward_reshape()
test_forward_squeeze()
if tf.__version__ == '1.4.1':
_test_forward_concat_v2()
test_forward_multi_input()
test_forward_inception_v3()
test_forward_inception_v1()
......@@ -18,3 +18,6 @@ python3 -m nose -v nnvm/tests/python/frontend/mxnet || exit -1
echo "Running Keras frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/keras || exit -1
echo "Running Tensorflow frontend test..."
python3 -m nose -v nnvm/tests/python/frontend/tensorflow || exit -1
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with NNVM.
For us to begin with, tensorflow module is required to be installed.
A quick solution is to install tensorlfow from
https://www.tensorflow.org/install/install_sources
"""
import nnvm
import tvm
import numpy as np
import os.path
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
import nnvm.testing.tf
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)
######################################################################
# Download processed tensorflow model
# -----------------------------------
# In this section, we download a pretrained Tensorflow model and classify an image.
from mxnet.gluon.utils import download
download(image_url, img_name)
download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)
######################################################################
# Creates graph from saved graph_def.pb.
# --------------------------------------
with tf.gfile.FastGFile(os.path.join(
"./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
######################################################################
# Decode image
# ------------
from PIL import Image
image = Image.open(img_name).resize((299, 299))
def transform_image(image):
image = np.array(image)
return image
x = transform_image(image)
######################################################################
# Import the graph to NNVM
# ------------------------
sym, params = nnvm.frontend.from_tensorflow(graph_def)
######################################################################
# Now compile the graph through NNVM
import nnvm.compiler
target = 'llvm'
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params)
######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now, we would like to reproduce the same forward computation using TVM.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
######################################################################
# Process the output to human readable
# ------------------------------------
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
######################################################################
# Run the same graph with tensorflow and dump output.
# ---------------------------------------------------
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(model_name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = nnvm.testing.tf.NodeLookup(label_lookup_path=os.path.join("./", map_proto),
uid_lookup_path=os.path.join("./", lable_map))
top_k = predictions.argsort()[-5:][::-1]
print ("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
run_inference_on_image (img_name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment