From 02631f67780a9175ac202e5564e25bc6d93393c2 Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Fri, 25 Jan 2019 10:17:31 -0800
Subject: [PATCH] [Relay] Add generic & informative Relay error reporting
 (#2408)

---
 .gitmodules                                |   3 +
 3rdparty/rang                              |   1 +
 CMakeLists.txt                             |   1 +
 include/tvm/relay/error.h                  | 127 +++++++++++++--
 include/tvm/relay/module.h                 |  19 +++
 include/tvm/relay/pass.h                   |   2 +-
 include/tvm/relay/type.h                   |   6 +
 src/relay/ir/error.cc                      | 128 +++++++++++++++
 src/relay/ir/module.cc                     |  17 ++
 src/relay/op/type_relations.cc             |   8 +-
 src/relay/pass/type_infer.cc               | 173 ++++++++++++++-------
 src/relay/pass/type_solver.cc              |  78 ++++++++--
 src/relay/pass/type_solver.h               |  26 +++-
 tests/python/relay/test_error_reporting.py |  34 ++++
 14 files changed, 537 insertions(+), 86 deletions(-)
 create mode 160000 3rdparty/rang
 create mode 100644 src/relay/ir/error.cc
 create mode 100644 tests/python/relay/test_error_reporting.py

diff --git a/.gitmodules b/.gitmodules
index 8011ec12d..984326434 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -7,3 +7,6 @@
 [submodule "dlpack"]
 	path = 3rdparty/dlpack
 	url = https://github.com/dmlc/dlpack
+[submodule "3rdparty/rang"]
+	path = 3rdparty/rang
+	url = https://github.com/agauniyal/rang
diff --git a/3rdparty/rang b/3rdparty/rang
new file mode 160000
index 000000000..cabe04d6d
--- /dev/null
+++ b/3rdparty/rang
@@ -0,0 +1 @@
+Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8765a3346..23dd58a2c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
 include_directories("include")
 include_directories("3rdparty/dlpack/include")
 include_directories("3rdparty/dmlc-core/include")
+include_directories("3rdparty/rang/include")
 include_directories("3rdparty/compiler-rt")
 
 # initial variables
diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h
index 1c2b90611..0451a9826 100644
--- a/include/tvm/relay/error.h
+++ b/include/tvm/relay/error.h
@@ -7,25 +7,134 @@
 #define TVM_RELAY_ERROR_H_
 
 #include <string>
+#include <vector>
+#include <sstream>
 #include "./base.h"
+#include "./expr.h"
+#include "./module.h"
 
 namespace tvm {
 namespace relay {
 
-struct Error : public dmlc::Error {
-  explicit Error(const std::string &msg) : dmlc::Error(msg) {}
-};
+#define RELAY_ERROR(msg) (RelayErrorStream() << msg)
+
+// Forward declaratio for RelayErrorStream.
+struct Error;
+
+/*! \brief A wrapper around std::stringstream.
+ *
+ * This is designed to avoid platform specific
+ * issues compiling and using std::stringstream
+ * for error reporting.
+ */
+struct RelayErrorStream {
+  std::stringstream ss;
+
+  template<typename T>
+  RelayErrorStream& operator<<(const T& t) {
+    ss << t;
+    return *this;
+  }
 
-struct InternalError : public Error {
-  explicit InternalError(const std::string &msg) : Error(msg) {}
+  std::string str() const {
+    return ss.str();
+  }
+
+  void Raise() const;
 };
 
-struct FatalTypeError : public Error {
-  explicit FatalTypeError(const std::string &s) : Error(s) {}
+struct Error : public dmlc::Error {
+  Span sp;
+  explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
+  Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
+  Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
 };
 
-struct TypecheckerError : public Error {
-  explicit TypecheckerError(const std::string &msg) : Error(msg) {}
+/*! \brief An abstraction around how errors are stored and reported.
+ * Designed to be opaque to users, so we can support a robust and simpler
+ * error reporting mode, as well as a more complex mode.
+ *
+ * The first mode is the most accurate: we report a Relay error at a specific
+ * Span, and then render the error message directly against a textual representation
+ * of the program, highlighting the exact lines in which it occurs. This mode is not
+ * implemented in this PR and will not work.
+ *
+ * The second mode is a general-purpose mode, which attempts to annotate the program's
+ * textual format with errors.
+ *
+ * The final mode represents the old mode, if we report an error that has no span or
+ * expression, we will default to throwing an exception with a textual representation
+ * of the error and no indication of where it occured in the original program.
+ *
+ * The latter mode is not ideal, and the goal of the new error reporting machinery is
+ * to avoid ever reporting errors in this style.
+ */
+class ErrorReporter {
+ public:
+  ErrorReporter() : errors_(), node_to_error_() {}
+
+  /*! \brief Report a tvm::relay::Error.
+   *
+   * This API is useful for reporting spanned errors.
+   *
+   * \param err The error to report.
+   */
+  void Report(const Error& err) {
+    if (!err.sp.defined()) {
+      throw err;
+    }
+
+    this->errors_.push_back(err);
+  }
+
+  /*! \brief Report an error against a program, using the full program
+   * error reporting strategy.
+   *
+   * This error reporting method requires the global function in which
+   * to report an error, the expression to report the error on,
+   * and the error object.
+   *
+   * \param global The global function in which the expression is contained.
+   * \param node The expression or type to report the error at.
+   * \param err The error message to report.
+   */
+  inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
+    this->ReportAt(global, node, Error(err));
+  }
+
+  /*! \brief Report an error against a program, using the full program
+   * error reporting strategy.
+   *
+   * This error reporting method requires the global function in which
+   * to report an error, the expression to report the error on,
+   * and the error object.
+   *
+   * \param global The global function in which the expression is contained.
+   * \param node The expression or type to report the error at.
+   * \param err The error to report.
+   */
+  void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);
+
+  /*! \brief Render all reported errors and exit the program.
+   *
+   * This function should be used after executing a pass to render reported errors.
+   *
+   * It will build an error message from the set of errors, depending on the error
+   * reporting strategy.
+   *
+   * \param module The module to report errors on.
+   * \param use_color Controls whether to colorize the output.
+   */
+  void RenderErrors(const Module& module, bool use_color = true);
+
+  inline bool AnyErrors() {
+    return errors_.size() != 0;
+  }
+
+ private:
+  std::vector<Error> errors_;
+  std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
+  std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
 };
 
 }  // namespace relay
diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h
index 8d302c09d..45ccfe3a8 100644
--- a/include/tvm/relay/module.h
+++ b/include/tvm/relay/module.h
@@ -43,11 +43,15 @@ class ModuleNode : public RelayNode {
   /*! \brief A map from ids to all global functions. */
   tvm::Map<GlobalVar, Function> functions;
 
+  /*! \brief The entry function (i.e. "main"). */
+  GlobalVar entry_func;
+
   ModuleNode() {}
 
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("functions", &functions);
     v->Visit("global_var_map_", &global_var_map_);
+    v->Visit("entry_func", &entry_func);
   }
 
   TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
@@ -111,6 +115,20 @@ class ModuleNode : public RelayNode {
    */
   void Update(const Module& other);
 
+  /*! \brief Construct a module from a standalone expression.
+   *
+   * Allows one to optionally pass a global function map as
+   * well.
+   *
+   * \param expr The expression to set as the entry point to the module.
+   * \param global_funcs The global function map.
+   *
+   * \returns A module with expr set as the entry point.
+   */
+  static Module FromExpr(
+    const Expr& expr,
+    const tvm::Map<GlobalVar, Function>& global_funcs = {});
+
   static constexpr const char* _type_key = "relay.Module";
   TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
 
@@ -132,6 +150,7 @@ struct Module : public NodeRef {
   using ContainerType = ModuleNode;
 };
 
+
 }  // namespace relay
 }  // namespace tvm
 
diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h
index 8527ab7a2..38f6a805f 100644
--- a/include/tvm/relay/pass.h
+++ b/include/tvm/relay/pass.h
@@ -6,8 +6,8 @@
 #ifndef TVM_RELAY_PASS_H_
 #define TVM_RELAY_PASS_H_
 
-#include <tvm/relay/module.h>
 #include <tvm/relay/expr.h>
+#include <tvm/relay/module.h>
 #include <tvm/relay/op_attr_types.h>
 #include <string>
 
diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h
index 69a8a4fb0..f3bcf2c0a 100644
--- a/include/tvm/relay/type.h
+++ b/include/tvm/relay/type.h
@@ -295,6 +295,12 @@ class TypeReporterNode : public Node {
    */
   TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
 
+  /*!
+   * \brief Set the location at which to report unification errors.
+   * \param ref The program node to report the error.
+   */
+  TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
+
   // solver is not serializable.
   void VisitAttrs(tvm::AttrVisitor* v) final {}
 
diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc
new file mode 100644
index 000000000..24f8d1c49
--- /dev/null
+++ b/src/relay/ir/error.cc
@@ -0,0 +1,128 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file error_reporter.h
+ * \brief The set of errors raised by Relay.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/module.h>
+#include <tvm/relay/error.h>
+#include <string>
+#include <vector>
+#include <rang.hpp>
+
+namespace tvm {
+namespace relay {
+
+void RelayErrorStream::Raise() const {
+  throw Error(*this);
+}
+
+template<typename T, typename U>
+using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
+
+void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
+  // First we pick an error reporting strategy for each error.
+  // TODO(@jroesch): Spanned errors are currently not supported.
+  for (auto err : this->errors_) {
+    CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
+  }
+
+  NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;
+
+  // Set control mode in order to produce colors;
+  if (use_color) {
+    rang::setControlMode(rang::control::Force);
+  }
+
+  for (auto pair : this->node_to_gv_) {
+    auto node = pair.first;
+    auto global = Downcast<GlobalVar>(pair.second);
+
+    auto has_errs = this->node_to_error_.find(node);
+
+    CHECK(has_errs != this->node_to_error_.end());
+
+    const auto& error_indicies = has_errs->second;
+
+    std::stringstream err_msg;
+
+    err_msg << rang::fg::red;
+    for (auto index : error_indicies) {
+      err_msg << this->errors_[index].what() << "; ";
+    }
+    err_msg << rang::fg::reset;
+
+    // Setup error map.
+    auto it = error_maps.find(global);
+    if (it != error_maps.end()) {
+      it->second.insert({ node, err_msg.str() });
+    } else {
+      error_maps.insert({ global, { { node, err_msg.str() }}});
+    }
+  }
+
+  // Now we will construct the fully-annotated program to display to
+  // the user.
+  std::stringstream annotated_prog;
+
+  // First we output a header for the errors.
+  annotated_prog <<
+  rang::style::bold << std::endl <<
+  "Error(s) have occurred. We have annotated the program with them:"
+  << std::endl << std::endl << rang::style::reset;
+
+  // For each global function which contains errors, we will
+  // construct an annotated function.
+  for (auto pair : error_maps) {
+    auto global = pair.first;
+    auto err_map = pair.second;
+    auto func = module->Lookup(global);
+
+    // We output the name of the function before displaying
+    // the annotated program.
+    annotated_prog <<
+      rang::style::bold <<
+      "In `" << global->name_hint << "`: " <<
+      std::endl <<
+      rang::style::reset;
+
+    // We then call into the Relay printer to generate the program.
+    //
+    // The annotation callback will annotate the error messages
+    // contained in the map.
+    annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) {
+      auto it = err_map.find(expr);
+      if (it != err_map.end()) {
+        return it->second;
+      } else {
+        return std::string("");
+      }
+    });
+  }
+
+  auto msg = annotated_prog.str();
+
+  if (use_color) {
+    rang::setControlMode(rang::control::Auto);
+  }
+
+  // Finally we report the error, currently we do so to LOG(FATAL),
+  // it may be good to instead report it to std::cout.
+  LOG(FATAL) << annotated_prog.str() << std::endl;
+}
+
+void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
+  size_t index_to_insert = this->errors_.size();
+  this->errors_.push_back(err);
+  auto it = this->node_to_error_.find(node);
+  if (it != this->node_to_error_.end()) {
+    it->second.push_back(index_to_insert);
+  } else {
+    this->node_to_error_.insert({ node, { index_to_insert }});
+  }
+  this->node_to_gv_.insert({ node, global });
+}
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc
index cbb0b7768..9ba5efece 100644
--- a/src/relay/ir/module.cc
+++ b/src/relay/ir/module.cc
@@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
         << "Duplicate global function name " << kv.first->name_hint;
     n->global_var_map_.Set(kv.first->name_hint, kv.first);
   }
+
+  n->entry_func = GlobalVarNode::make("main");
   return Module(n);
 }
 
@@ -96,6 +98,21 @@ void ModuleNode::Update(const Module& mod) {
   }
 }
 
+Module ModuleNode::FromExpr(
+  const Expr& expr,
+  const tvm::Map<GlobalVar, Function>& global_funcs) {
+  auto mod = ModuleNode::make(global_funcs);
+  auto func_node = expr.as<FunctionNode>();
+  Function func;
+  if (func_node) {
+    func = GetRef<Function>(func_node);
+  } else {
+    func = FunctionNode::make({}, expr, Type(), {}, {});
+  }
+  mod->Add(mod->entry_func, func);
+  return mod;
+}
+
 TVM_REGISTER_NODE_TYPE(ModuleNode);
 
 TVM_REGISTER_API("relay._make.Module")
diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc
index 467c0fcde..2618054a6 100644
--- a/src/relay/op/type_relations.cc
+++ b/src/relay/op/type_relations.cc
@@ -70,9 +70,12 @@ Type ConcreteBroadcast(const TensorType& t1,
     } else if (EqualConstInt(s2, 1)) {
       oshape.push_back(s1);
     } else {
-      LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2;
+      RELAY_ERROR(
+          "Incompatible broadcast type "
+              << t1 << " and " << t2).Raise();
     }
   }
+
   size_t max_ndim = std::max(ndim1, ndim2);
   auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape;
   for (; i <= max_ndim; ++i) {
@@ -92,7 +95,8 @@ bool BroadcastRel(const Array<Type>& types,
   if (auto t0 = ToTensorType(types[0])) {
     if (auto t1 = ToTensorType(types[1])) {
       CHECK_EQ(t0->dtype, t1->dtype);
-      reporter->Assign(types[2], ConcreteBroadcast(t0, t1, t0->dtype));
+      reporter->Assign(types[2],
+        ConcreteBroadcast(t0, t1, t0->dtype));
       return true;
     }
   }
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index af4cc6607..3135715f7 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -82,10 +82,10 @@ struct ResolvedTypeInfo {
 class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
  public:
   // constructors
-  TypeInferencer() {
-  }
-  explicit TypeInferencer(Module mod)
-      : mod_(mod) {
+
+  explicit TypeInferencer(Module mod, GlobalVar current_func)
+      : mod_(mod), current_func_(current_func),
+        err_reporter(), solver_(current_func, &this->err_reporter) {
   }
 
   // inference the type of expr.
@@ -96,6 +96,13 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   class Resolver;
   // internal environment
   Module mod_;
+
+  // The current function being type checked.
+  GlobalVar current_func_;
+
+  // The error reporter.
+  ErrorReporter err_reporter;
+
   // map from expression to checked type
   // type inferencer will populate it up
   std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
@@ -109,18 +116,21 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   // relation function
   TypeRelationFn tuple_getitem_rel_;
   TypeRelationFn make_tuple_rel_;
-  // Unify two types
-  Type Unify(const Type& t1, const Type& t2, const Span& span) {
+
+  // Perform unification on two types and report the error at the expression
+  // or the span of the expression.
+  Type Unify(const Type& t1, const Type& t2, const Expr& expr) {
     // TODO(tqchen, jroesch): propagate span to solver
     try {
-      return solver_.Unify(t1, t2);
+      return solver_.Unify(t1, t2, expr);
     } catch (const dmlc::Error &e) {
-      LOG(FATAL)
-          << "Error unifying `"
+      this->ReportFatalError(
+        expr,
+        RELAY_ERROR("Error unifying `"
           << t1
           << "` and `"
           << t2
-          << "`: " << e.what();
+          << "`: " << e.what()));
       return Type();
     }
   }
@@ -151,7 +161,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   }
 
   // Lazily get type for expr
-  // will call visit to deduce it if it is not in the type_map_
+  // expression, we will populate it now, and return the result.
   Type GetType(const Expr &expr) {
     auto it = type_map_.find(expr);
     if (it != type_map_.end() && it->second.checked_type.defined()) {
@@ -163,7 +173,13 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     return ret;
   }
 
-  // Visitor logics
+  void ReportFatalError(const Expr& expr, const Error& err) {
+    CHECK(this->current_func_.defined());
+    this->err_reporter.ReportAt(this->current_func_, expr, err);
+    this->err_reporter.RenderErrors(this->mod_);
+  }
+
+  // Visitor Logic
   Type VisitExpr_(const VarNode* op) final {
     if (op->type_annotation.defined()) {
       return op->type_annotation;
@@ -174,8 +190,13 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
 
   Type VisitExpr_(const GlobalVarNode* op) final {
     GlobalVar var = GetRef<GlobalVar>(op);
-    CHECK(mod_.defined())
-        << "Cannot do type inference without a global variable";
+    if (!mod_.defined()) {
+      this->ReportFatalError(
+        GetRef<GlobalVar>(op),
+        RELAY_ERROR(
+          "Cannot do type inference on global variables " \
+          "without a module"));
+    }
     Expr e = mod_->Lookup(var);
     return e->checked_type();
   }
@@ -202,7 +223,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     auto attrs = make_node<TupleGetItemAttrs>();
     attrs->index = op->index;
     solver_.AddConstraint(TypeRelationNode::make(
-        tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)));
+        tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef<TupleGetItem>(op));
     return rtype;
   }
 
@@ -210,40 +231,43 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     return op->op_type;
   }
 
-  Type VisitExpr_(const LetNode* op) final {
+  Type VisitExpr_(const LetNode* let) final {
     // if the definition is a function literal, permit recursion
-    bool is_functional_literal = op->value.as<FunctionNode>() != nullptr;
+    bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
     if (is_functional_literal) {
-      type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
+      type_map_[let->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
     }
 
-    Type vtype = GetType(op->value);
-    if (op->var->type_annotation.defined()) {
-      vtype = Unify(vtype, op->var->type_annotation, op->span);
+    Type vtype = GetType(let->value);
+    if (let->var->type_annotation.defined()) {
+      vtype = Unify(vtype, let->var->type_annotation, GetRef<Let>(let));
     }
-    CHECK(is_functional_literal || !type_map_.count(op->var));
+    CHECK(is_functional_literal || !type_map_.count(let->var));
     // NOTE: no scoping is necessary because var are unique in program
-    type_map_[op->var].checked_type = vtype;
-    return GetType(op->body);
+    type_map_[let->var].checked_type = vtype;
+    return GetType(let->body);
   }
 
-  Type VisitExpr_(const IfNode* op) final {
+  Type VisitExpr_(const IfNode* ite) final {
     // Ensure the type of the guard is of Tensor[Bool, ()],
     // that is a rank-0 boolean tensor.
-    Type cond_type = this->GetType(op->cond);
+    Type cond_type = this->GetType(ite->cond);
     this->Unify(cond_type,
                 TensorTypeNode::Scalar(tvm::Bool()),
-                op->cond->span);
-    Type checked_true = this->GetType(op->true_branch);
-    Type checked_false = this->GetType(op->false_branch);
-    return this->Unify(checked_true, checked_false, op->span);
+                ite->cond);
+    Type checked_true = this->GetType(ite->true_branch);
+    Type checked_false = this->GetType(ite->false_branch);
+    return this->Unify(checked_true, checked_false, GetRef<If>(ite));
   }
 
-  // Handle special case basic primitive operator,
-  // if successful return the return type
+  // This code is special-cased for primitive operators,
+  // which are registered in the style defined in src/relay/op/*.
+  //
+  // The result will be the return type of the operator.
   Type PrimitiveCall(const FuncTypeNode* op,
                      Array<Type> arg_types,
-                     const Attrs& attrs) {
+                     const Attrs& attrs,
+                     const NodeRef& loc) {
     if (op->type_params.size() != arg_types.size() + 1) return Type();
     if (op->type_constraints.size() != 1) return Type();
     const TypeRelationNode* rel = op->type_constraints[0].as<TypeRelationNode>();
@@ -256,7 +280,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     arg_types.push_back(rtype);
     // we can do simple replacement here
     solver_.AddConstraint(TypeRelationNode::make(
-        rel->func, arg_types, arg_types.size() - 1, attrs));
+        rel->func, arg_types, arg_types.size() - 1, attrs), loc);
     return rtype;
   }
 
@@ -304,16 +328,19 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     auto* fn_ty_node = ftype.as<FuncTypeNode>();
     auto* inc_ty_node = ftype.as<IncompleteTypeNode>();
 
-    CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr)
-      << "only expressions with function types can be called, found "
-      << ftype << " at " << call->span;
+    if (fn_ty_node == nullptr && inc_ty_node == nullptr) {
+      this->ReportFatalError(
+        GetRef<Call>(call),
+        RELAY_ERROR("only expressions with function types can be called, found "
+        << ftype));
+    }
 
     // incomplete type => it must be a function taking the arg types
     // with an unknown return type
     if (inc_ty_node != nullptr) {
       Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
       Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {});
-      Type unified = this->Unify(ftype, func_type, call->span);
+      Type unified = this->Unify(ftype, func_type, GetRef<Call>(call));
       fn_ty_node = unified.as<FuncTypeNode>();
     }
 
@@ -323,10 +350,16 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
         type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType));
       }
     }
-    CHECK(type_args.size() == fn_ty_node->type_params.size())
-      << "Incorrect number of type args in " << call->span << ": "
-      << "Expected " << fn_ty_node->type_params.size()
-      << "but got " << type_args.size();
+
+    if (type_args.size() != fn_ty_node->type_params.size()) {
+      this->ReportFatalError(GetRef<Call>(call),
+        RELAY_ERROR("Incorrect number of type args in "
+          << call->span << ": "
+          << "Expected "
+          << fn_ty_node->type_params.size()
+          << "but got " << type_args.size()));
+    }
+
     FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
 
     AddTypeArgs(GetRef<Call>(call), type_args);
@@ -336,22 +369,29 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
 
     if (type_arity != number_of_args) {
       if (type_arity < number_of_args) {
-        LOG(FATAL) << "the function is provided too many arguments " << call->span;
+        this->ReportFatalError(
+          GetRef<Call>(call),
+          RELAY_ERROR("the function is provided too many arguments "
+          << "expected " << type_arity << ", found " << number_of_args));
       } else {
-        LOG(FATAL) << "the function is provided too few arguments" << call->span;
+        this->ReportFatalError(
+          GetRef<Call>(call),
+          RELAY_ERROR("the function is provided too few arguments "
+          << "expected " << type_arity << ", found " << number_of_args));
       }
     }
 
     for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
-      this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span);
+      this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]);
     }
 
     for (auto cs : fn_ty->type_constraints) {
       if (auto tr = cs.as<TypeRelationNode>()) {
         solver_.AddConstraint(
-          TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs));
+          TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs),
+          GetRef<Call>(call));
       } else {
-        solver_.AddConstraint(cs);
+        solver_.AddConstraint(cs, GetRef<Call>(call));
       }
     }
 
@@ -367,7 +407,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     if (const OpNode* opnode = call->op.as<OpNode>()) {
       Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
                                  arg_types,
-                                 call->attrs);
+                                 call->attrs,
+                                 GetRef<Call>(call));
       if (rtype.defined()) {
         AddTypeArgs(GetRef<Call>(call), arg_types);
         return rtype;
@@ -385,7 +426,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     }
     Type rtype = GetType(f->body);
     if (f->ret_type.defined()) {
-      rtype = this->Unify(f->ret_type, rtype, f->span);
+      rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f));
     }
     auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
     return solver_.Resolve(ret);
@@ -445,6 +486,9 @@ class TypeInferencer::Resolver : public ExprMutator {
     auto it = tmap_.find(GetRef<Expr>(op));
     CHECK(it != tmap_.end());
     Type checked_type = solver_->Resolve(it->second.checked_type);
+
+    // TODO(@jroesch): it would be nice if we would report resolution
+    // errors directly on the program.
     CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
         << "Cannot resolve type of " << GetRef<Expr>(op)
         << " at " << op->span;
@@ -542,6 +586,10 @@ Expr TypeInferencer::Infer(Expr expr) {
   // Step 1: Solve the constraints.
   solver_.Solve();
 
+  if (err_reporter.AnyErrors()) {
+    err_reporter.RenderErrors(mod_);
+  }
+
   // Step 2: Attach resolved types to checked_type field.
   auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
   CHECK(WellFormed(resolved_expr));
@@ -549,10 +597,27 @@ Expr TypeInferencer::Infer(Expr expr) {
 }
 
 
-Expr InferType(const Expr& expr, const Module& mod) {
-  auto e = TypeInferencer(mod).Infer(expr);
-  CHECK(WellFormed(e));
-  return e;
+Expr InferType(const Expr& expr, const Module& mod_ref) {
+  if (!mod_ref.defined()) {
+    Module mod = ModuleNode::FromExpr(expr);
+    // NB(@jroesch): By adding the expression to the module we will
+    // type check it anyway; afterwards we can just recover type
+    // from the type-checked function to avoid doing unnecessary work.
+
+    Function func = mod->Lookup(mod->entry_func);
+
+    // FromExpr wraps a naked expression as a function, we will unbox
+    // it here.
+    if (expr.as<FunctionNode>()) {
+      return func;
+    } else {
+      return func->body;
+    }
+  } else {
+    auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr);
+    CHECK(WellFormed(e));
+    return e;
+  }
 }
 
 Function InferType(const Function& func,
@@ -561,7 +626,7 @@ Function InferType(const Function& func,
   Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
   func_copy->checked_type_ = func_copy->func_type_annotation();
   mod->AddUnchecked(var, func_copy);
-  Expr func_ret = TypeInferencer(mod).Infer(func_copy);
+  Expr func_ret = TypeInferencer(mod, var).Infer(func_copy);
   mod->Remove(var);
   CHECK(WellFormed(func_ret));
   return Downcast<Function>(func_ret);
diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc
index caea3755b..dafcaf560 100644
--- a/src/relay/pass/type_solver.cc
+++ b/src/relay/pass/type_solver.cc
@@ -16,7 +16,7 @@ class TypeSolver::Reporter : public TypeReporterNode {
       : solver_(solver) {}
 
   void Assign(const Type& dst, const Type& src) final {
-    solver_->Unify(dst, src);
+    solver_->Unify(dst, src, location);
   }
 
   bool Assert(const IndexExpr& cond) final {
@@ -35,7 +35,14 @@ class TypeSolver::Reporter : public TypeReporterNode {
     return true;
   }
 
+  TVM_DLL void SetLocation(const NodeRef& ref) final {
+    location = ref;
+  }
+
  private:
+  /*! \brief The location to report unification errors at. */
+  mutable NodeRef location;
+
   TypeSolver* solver_;
 };
 
@@ -329,8 +336,10 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
 };
 
 // constructor
-TypeSolver::TypeSolver()
-  : reporter_(make_node<Reporter>(this)) {
+TypeSolver::TypeSolver(const GlobalVar &current_func, ErrorReporter* err_reporter)
+  : reporter_(make_node<Reporter>(this)),
+    current_func(current_func),
+    err_reporter_(err_reporter) {
 }
 
 // destructor
@@ -351,16 +360,26 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
 }
 
 // Add equality constraint
-Type TypeSolver::Unify(const Type& dst, const Type& src) {
+Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
+  // NB(@jroesch): we should probably pass location into the unifier to do better
+  // error reporting as well.
   Unifier unifier(this);
   return unifier.Unify(dst, src);
 }
 
+void TypeSolver::ReportError(const Error& err, const NodeRef& location)  {
+    this->err_reporter_->ReportAt(
+      this->current_func,
+      location,
+      err);
+  }
+
 // Add type constraint to the solver.
-void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
+void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
   if (auto *op = constraint.as<TypeRelationNode>()) {
     // create a new relation node.
     RelationNode* rnode = arena_.make<RelationNode>();
+    rnode->location = loc;
     rnode->rel = GetRef<TypeRelation>(op);
     rel_nodes_.push_back(rnode);
     // populate the type information.
@@ -404,29 +423,52 @@ bool TypeSolver::Solve() {
       args.push_back(Resolve(tlink->value->FindRoot()->resolved_type));
       CHECK_LE(args.size(), rel->args.size());
     }
-    // call the function
-    bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);
-    // mark inqueue as false after the function call
-    // so that rnode itself won't get enqueued again.
-    rnode->inqueue = false;
 
-    if (resolved) {
-      ++num_resolved_rels_;
+    CHECK(rnode->location.defined())
+      << "undefined location, should be set when constructing relation node";
+
+    // We need to set this in order to understand where unification
+    // errors generated by the error reporting are coming from.
+    reporter_->SetLocation(rnode->location);
+
+    try {
+      // Call the Type Relation's function.
+      bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);
+
+      if (resolved) {
+        ++num_resolved_rels_;
+      }
+
+      rnode->resolved = resolved;
+    } catch (const Error& err) {
+      this->ReportError(err, rnode->location);
+      rnode->resolved = false;
+    } catch (const dmlc::Error& err) {
+      rnode->resolved = false;
+      this->ReportError(
+          RELAY_ERROR(
+            "an internal invariant was violdated while" \
+            "typechecking your program" <<
+            err.what()), rnode->location);
     }
-    rnode->resolved = resolved;
+
+    // Mark inqueue as false after the function call
+    // so that rnode itself won't get enqueued again.
+    rnode->inqueue = false;
   }
+
   // This criterion is not necessarily right for all the possible cases
   // TODO(tqchen): We should also count the number of in-complete types.
   return num_resolved_rels_ == rel_nodes_.size();
 }
 
-
 // Expose type solver only for debugging purposes.
 TVM_REGISTER_API("relay._ir_pass._test_type_solver")
 .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
     using runtime::PackedFunc;
     using runtime::TypedPackedFunc;
-    auto solver = std::make_shared<TypeSolver>();
+    ErrorReporter err_reporter;
+    auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), &err_reporter);
 
     auto mod = [solver](std::string name) -> PackedFunc {
       if (name == "Solve") {
@@ -435,7 +477,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
           });
       } else if (name == "Unify") {
         return TypedPackedFunc<Type(Type, Type)>([solver](Type lhs, Type rhs) {
-            return solver->Unify(lhs, rhs);
+            return solver->Unify(lhs, rhs, lhs);
           });
       } else if (name == "Resolve") {
         return TypedPackedFunc<Type(Type)>([solver](Type t) {
@@ -443,7 +485,9 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
           });
       } else if (name == "AddConstraint") {
         return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
-            return solver->AddConstraint(c);
+            Expr e = VarNode::make("dummy_var",
+              IncompleteTypeNode::make(TypeVarNode::Kind::kType));
+            return solver->AddConstraint(c, e);
           });
       } else {
         return PackedFunc();
diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h
index b4635fdec..b56d45c3b 100644
--- a/src/relay/pass/type_solver.h
+++ b/src/relay/pass/type_solver.h
@@ -6,8 +6,10 @@
 #ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_
 #define TVM_RELAY_PASS_TYPE_SOLVER_H_
 
+#include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/pass.h>
+#include <tvm/relay/error.h>
 #include <vector>
 #include <queue>
 #include "../../common/arena.h"
@@ -40,13 +42,14 @@ using common::LinkedList;
  */
 class TypeSolver {
  public:
-  TypeSolver();
+  TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter);
   ~TypeSolver();
   /*!
    * \brief Add a type constraint to the solver.
    * \param constraint The constraint to be added.
+   * \param location The location at which the constraint was incurred.
    */
-  void AddConstraint(const TypeConstraint& constraint);
+  void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation);
   /*!
    * \brief Resolve type to the solution type in the solver.
    * \param type The type to be resolved.
@@ -62,8 +65,16 @@ class TypeSolver {
    * \brief Unify lhs and rhs.
    * \param lhs The left operand.
    * \param rhs The right operand
+   * \param location The location at which the unification problem arose.
    */
-  Type Unify(const Type& lhs, const Type& rhs);
+  Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);
+
+  /*!
+   * \brief Report an error at the provided location.
+   * \param err The error to report.
+   * \param loc The location at which to report the error.
+   */
+  void ReportError(const Error& err, const NodeRef& location);
 
  private:
   class OccursChecker;
@@ -112,6 +123,7 @@ class TypeSolver {
       return root;
     }
   };
+
   /*! \brief relation node */
   struct RelationNode {
     /*! \brief Whether the relation is in the queue to be solved */
@@ -122,7 +134,10 @@ class TypeSolver {
     TypeRelation rel;
     /*! \brief list types to this relation */
     LinkedList<TypeNode*> type_list;
+    /*! \brief The location this type relation originated from. */
+    NodeRef location;
   };
+
   /*! \brief List of all allocated type nodes */
   std::vector<TypeNode*> type_nodes_;
   /*! \brief List of all allocated relation nodes */
@@ -137,6 +152,11 @@ class TypeSolver {
   common::Arena arena_;
   /*! \brief Reporter that reports back to self */
   TypeReporter reporter_;
+  /*! \brief The global representing the current function. */
+  GlobalVar current_func;
+  /*! \brief Error reporting. */
+  ErrorReporter* err_reporter_;
+
   /*!
    * \brief GetTypeNode that is corresponds to t.
    * if it do not exist, create a new one.
diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py
new file mode 100644
index 000000000..1720af21a
--- /dev/null
+++ b/tests/python/relay/test_error_reporting.py
@@ -0,0 +1,34 @@
+import tvm
+from tvm import relay
+
+def check_type_err(expr, msg):
+    try:
+        expr = relay.ir_pass.infer_type(expr)
+        assert False
+    except tvm.TVMError as err:
+        assert msg in str(err)
+
+def test_too_many_args():
+    x = relay.var('x', shape=(10, 10))
+    f = relay.Function([x], x)
+    y = relay.var('y', shape=(10, 10))
+    check_type_err(
+        f(x, y),
+        "the function is provided too many arguments expected 1, found 2;")
+
+def test_too_few_args():
+    x = relay.var('x', shape=(10, 10))
+    y = relay.var('y', shape=(10, 10))
+    f = relay.Function([x, y], x)
+    check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;")
+
+def test_rel_fail():
+    x = relay.var('x', shape=(10, 10))
+    y = relay.var('y', shape=(11, 10))
+    f = relay.Function([x, y], x + y)
+    check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
+
+if __name__ == "__main__":
+    test_too_many_args()
+    test_too_few_args()
+    test_rel_fail()
-- 
GitLab