diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index e7282eb9afd6bd49fc7c95f38086330d82eb92ca..13ed717b045097658eb626ffc6dcc3e5023a127f 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -1039,7 +1039,7 @@ class GraphProto(object):
         self._num_param = 0
         self._num_rnn_layer = False
 
-    def from_tensorflow(self, graph, layout="NHWC", shape=None):
+    def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
         """Construct nnvm nodes from tensorflow  graph definition - GraphDef.
 
         Follow the tensorflow graph definition to parse and convert it to NNVM.
@@ -1086,6 +1086,7 @@ class GraphProto(object):
             raise NotImplementedError( \
                 "The following operators are not implemented: {}".format(missing_operators))
 
+        final_op = None
         # Parse the nodes to re-create TF graph using Symbol API of NNVM
         for node in graph.node:
             # Tensorflow doesn't have seperate list for params extraction.
@@ -1165,6 +1166,7 @@ class GraphProto(object):
 
                 # Assuming only one output.
                 self._nodes[node.name] = op
+                final_op = op
 
             # Infer shapes if passed explicitely
             node_output = self._nodes[node.name]
@@ -1175,13 +1177,16 @@ class GraphProto(object):
                 _, out_shapes = graph_util.infer_shape(g, **shape_dict)
                 self._output_shapes[node.name] = out_shapes
 
-        # Assume the final node is the output node
-        out = node_output
+        out = []
+        if outputs is None:
+            out.append(final_op)
+        else:
+            out = [self._nodes[out_name] for out_name in outputs]
 
         #Add the RNN outputs also with 'head' nodes of the nnvm graph
         if self._num_rnn_layer:
             out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
-            out = [out, out_rnn]
+            out.append(out_rnn)
 
         if isinstance(out, list):
             out = _sym.Group(out)
@@ -1378,7 +1383,7 @@ class GraphProto(object):
 
         return inputs
 
-def from_tensorflow(graph, layout="NHWC", shape=None):
+def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
     """  Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
     The companion parameters will be handled automatically.
 
@@ -1396,5 +1401,5 @@ def from_tensorflow(graph, layout="NHWC", shape=None):
         Dict of converted parameters stored in tvm.ndarray format
     """
     g = GraphProto()
-    sym, params = g.from_tensorflow(graph, layout, shape)
+    sym, params = g.from_tensorflow(graph, layout, shape, outputs)
     return sym, params
diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py
index 62d3577ba10ae5ee917b50f29c4f4a135e9a7cc3..e93f14ceb96892a875e289815d900a3b0eaa99d2 100644
--- a/nnvm/tests/python/frontend/tensorflow/test_forward.py
+++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py
@@ -26,8 +26,15 @@ import nnvm.testing.tf
 #######################################################################
 # Generic run functions for TVM & tensorflow
 # ------------------------------------------
-def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'):
+def convert_to_list(x):
+    if not isinstance(x, list):
+        x = [x]
+    return x
+
+def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None):
     """ Generic function to compile on nnvm and execute on tvm """
+    input_data = convert_to_list(input_data)
+    input_node = convert_to_list(input_node)
 
     layout = None
     if target == "cuda":
@@ -43,8 +50,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
     else:
         shape_dict = {input_node: input_data.shape}
         dtype_dict = {input_node: input_data.dtype}
-
-    sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
+   
+    sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=out_names)
     graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
                                              dtype=dtype_dict, params=params)
 
@@ -52,37 +59,34 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
     from tvm.contrib import graph_runtime
     m = graph_runtime.create(graph, lib, ctx)
     # set inputs
-    if isinstance(input_data, list):
-        for i, e in enumerate(input_node):
-            m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
-    else:
-        m.set_input(input_node, tvm.nd.array(input_data.astype(input_data.dtype)))
+    for i, e in enumerate(input_node):
+        m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
 
     m.set_input(**params)
     # execute
     m.run()
     # get outputs
-    if num_output > 1:
-        tvm_output_list = []
-        for i in range(0, num_output):
-            tvm_output = m.get_output(i)
-            tvm_output_list.append(tvm_output.asnumpy())
-        return tvm_output_list
-    else:
-        tvm_output = m.get_output(0)
-        return tvm_output.asnumpy()
+    assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format(
+                                                              out_names, num_output)
+    tvm_output_list = []
+    for i in range(0, num_output):
+        tvm_output = m.get_output(i)
+        tvm_output_list.append(tvm_output.asnumpy())
+    return tvm_output_list
 
 def run_tf_graph(sess, input_data, input_node, output_node):
     """ Generic function to execute tensorflow """
+    input_data = convert_to_list(input_data)
+    input_node = convert_to_list(input_node)
+    output_node = convert_to_list(output_node)
 
-    tensor = sess.graph.get_tensor_by_name(output_node)
+    tensor = [0] * len(output_node)
+    for i in range(len(output_node)):
+        tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
 
-    if isinstance(input_data, list):
-        input_dict = {}
-        for i, e in enumerate(input_node):
-            input_dict[e] = input_data[i]
-    else:
-        input_dict = {input_node: input_data}
+    input_dict = {}
+    for i, e in enumerate(input_node):
+        input_dict[e] = input_data[i]
 
     output_data = sess.run(tensor, input_dict)
     return output_data
@@ -91,14 +95,16 @@ def run_tf_graph(sess, input_data, input_node, output_node):
 def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
     """Generic function to generate and compare tensorflow and TVM output"""
 
-    out_node = out_name.split(':')[0] if ":" in out_name else out_name
+    out_name = convert_to_list(out_name)
+    out_node = [0]*len(out_name)
+    for i in range(len(out_name)):
+        out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
 
-    if isinstance(in_name, list):
-        in_node = [0]*len(in_name)
-        for i in range(len(in_name)):
-            in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
-    else:
-        in_node = in_name.split(':')[0] if ":" in in_name else in_name
+    in_data = convert_to_list(in_data)
+    in_name = convert_to_list(in_name)
+    in_node = [0]*len(in_name)
+    for i in range(len(in_name)):
+        in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
 
     with tf.Session() as sess:
         if init_global_variables:
@@ -106,9 +112,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
         final_graph_def = tf.graph_util.convert_variables_to_constants(
             sess,
             sess.graph.as_graph_def(add_shapes=True),
-            [out_node],
+            out_node,
             )
-
         tf_output = run_tf_graph(sess, in_data, in_name, out_name)
 
         for device in ["llvm", "cuda"]:
@@ -120,7 +125,10 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
                 continue
 
             tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
-            tvm.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
+            # since the names from tensorflow and nnvm runs are not exactly same, 
+            # first len(tf_output) will be compared
+            for i in range(len(tf_output)):
+                tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
         sess.close()
 
@@ -259,6 +267,7 @@ def test_forward_reshape():
     _test_reshape(np.arange(6), [3, -1])
     _test_reshape(np.arange(6), [-1])
 
+#######################################################################
 #######################################################################
 # Squeeze
 # -------
@@ -508,6 +517,35 @@ def test_forward_multi_input():
         compare_tf_with_tvm([in_data, in_data, in_data, in_data],
                             ['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0')
 
+#######################################################################
+# Multi Output to Graph
+# ---------------------
+
+def test_forward_multi_output():
+    with tf.Graph().as_default():
+        in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
+        in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
+        in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
+        in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
+
+        out1 = tf.add(in1, in2, name='out1')
+        out2 = tf.subtract(in3, in4, name='out2')
+        in_data = np.arange(9, dtype='int32').reshape([3, 3])
+        in_data = [in_data] * 4
+        in_name = ['in1:0', 'in2:0', 'in3:0', 'in4:0']
+        out_name = ['out1:0', 'out2:0']
+        out_node = [out.strip(':0') for out in out_name]
+        in_node = [inp.strip(':0') for inp in in_name]
+        
+        with tf.Session() as sess:
+            final_graph_def = tf.graph_util.convert_variables_to_constants(
+                sess, sess.graph.as_graph_def(add_shapes=True), out_node,)
+            tf_output = run_tf_graph(sess, in_data, in_name, out_name)
+            tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm',
+                                       out_names=out_node, num_output=2)
+            for i in range(len(tf_output)):
+                tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+
 #######################################################################
 # Resize Bilinear
 # ---------------
@@ -580,7 +618,7 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
     out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
     out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
     tvm_out = [out, out_state_c, out_state_h]
-    tvm.testing.assert_allclose(tf_out, tvm_out, rtol=1e-3, atol=1e-3)
+    tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3)
 
 def test_forward_lstm():
     '''test LSTM block cell'''
@@ -653,7 +691,7 @@ def test_forward_inception_v3():
         with tf.Session() as sess:
             tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
             tvm_output = run_tvm_graph(graph_def, data, 'input')
-            tvm.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
 
 #######################################################################
 # Inception V1
@@ -689,7 +727,7 @@ def test_forward_inception_v1():
         with tf.Session() as sess:
             tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
             tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents')
-            tvm.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
 
 #######################################################################
 # Mobilenet
@@ -712,7 +750,7 @@ def test_forward_mobilenet():
             graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
             tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
             tvm_output = run_tvm_graph(graph_def, data, 'input')
-            tvm.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
 
 #######################################################################
 # ResnetV2
@@ -731,7 +769,7 @@ def test_forward_resnetv2():
             with tf.Session() as sess:
                 tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
                 tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
-                tvm.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
+                tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
 
 #######################################################################
 # PTB
@@ -797,6 +835,7 @@ def test_forward_ptb():
             state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
                                                         "float32")).asnumpy()
             sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
+
             return sample, state_output
 
         for x in data:
@@ -942,7 +981,7 @@ def test_forward_leaky_relu():
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
         tf.nn.leaky_relu(in1, alpha=0.4)
-        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu:0')
+        compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu/mul:0')
 
 def test_forward_elu():
     ishape = (1, 3, 10, 10)
@@ -1042,6 +1081,7 @@ if __name__ == '__main__':
 
     # General
     test_forward_multi_input()
+    test_forward_multi_output()
     test_forward_variable()
 
     # End to End