From 6a3a9572b94c5198a4f769e9a6d7a26f0f15d897 Mon Sep 17 00:00:00 2001
From: Sergei Grechanik <grechanik.sergey@huawei.com>
Date: Tue, 2 Oct 2018 06:43:38 +0300
Subject: [PATCH] [NNVM][TEST] Numgrad: fix nan and multioutput (#1754)

---
 nnvm/python/nnvm/testing/check_computation.py | 84 ++++++++++---------
 nnvm/tests/python/compiler/test_top_level1.py |  1 +
 2 files changed, 46 insertions(+), 39 deletions(-)

diff --git a/nnvm/python/nnvm/testing/check_computation.py b/nnvm/python/nnvm/testing/check_computation.py
index a207e8eb8..76d7b66b1 100644
--- a/nnvm/python/nnvm/testing/check_computation.py
+++ b/nnvm/python/nnvm/testing/check_computation.py
@@ -55,84 +55,84 @@ def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
     """
     # Preprocess input parameters
     if shape is None:
-        shape = {}
+        provided_shapes = {}
+    elif isinstance(shape, dict):
+        provided_shapes = shape
+    else:
+        provided_shapes = {x: shape for x in graph.symbol.list_input_variables()}
 
     if dtype is None:
-        dtype = {}
-
-    if not isinstance(shape, dict):
-        shape = {x: shape for x in graph.symbol.list_input_variables()}
-
-    if not isinstance(dtype, dict):
-        dtype = {x: dtype for x in graph.symbol.list_input_variables()}
+        provided_dtypes = {}
+    elif isinstance(dtype, dict):
+        provided_dtypes = dtype
+    else:
+        provided_dtypes = {x: dtype for x in graph.symbol.list_input_variables()}
 
-    shape = _dict_var_to_dict_str(shape)
-    dtype = _dict_var_to_dict_str(dtype)
+    provided_shapes = _dict_var_to_dict_str(provided_shapes)
+    provided_dtypes = _dict_var_to_dict_str(provided_dtypes)
 
     # The graph may already contain shape and dtype info, so extract it and merge with
     # the user-specified shapes and dtypes (use the user-specified one on contradiction)
-    all_initial_shapes = graph.json_attr('shape')
-    all_initial_dtypes = graph.json_attr('dtype')
+    preexisting_shapes = graph.json_attr('shape')
+    preexisting_dtypes = graph.json_attr('dtype')
 
-    if all_initial_shapes:
+    if preexisting_shapes:
         for x in graph.index.input_names:
-            if x not in shape:
-                x_shape = tuple(all_initial_shapes[graph.index.entry_id(x)])
-                shape[x] = x_shape
+            if x not in provided_shapes:
+                x_shape = tuple(preexisting_shapes[graph.index.entry_id(x)])
+                provided_shapes[x] = x_shape
 
-    if all_initial_dtypes:
+    if preexisting_dtypes:
         for x in graph.index.input_names:
-            if x not in dtype:
-                x_dtype = TCODE_TO_DTYPE[all_initial_dtypes[graph.index.entry_id(x)]]
-                dtype[x] = x_dtype
+            if x not in provided_dtypes:
+                x_dtype = TCODE_TO_DTYPE[preexisting_dtypes[graph.index.entry_id(x)]]
+                provided_dtypes[x] = x_dtype
 
     # Perform inference
-    nnvm.compiler.graph_attr.set_shape_inputs(graph, shape)
-    nnvm.compiler.graph_attr.set_dtype_inputs(graph, dtype)
+    nnvm.compiler.graph_attr.set_shape_inputs(graph, provided_shapes)
+    nnvm.compiler.graph_attr.set_dtype_inputs(graph, provided_dtypes)
 
     graph = graph.apply('InferShape').apply('InferType')
 
-    shapes = graph.json_attr('shape')
-    dtypes = graph.json_attr('dtype')
-
-    out_len = len(graph.symbol.list_output_names())
+    inferred_shapes = graph.json_attr('shape')
+    inferred_dtypes = graph.json_attr('dtype')
 
     index = graph.index
 
-    output_shapes = \
-        [tuple(shapes[index.entry_id(index.output_entries[i])]) for i in range(out_len)]
-    output_dtypes = \
-        [TCODE_TO_DTYPE[dtypes[index.entry_id(index.output_entries[i])]] for i in range(out_len)]
+    output_shapes = [tuple(inferred_shapes[index.entry_id(entry)])
+                     for entry in index.output_entries]
+    output_dtypes = [TCODE_TO_DTYPE[inferred_dtypes[index.entry_id(entry)]]
+                     for entry in index.output_entries]
 
     # Postprocess the results
-    input_shapes = shape.copy()
-    input_dtypes = dtype.copy()
+    input_shapes = provided_shapes.copy()
+    input_dtypes = provided_dtypes.copy()
 
     for x in graph.symbol.list_input_variables():
         x_name = x.attr('name')
-        x_node_id = graph.index.node_id(x_name)
-        input_shapes[x_name] = tuple(shapes[x_node_id])
-        input_dtypes[x_name] = TCODE_TO_DTYPE[dtypes[x_node_id]]
+        x_entry_id = graph.index.entry_id(x_name)
+        input_shapes[x_name] = tuple(inferred_shapes[x_entry_id])
+        input_dtypes[x_name] = TCODE_TO_DTYPE[inferred_dtypes[x_entry_id]]
 
     # Merge the original user-specified shapes in case some of them are specified for non-existing
     # variables
-    for x_name, x_shape in shape.items():
+    for x_name, x_shape in provided_shapes.items():
         x_shape = tuple(x_shape)
         if input_shapes.get(x_name, x_shape) != x_shape:
             raise RuntimeError("Inferred shape differs from the provided shape.\n"
                                "Provided shapes: {}\nInferred shapes: {}"
-                               .format(shapes, input_shapes))
+                               .format(provided_shapes, input_shapes))
         else:
             input_shapes[x_name] = x_shape
 
     # Merge the original user-specified dtypes
-    for x_name, x_dtype in dtype.items():
+    for x_name, x_dtype in provided_dtypes.items():
         if not isinstance(x_dtype, str):
             x_dtype = TCODE_TO_DTYPE[x_dtype]
         if input_dtypes.get(x_name, x_dtype) != x_dtype:
             raise RuntimeError("Inferred dtype differs from the provided dtype.\n"
                                "Provided dtypes: {}\nInferred dtypes: {}"
-                               .format(dtypes, input_dtypes))
+                               .format(provided_dtypes, input_dtypes))
         else:
             input_dtypes[x_name] = x_dtype
 
@@ -622,6 +622,12 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
         dist = np.sqrt(np.sum((ngrad - grad)**2))
         grad_norm = np.sqrt(np.sum(ngrad**2))
 
+        if not (np.isfinite(dist) and np.isfinite(grad_norm)):
+            raise ValueError(
+                "NaN or infinity detected during numerical gradient checking wrt {}\n"
+                "analytical grad = {}\n numerical grad = {}\n"
+                .format(x_name, grad, ngrad))
+
         # we multiple atol by this number to make it more universal for different sizes
         sqrt_n = np.sqrt(float(np.prod(grad.shape)))
 
diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py
index ba6280dd9..089ae84cd 100644
--- a/nnvm/tests/python/compiler/test_top_level1.py
+++ b/nnvm/tests/python/compiler/test_top_level1.py
@@ -96,6 +96,7 @@ def test_check_function():
     _check_function_must_fail(sym.block_grad(x + 2*y), numerical_grads=True)
     _check_function_must_fail(x*x, numerical_grads=True,
                               numerical_grads_params={'atol': 0.0, 'rtol': 0.0})
+    _check_function_must_fail(sym.log(-x*x), numerical_grads=True, error=ValueError)
 
     # different styles of returning results from the forward function
     check_function(x + 2*y, lambda x, y: [x + 2*y], numerical_grads=False)
-- 
GitLab