diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 67378c5d14a615e9c402bf9d7b42e6d1a73f526e..3ca161d23f728e593b01a64e7703f3f6ae743c65 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -61,6 +61,11 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) { // - handle shape pattern matching TypeNode* lhs = GetTypeNode(dst); TypeNode* rhs = GetTypeNode(src); + + // do occur check so we don't create self-referencing structure + if (lhs->FindRoot() == rhs->FindRoot()) { + return lhs->resolved_type; + } if (lhs->resolved_type.as<IncompleteTypeNode>()) { MergeFromTo(lhs, rhs); return rhs->resolved_type; diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1a81d3c0535f44d0375489e28d1a0de8bcab6bc --- /dev/null +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -0,0 +1,22 @@ +#include <gtest/gtest.h> +#include <tvm/tvm.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/type.h> +#include <tvm/relay/pass.h> + +TEST(Relay, SelfReference) { + using namespace tvm; + auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); + auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); + auto x = relay::VarNode::make("x", type_a); + auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{}); + auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x }); + auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map<relay::GlobalVar, relay::Function>{})); + CHECK_EQ(type_fx->checked_type(), type_a); +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index e1d749e758631dcc3db236e8248277120e5036f8..8f92fc0f51921b31979870c67b9521bdb2a1478e 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -107,6 +107,22 @@ def test_type_args(): assert sh2[0].value == 1 assert sh2[1].value == 10 +def test_self_reference(): + """ + Program: + def f(x) { + return x; + } + """ + a = relay.TypeVar("a") + x = relay.var("x", a) + sb = relay.ScopeBuilder() + f = relay.Function([x], x) + fx = relay.Call(f, [x]) + assert relay.ir_pass.infer_type(x).checked_type == a + assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) + assert relay.ir_pass.infer_type(fx).checked_type == a + if __name__ == "__main__": test_free_expr() test_dual_op() @@ -117,3 +133,4 @@ if __name__ == "__main__": test_tuple() test_free_expr() test_type_args() + test_self_reference()