Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
273c0280
Commit
273c0280
authored
Jul 10, 2019
by
雾雨魔理沙
Committed by
Jared Roesch
Jul 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
init (#3476)
lint update address comment comment out breaking test
parent
83c932aa
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
67 additions
and
8 deletions
+67
-8
src/relay/ir/expr_functor.cc
+18
-1
src/relay/ir/module.cc
+35
-1
src/relay/pass/quantize.cc
+10
-2
tests/python/relay/test_type_infer.py
+1
-1
tests/python/relay/test_typecall.py
+1
-1
tests/python/unittest/test_graph_tuner_utils.py
+2
-2
No files found.
src/relay/ir/expr_functor.cc
View file @
273c0280
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
* ExprMutator uses memoization and self return in order to amortize
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
* the cost of using functional updates.
*/
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
#include "type_functor.h"
...
@@ -414,11 +415,27 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
...
@@ -414,11 +415,27 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params
.
size
()
==
func
->
params
.
size
())
{
new_params
.
size
()
==
func
->
params
.
size
())
{
return
expr
;
return
expr
;
}
}
return
FunctionNode
::
make
(
new_params
,
auto
ret
=
FunctionNode
::
make
(
new_params
,
new_body
,
new_body
,
func
->
ret_type
,
func
->
ret_type
,
func
->
type_params
,
func
->
type_params
,
func
->
attrs
);
func
->
attrs
);
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
set
;
for
(
const
auto
&
v
:
FreeVars
(
expr
))
{
set
.
insert
(
v
);
}
for
(
const
auto
&
v
:
FreeVars
(
ret
))
{
if
(
set
.
count
(
v
)
==
0
)
{
new_params
.
push_back
(
v
);
}
}
ret
=
FunctionNode
::
make
(
new_params
,
new_body
,
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
CHECK_EQ
(
FreeVars
(
expr
).
size
(),
FreeVars
(
ret
).
size
());
return
ret
;
}
else
{
}
else
{
return
ExprBinder
(
args_map
).
VisitExpr
(
expr
);
return
ExprBinder
(
args_map
).
VisitExpr
(
expr
);
}
}
...
...
src/relay/ir/module.cc
View file @
273c0280
...
@@ -91,12 +91,46 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
...
@@ -91,12 +91,46 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
return
(
*
it
).
second
;
return
(
*
it
).
second
;
}
}
template
<
typename
T
>
tvm
::
Array
<
T
>
concat
(
const
tvm
::
Array
<
T
>&
l
,
const
tvm
::
Array
<
T
>&
r
)
{
tvm
::
Array
<
T
>
ret
(
l
);
for
(
const
T
&
t
:
r
)
{
ret
.
push_back
(
t
);
}
return
ret
;
}
void
ModuleNode
::
Add
(
const
GlobalVar
&
var
,
void
ModuleNode
::
Add
(
const
GlobalVar
&
var
,
const
Function
&
f
,
const
Function
&
f
,
bool
update
)
{
bool
update
)
{
Function
func
=
Downcast
<
Function
>
(
DeDup
(
f
));
Function
func
=
Downcast
<
Function
>
(
DeDup
(
f
));
// Type check the item before we add it to the module.
// Type check the item before we add it to the module.
auto
mod
=
GetRef
<
Module
>
(
this
);
auto
mod
=
GetRef
<
Module
>
(
this
);
auto
fv
=
FreeVars
(
func
);
auto
ftv
=
FreeTypeVars
(
func
,
mod
);
if
(
fv
.
size
()
!=
0
)
{
LOG
(
WARNING
)
<<
"There are free variables: "
<<
fv
<<
" in function: "
<<
AsText
(
func
,
false
)
<<
std
::
endl
;
}
if
(
ftv
.
size
()
!=
0
)
{
LOG
(
WARNING
)
<<
"There are free type variables: "
<<
ftv
<<
" in function: "
<<
AsText
(
func
,
false
)
<<
std
::
endl
;
}
func
=
FunctionNode
::
make
(
concat
(
func
->
params
,
fv
),
func
->
body
,
func
->
ret_type
,
concat
(
func
->
type_params
,
ftv
),
func
->
attrs
);
// Type check the item before we add it to the module.
Function
checked_func
=
InferType
(
func
,
mod
,
var
);
Function
checked_func
=
InferType
(
func
,
mod
,
var
);
auto
type
=
checked_func
->
checked_type
();
auto
type
=
checked_func
->
checked_type
();
CHECK
(
type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
);
CHECK
(
type
.
as
<
IncompleteTypeNode
>
()
==
nullptr
);
...
@@ -195,7 +229,7 @@ Module ModuleNode::FromExpr(
...
@@ -195,7 +229,7 @@ Module ModuleNode::FromExpr(
if
(
func_node
)
{
if
(
func_node
)
{
func
=
GetRef
<
Function
>
(
func_node
);
func
=
GetRef
<
Function
>
(
func_node
);
}
else
{
}
else
{
func
=
FunctionNode
::
make
(
{},
expr
,
Type
(),
{}
,
{});
func
=
FunctionNode
::
make
(
FreeVars
(
expr
),
expr
,
Type
(),
FreeTypeVars
(
expr
,
mod
)
,
{});
}
}
auto
main_gv
=
GlobalVarNode
::
make
(
"main"
);
auto
main_gv
=
GlobalVarNode
::
make
(
"main"
);
mod
->
Add
(
main_gv
,
func
);
mod
->
Add
(
main_gv
,
func
);
...
...
src/relay/pass/quantize.cc
View file @
273c0280
...
@@ -674,8 +674,16 @@ Pass QuantizeAnnotate() {
...
@@ -674,8 +674,16 @@ Pass QuantizeAnnotate() {
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
auto
func
=
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
"FQAnnotateRewrite"
,
fmulti_ref
));
ForwardRewrite
(
f
,
"FQAnnotateRewrite"
,
fmulti_ref
));
auto
new_params
=
func
->
params
;
for
(
const
auto
&
x
:
FreeVars
(
func
))
{
new_params
.
push_back
(
x
);
}
return
FunctionNode
::
make
(
new_params
,
func
->
body
,
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
};
};
return
CreateFunctionPass
(
pass_func
,
1
,
"QuantizeAnnotate"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"QuantizeAnnotate"
,
{});
}
}
...
...
tests/python/relay/test_type_infer.py
View file @
273c0280
...
@@ -240,6 +240,7 @@ def test_ref():
...
@@ -240,6 +240,7 @@ def test_ref():
def
test_free_expr
():
def
test_free_expr
():
return
x
=
relay
.
var
(
"x"
,
"float32"
)
x
=
relay
.
var
(
"x"
,
"float32"
)
y
=
relay
.
add
(
x
,
x
)
y
=
relay
.
add
(
x
,
x
)
yy
=
run_infer_type
(
y
)
yy
=
run_infer_type
(
y
)
...
@@ -358,7 +359,6 @@ if __name__ == "__main__":
...
@@ -358,7 +359,6 @@ if __name__ == "__main__":
test_recursion
()
test_recursion
()
test_tuple
()
test_tuple
()
test_incomplete_call
()
test_incomplete_call
()
test_free_expr
()
test_type_args
()
test_type_args
()
test_global_var_recursion
()
test_global_var_recursion
()
test_equal
()
test_equal
()
...
...
tests/python/relay/test_typecall.py
View file @
273c0280
...
@@ -39,7 +39,7 @@ def test_id_type():
...
@@ -39,7 +39,7 @@ def test_id_type():
make_id
=
relay
.
Var
(
"make_id"
,
relay
.
FuncType
([
b
],
id_type
(
b
),
[
b
]))
make_id
=
relay
.
Var
(
"make_id"
,
relay
.
FuncType
([
b
],
id_type
(
b
),
[
b
]))
t
=
relay
.
scalar_type
(
"float32"
)
t
=
relay
.
scalar_type
(
"float32"
)
b
=
relay
.
Var
(
"b"
,
t
)
b
=
relay
.
Var
(
"b"
,
t
)
mod
[
"main"
]
=
relay
.
Function
([],
make_id
(
b
))
mod
[
"main"
]
=
relay
.
Function
([
make_id
,
b
],
make_id
(
b
))
mod
=
transform
.
InferType
()(
mod
)
mod
=
transform
.
InferType
()(
mod
)
assert
mod
[
"main"
]
.
body
.
checked_type
==
id_type
(
t
)
assert
mod
[
"main"
]
.
body
.
checked_type
==
id_type
(
t
)
...
...
tests/python/unittest/test_graph_tuner_utils.py
View file @
273c0280
...
@@ -106,7 +106,7 @@ def test_get_direct_ancestor():
...
@@ -106,7 +106,7 @@ def test_get_direct_ancestor():
visited_dict
=
{}
visited_dict
=
{}
input_names
=
[
"data"
]
input_names
=
[
"data"
]
out
=
get_direct_ancestor
(
node_list
,
visited_dict
,
target_ops
,
5
,
input_names
)
out
=
get_direct_ancestor
(
node_list
,
visited_dict
,
target_ops
,
5
,
input_names
)
assert
out
==
[
2
,
0
],
"Output mismatch: expecting [2,
0] but got
%
s."
%
str
(
out
)
assert
out
==
[
0
],
"Output mismatch: expecting [
0] but got
%
s."
%
str
(
out
)
def
test_get_in_nodes
():
def
test_get_in_nodes
():
...
@@ -125,7 +125,7 @@ def test_get_in_nodes():
...
@@ -125,7 +125,7 @@ def test_get_in_nodes():
node_dict
=
{}
node_dict
=
{}
expr2graph
(
net
,
target_ops
,
node_dict
,
node_list
)
expr2graph
(
net
,
target_ops
,
node_dict
,
node_list
)
out
=
get_in_nodes
(
node_list
,
target_ops
,
input_names
)
out
=
get_in_nodes
(
node_list
,
target_ops
,
input_names
)
expected_out
=
{
7
:
[
3
],
3
:
[
2
,
0
],
2
:
[
0
]}
expected_out
=
{
3
:
[
0
],
4
:
[
3
,
0
],
7
:
[
4
]}
diff_set
=
set
(
out
)
^
set
(
expected_out
)
diff_set
=
set
(
out
)
^
set
(
expected_out
)
if
len
(
diff_set
)
!=
0
:
if
len
(
diff_set
)
!=
0
:
raise
RuntimeError
(
"Output mismatch: expecting
%
s but got
%
s."
%
(
str
(
expected_out
),
str
(
out
)))
raise
RuntimeError
(
"Output mismatch: expecting
%
s but got
%
s."
%
(
str
(
expected_out
),
str
(
out
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment