Commit 7d906c9d by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] Free Variables (#1786)

parent e928109c
...@@ -92,6 +92,36 @@ bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -92,6 +92,36 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/ */
bool WellFormed(const Expr & e); bool WellFormed(const Expr & e);
/*! \brief Get free variables from expression e.
*
* Free variables are variables that are not bound by a let or a function parameter in the context.
*
* \param e the expression.
*
* \return the set of free variable.
*/
tvm::Array<Var> FreeVariables(const Expr & e);
/*! \brief Get free type parameters from expression e.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param e the expression.
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Expr & e);
/*! \brief Get free type parameters from type t.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param t the type.
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Type & t);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_H_ #endif // TVM_RELAY_PASS_H_
...@@ -14,3 +14,7 @@ check_expr = _ir_pass.check_expr ...@@ -14,3 +14,7 @@ check_expr = _ir_pass.check_expr
well_formed = _ir_pass.well_formed well_formed = _ir_pass.well_formed
check_kind = _ir_pass.check_kind check_kind = _ir_pass.check_kind
free_vars = _ir_pass.free_vars
free_type_vars = _ir_pass.free_type_vars
/*!
* Copyright (c) 2018 by Contributors
*
* \file util.cc
*
* \brief simple util for relay.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "./type_visitor.h"
namespace tvm {
namespace relay {
class FreeVar;
class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars;
std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars;
FreeTypeVar(std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars,
std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { }
void VisitType_(const TypeParamNode* tp) final {
auto var = GetRef<TypeParam>(tp);
if (bound_vars->count(var) == 0) {
free_vars->insert(var);
}
}
void VisitType_(const FuncTypeNode* f) final {
for (auto type_param : f->type_params) {
bound_vars->insert(type_param);
}
for (auto type_cs : f->type_constraints) {
this->VisitType(type_cs);
}
for (auto arg_type : f->arg_types) {
this->VisitType(arg_type);
}
this->VisitType(f->ret_type);
}
friend FreeVar;
};
class FreeVar : public ExprVisitor {
void VisitExpr_(const VarNode *v) final {
auto var = GetRef<Var>(v);
if (bound_vars.count(var) == 0) {
free_vars.insert(var);
}
}
void VisitExpr_(const FunctionNode *f) final {
for (const auto& tp : f->type_params) {
bound_types.insert(tp);
}
for (const auto& p : f->params) {
bound_vars.insert(p->var);
}
VisitExpr(f->body);
VisitType(f->ret_type);
}
void VisitExpr_(const LetNode *l) final {
bound_vars.insert(l->var);
VisitExpr(l->value);
VisitExpr(l->body);
VisitType(l->value_type);
}
public:
std::unordered_set<Var, NodeHash, NodeEqual> free_vars;
std::unordered_set<Var, NodeHash, NodeEqual> bound_vars;
std::unordered_set<TypeParam, NodeHash, NodeEqual> free_types;
std::unordered_set<TypeParam, NodeHash, NodeEqual> bound_types;
void VisitType(const Type& t) final {
FreeTypeVar(&free_types, &bound_types)(t);
}
};
tvm::Array<Var> FreeVariables(const Expr& e) {
FreeVar fv;
fv.VisitExpr(e);
return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end());
}
tvm::Array<TypeParam> FreeTypeVariables(const Expr& e) {
FreeVar fv;
fv.VisitExpr(e);
return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end());
}
tvm::Array<TypeParam> FreeTypeVariables(const Type& t) {
FreeVar fv;
fv.VisitType(t);
return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end());
}
TVM_REGISTER_API("relay._ir_pass.free_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FreeVariables(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as<TypeNode>()) {
*ret = FreeTypeVariables(Downcast<Type>(x));
} else {
*ret = FreeTypeVariables(Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm
import tvm
from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars
def test_free_vars():
x = relay.Var("x")
fvx = free_vars(x)
assert len(fvx) == 1
assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10))
ty = relay.TensorType([], "int32")
let = relay.Let(x, v, x, ty)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([relay.Param(x, ty)], ty, x)
assert len(free_vars(f)) == 0
def test_free_type_vars():
tp = relay.TypeParam("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x")
y = relay.Var("y")
let = relay.Let(x, y, x, ty)
fvl = free_vars(let)
assert len(fvl) == 1
assert fvl[0] == y
ftvl = free_type_vars(let)
assert len(ftvl) == 1
assert ftvl[0] == tp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment