Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7d906c9d
Commit
7d906c9d
authored
Sep 30, 2018
by
雾雨魔理沙
Committed by
Tianqi Chen
Sep 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Free Variables (#1786)
parent
e928109c
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
187 additions
and
6 deletions
+187
-6
include/tvm/relay/pass.h
+30
-0
python/tvm/relay/ir_pass.py
+4
-0
src/relay/pass/type_visitor.h
+6
-6
src/relay/pass/util.cc
+118
-0
tests/python/relay/test_free_vars.py
+29
-0
No files found.
include/tvm/relay/pass.h
View file @
7d906c9d
...
...
@@ -92,6 +92,36 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/
bool
WellFormed
(
const
Expr
&
e
);
/*! \brief Get free variables from expression e.
*
* Free variables are variables that are not bound by a let or a function parameter in the context.
*
* \param e the expression.
*
* \return the set of free variable.
*/
tvm
::
Array
<
Var
>
FreeVariables
(
const
Expr
&
e
);
/*! \brief Get free type parameters from expression e.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param e the expression.
*
* \return the set of free type variables.
*/
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Expr
&
e
);
/*! \brief Get free type parameters from type t.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param t the type.
*
* \return the set of free type variables.
*/
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Type
&
t
);
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_H_
python/tvm/relay/ir_pass.py
View file @
7d906c9d
...
...
@@ -14,3 +14,7 @@ check_expr = _ir_pass.check_expr
well_formed
=
_ir_pass
.
well_formed
check_kind
=
_ir_pass
.
check_kind
free_vars
=
_ir_pass
.
free_vars
free_type_vars
=
_ir_pass
.
free_type_vars
src/relay/pass/type_visitor.h
View file @
7d906c9d
...
...
@@ -95,13 +95,13 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
type_params
,
type_constraints
);
}
Type
VisitType_
(
const
TupleTypeNode
*
op
)
override
{
std
::
vector
<
Type
>
new_fields
;
for
(
const
Type
&
t
:
op
->
fields
)
{
new_fields
.
push_back
(
this
->
VisitType
(
t
));
}
return
TupleTypeNode
::
make
(
new_fields
);
Type
VisitType_
(
const
TupleTypeNode
*
op
)
override
{
std
::
vector
<
Type
>
new_fields
;
for
(
const
Type
&
t
:
op
->
fields
)
{
new_fields
.
push_back
(
this
->
VisitType
(
t
));
}
return
TupleTypeNode
::
make
(
new_fields
);
}
Type
VisitType_
(
const
TypeRelationNode
*
type_rel
)
override
{
std
::
vector
<
Type
>
new_args
;
...
...
src/relay/pass/util.cc
0 → 100644
View file @
7d906c9d
/*!
* Copyright (c) 2018 by Contributors
*
* \file util.cc
*
* \brief simple util for relay.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "./type_visitor.h"
namespace
tvm
{
namespace
relay
{
class
FreeVar
;
class
FreeTypeVar
:
private
TypeVisitor
<>
{
std
::
unordered_set
<
TypeParam
,
NodeHash
,
NodeEqual
>
*
free_vars
;
std
::
unordered_set
<
TypeParam
,
NodeHash
,
NodeEqual
>
*
bound_vars
;
FreeTypeVar
(
std
::
unordered_set
<
TypeParam
,
NodeHash
,
NodeEqual
>
*
free_vars
,
std
::
unordered_set
<
TypeParam
,
NodeHash
,
NodeEqual
>
*
bound_vars
)
:
free_vars
(
free_vars
),
bound_vars
(
bound_vars
)
{
}
void
VisitType_
(
const
TypeParamNode
*
tp
)
final
{
auto
var
=
GetRef
<
TypeParam
>
(
tp
);
if
(
bound_vars
->
count
(
var
)
==
0
)
{
free_vars
->
insert
(
var
);
}
}
void
VisitType_
(
const
FuncTypeNode
*
f
)
final
{
for
(
auto
type_param
:
f
->
type_params
)
{
bound_vars
->
insert
(
type_param
);
}
for
(
auto
type_cs
:
f
->
type_constraints
)
{
this
->
VisitType
(
type_cs
);
}
for
(
auto
arg_type
:
f
->
arg_types
)
{
this
->
VisitType
(
arg_type
);
}
this
->
VisitType
(
f
->
ret_type
);
}
friend
FreeVar
;
};
class
FreeVar
:
public
ExprVisitor
{
void
VisitExpr_
(
const
VarNode
*
v
)
final
{
auto
var
=
GetRef
<
Var
>
(
v
);
if
(
bound_vars
.
count
(
var
)
==
0
)
{
free_vars
.
insert
(
var
);
}
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
bound_types
.
insert
(
tp
);
}
for
(
const
auto
&
p
:
f
->
params
)
{
bound_vars
.
insert
(
p
->
var
);
}
VisitExpr
(
f
->
body
);
VisitType
(
f
->
ret_type
);
}
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
bound_vars
.
insert
(
l
->
var
);
VisitExpr
(
l
->
value
);
VisitExpr
(
l
->
body
);
VisitType
(
l
->
value_type
);
}
public
:
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
free_vars
;
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
bound_vars
;
std
::
unordered_set
<
TypeParam
,
NodeHash
,
NodeEqual
>
free_types
;
std
::
unordered_set
<
TypeParam
,
NodeHash
,
NodeEqual
>
bound_types
;
void
VisitType
(
const
Type
&
t
)
final
{
FreeTypeVar
(
&
free_types
,
&
bound_types
)(
t
);
}
};
tvm
::
Array
<
Var
>
FreeVariables
(
const
Expr
&
e
)
{
FreeVar
fv
;
fv
.
VisitExpr
(
e
);
return
tvm
::
Array
<
Var
>
(
fv
.
free_vars
.
begin
(),
fv
.
free_vars
.
end
());
}
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Expr
&
e
)
{
FreeVar
fv
;
fv
.
VisitExpr
(
e
);
return
tvm
::
Array
<
TypeParam
>
(
fv
.
free_types
.
begin
(),
fv
.
free_types
.
end
());
}
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Type
&
t
)
{
FreeVar
fv
;
fv
.
VisitType
(
t
);
return
tvm
::
Array
<
TypeParam
>
(
fv
.
free_types
.
begin
(),
fv
.
free_types
.
end
());
}
TVM_REGISTER_API
(
"relay._ir_pass.free_vars"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
FreeVariables
(
args
[
0
]);
});
TVM_REGISTER_API
(
"relay._ir_pass.free_type_vars"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
NodeRef
x
=
args
[
0
];
if
(
x
.
as
<
TypeNode
>
())
{
*
ret
=
FreeTypeVariables
(
Downcast
<
Type
>
(
x
));
}
else
{
*
ret
=
FreeTypeVariables
(
Downcast
<
Expr
>
(
x
));
}
});
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_free_vars.py
0 → 100644
View file @
7d906c9d
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
free_vars
,
free_type_vars
def
test_free_vars
():
x
=
relay
.
Var
(
"x"
)
fvx
=
free_vars
(
x
)
assert
len
(
fvx
)
==
1
assert
fvx
[
0
]
==
x
v
=
relay
.
Constant
(
tvm
.
nd
.
array
(
10
))
ty
=
relay
.
TensorType
([],
"int32"
)
let
=
relay
.
Let
(
x
,
v
,
x
,
ty
)
fvx
=
free_vars
(
let
)
assert
len
(
free_vars
(
let
))
==
0
f
=
relay
.
Function
([
relay
.
Param
(
x
,
ty
)],
ty
,
x
)
assert
len
(
free_vars
(
f
))
==
0
def
test_free_type_vars
():
tp
=
relay
.
TypeParam
(
""
)
ty
=
relay
.
TupleType
([
tp
,
relay
.
TensorType
([],
"int32"
)])
x
=
relay
.
Var
(
"x"
)
y
=
relay
.
Var
(
"y"
)
let
=
relay
.
Let
(
x
,
y
,
x
,
ty
)
fvl
=
free_vars
(
let
)
assert
len
(
fvl
)
==
1
assert
fvl
[
0
]
==
y
ftvl
=
free_type_vars
(
let
)
assert
len
(
ftvl
)
==
1
assert
ftvl
[
0
]
==
tp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment