Newer
Older
/*!
* Copyright (c) 2018 by Contributors
*
* \file util.cc
*
* \brief simple util for relay.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
class FreeVar;
class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { }
void VisitType_(const TypeVarNode* tp) final {
auto var = GetRef<TypeVar>(tp);
if (bound_vars->count(var) == 0) {
free_vars->insert(var);
}
}
void VisitType_(const FuncTypeNode* f) final {
for (auto type_param : f->type_params) {
bound_vars->insert(type_param);
}
for (auto type_cs : f->type_constraints) {
this->VisitType(type_cs);
}
for (auto arg_type : f->arg_types) {
this->VisitType(arg_type);
}
this->VisitType(f->ret_type);
}
friend FreeVar;
};
class FreeVar : public ExprVisitor {
auto var = GetRef<Var>(v);
if (bound_vars.count(var) == 0) {
free_vars.insert(var);
}
if (v->type_annotation.defined()) {
VisitType(v->type_annotation);
}
for (const auto& tp : f->type_params) {
bound_types.insert(tp);
}
for (const auto& param : f->params) {
bound_vars.insert(param);
}
VisitExpr(f->body);
VisitType(f->ret_type);
}
bound_vars.insert(l->var);
VisitExpr(l->value);
VisitExpr(l->body);
}
public:
std::unordered_set<Var, NodeHash, NodeEqual> free_vars;
std::unordered_set<Var, NodeHash, NodeEqual> bound_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual> free_types;
std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_types;
void VisitType(const Type& t) final {
FreeTypeVar(&free_types, &bound_types)(t);
}
};
tvm::Array<Var> FreeVariables(const Expr& e) {
FreeVar fv;
fv.VisitExpr(e);
return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end());
}
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e) {
return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
tvm::Array<TypeVar> FreeTypeVariables(const Type& t) {
return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end());
}
TVM_REGISTER_API("relay._ir_pass.free_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FreeVariables(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as<TypeNode>()) {
*ret = FreeTypeVariables(Downcast<Type>(x));
} else {
*ret = FreeTypeVariables(Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm