Newer
Older
/*!
* Copyright (c) 2018 by Contributors
*
* \file util.cc
*
* \brief simple util for relay.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "../ir/type_functor.h"
// FreeTypeVar
class FreeTypeVarTVisitor : public TypeVisitor {
public:
FreeTypeVarTVisitor(
Array<TypeVar>* free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars)
: free_vars_(free_vars), bound_vars_(bound_vars) { }
void VisitType_(const TypeVarNode* tp) final {
TypeVar var = GetRef<TypeVar>(tp);
if (bound_vars_->count(var) == 0) {
free_vars_->push_back(var);
}
}
void VisitType_(const FuncTypeNode* f) final {
for (auto type_param : f->type_params) {
bound_vars_->insert(type_param);
TypeVisitor::VisitType_(f);
}
private:
Array<TypeVar>* free_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars_;
};
class FreeTypeVarEVisitor : private ExprVisitor {
public:
Array<TypeVar> Find(const Expr& expr) {
this->VisitExpr(expr);
return free_vars_;
Array<TypeVar> Find(const Type& type) {
this->VisitType(type);
return free_vars_;
void VisitType(const Type& t) final {
FreeTypeVarTVisitor(&free_vars_, &bound_vars_)
.VisitType(t);
private:
// The result list
Array<TypeVar> free_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_vars_;
};
class FreeVarVisitor : protected ExprVisitor {
Array<Var> Find(const Expr& expr) {
this->VisitExpr(expr);
return free_vars_;
}
void VisitExpr_(const VarNode* var) final {
if (bound_vars_.count(var) == 0) {
free_vars_.push_back(GetRef<Var>(var));
}
void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) {
bound_vars_.insert(param.operator->());
}
VisitExpr(op->body);
}
void VisitExpr_(const LetNode* op) final {
bound_vars_.insert(op->var.operator->());
VisitExpr(op->value);
VisitExpr(op->body);
}
private:
// The result list
Array<Var> free_vars_;
std::unordered_set<const VarNode*> bound_vars_;
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) {
return FreeTypeVarEVisitor().Find(expr);
tvm::Array<TypeVar> FreeTypeVars(const Type& type) {
return FreeTypeVarEVisitor().Find(type);
tvm::Array<Var> FreeVars(const Expr& expr) {
return FreeVarVisitor().Find(expr);
}
TVM_REGISTER_API("relay._ir_pass.free_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
});
TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x));
*ret = FreeTypeVars(Downcast<Expr>(x));