Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
10f85d03
Unverified
Commit
10f85d03
authored
Jan 31, 2020
by
masahi
Committed by
GitHub
Jan 31, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Dedup BindParamByName function in VM compiler (#4793)
parent
24126b42
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
83 deletions
+44
-83
src/relay/backend/build_module.cc
+1
-38
src/relay/backend/utils.h
+41
-0
src/relay/backend/vm/compiler.cc
+2
-35
src/relay/backend/vm/compiler.h
+0
-10
No files found.
src/relay/backend/build_module.cc
View file @
10f85d03
...
@@ -42,43 +42,6 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
...
@@ -42,43 +42,6 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
using
namespace
tvm
::
relay
::
transform
;
using
namespace
tvm
::
relay
::
transform
;
/*!
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay
::
Function
BindParamsByName
(
relay
::
Function
func
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
std
::
unordered_map
<
std
::
string
,
relay
::
Var
>
name_dict
;
std
::
unordered_set
<
relay
::
Var
,
ObjectHash
,
ObjectEqual
>
repeat_var
;
for
(
auto
arg
:
func
->
params
)
{
const
auto
&
name
=
arg
->
name_hint
();
if
(
name_dict
.
count
(
name
))
{
repeat_var
.
insert
(
arg
);
}
else
{
name_dict
[
name
]
=
arg
;
}
}
std
::
unordered_map
<
relay
::
Var
,
Expr
,
ObjectHash
,
ObjectEqual
>
bind_dict
;
for
(
auto
&
kv
:
params
)
{
if
(
name_dict
.
count
(
kv
.
first
)
==
0
)
{
continue
;
}
auto
arg
=
name_dict
.
at
(
kv
.
first
);
if
(
repeat_var
.
count
(
arg
))
{
LOG
(
FATAL
)
<<
"Multiple args in the function have name "
<<
kv
.
first
;
}
bind_dict
[
arg
]
=
ConstantNode
::
make
(
kv
.
second
);
}
Expr
bound_expr
=
relay
::
Bind
(
func
,
bind_dict
);
Function
ret
=
Downcast
<
Function
>
(
bound_expr
);
CHECK
(
ret
.
defined
())
<<
"The returning type is expected to be a Relay Function."
<<
"
\n
"
;
return
ret
;
}
/*!
* \brief Output of building module
* \brief Output of building module
*
*
*/
*/
...
@@ -527,7 +490,7 @@ TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
...
@@ -527,7 +490,7 @@ TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
for
(
const
auto
&
kv
:
params
)
{
for
(
const
auto
&
kv
:
params
)
{
params_
[
kv
.
first
]
=
kv
.
second
->
data
;
params_
[
kv
.
first
]
=
kv
.
second
->
data
;
}
}
*
rv
=
BindParamsByName
(
args
[
0
],
params_
);
*
rv
=
relay
::
backend
::
BindParamsByName
(
args
[
0
],
params_
);
});
});
}
// namespace backend
}
// namespace backend
...
...
src/relay/backend/utils.h
View file @
10f85d03
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <dmlc/json.h>
#include <dmlc/json.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h>
#include <tvm/driver/driver_api.h>
#include <tvm/target/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/ir_pass.h>
...
@@ -34,6 +35,8 @@
...
@@ -34,6 +35,8 @@
#include <typeinfo>
#include <typeinfo>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_set>
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
...
@@ -81,6 +84,44 @@ inline std::string DType2String(const tvm::DataType dtype) {
...
@@ -81,6 +84,44 @@ inline std::string DType2String(const tvm::DataType dtype) {
return
os
.
str
();
return
os
.
str
();
}
}
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
inline
relay
::
Function
BindParamsByName
(
relay
::
Function
func
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
std
::
unordered_map
<
std
::
string
,
relay
::
Var
>
name_dict
;
std
::
unordered_set
<
relay
::
Var
,
ObjectHash
,
ObjectEqual
>
repeat_var
;
for
(
auto
arg
:
func
->
params
)
{
const
auto
&
name
=
arg
->
name_hint
();
if
(
name_dict
.
count
(
name
))
{
repeat_var
.
insert
(
arg
);
}
else
{
name_dict
[
name
]
=
arg
;
}
}
std
::
unordered_map
<
relay
::
Var
,
Expr
,
ObjectHash
,
ObjectEqual
>
bind_dict
;
for
(
auto
&
kv
:
params
)
{
if
(
name_dict
.
count
(
kv
.
first
)
==
0
)
{
continue
;
}
auto
arg
=
name_dict
.
at
(
kv
.
first
);
if
(
repeat_var
.
count
(
arg
))
{
LOG
(
FATAL
)
<<
"Multiple args in the function have name "
<<
kv
.
first
;
}
bind_dict
[
arg
]
=
ConstantNode
::
make
(
kv
.
second
);
}
Expr
bound_expr
=
relay
::
Bind
(
func
,
bind_dict
);
Function
ret
=
Downcast
<
Function
>
(
bound_expr
);
CHECK
(
ret
.
defined
())
<<
"The returning type is expected to be a Relay Function."
<<
"
\n
"
;
return
ret
;
}
}
// namespace backend
}
// namespace backend
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
...
...
src/relay/backend/vm/compiler.cc
View file @
10f85d03
...
@@ -37,9 +37,8 @@
...
@@ -37,9 +37,8 @@
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <vector>
#include "../utils.h"
#include "../../backend/compile_engine.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
#include "../../op/op_common.h"
...
@@ -783,38 +782,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
...
@@ -783,38 +782,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_
[
name
]
=
data_in
;
params_
[
name
]
=
data_in
;
}
}
relay
::
Function
VMCompiler
::
BindParamsByName
(
relay
::
Function
func
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
std
::
unordered_map
<
std
::
string
,
relay
::
Var
>
name_dict
;
std
::
unordered_set
<
relay
::
Var
,
ObjectHash
,
ObjectEqual
>
repeat_var
;
for
(
auto
arg
:
func
->
params
)
{
const
auto
&
name
=
arg
->
name_hint
();
if
(
name_dict
.
count
(
name
))
{
repeat_var
.
insert
(
arg
);
}
else
{
name_dict
[
name
]
=
arg
;
}
}
std
::
unordered_map
<
relay
::
Var
,
Expr
,
ObjectHash
,
ObjectEqual
>
bind_dict
;
for
(
auto
&
kv
:
params
)
{
if
(
name_dict
.
count
(
kv
.
first
)
==
0
)
{
continue
;
}
auto
arg
=
name_dict
.
at
(
kv
.
first
);
if
(
repeat_var
.
count
(
arg
))
{
LOG
(
FATAL
)
<<
"Multiple args in the function have name "
<<
kv
.
first
;
}
bind_dict
[
arg
]
=
ConstantNode
::
make
(
kv
.
second
);
}
Expr
bound_expr
=
relay
::
Bind
(
func
,
bind_dict
);
Function
ret
=
Downcast
<
Function
>
(
bound_expr
);
CHECK
(
ret
.
defined
())
<<
"The returning type is expected to be a Relay Function."
<<
"
\n
"
;
return
ret
;
}
void
VMCompiler
::
Lower
(
IRModule
mod
,
void
VMCompiler
::
Lower
(
IRModule
mod
,
const
TargetsMap
&
targets
,
const
TargetsMap
&
targets
,
const
tvm
::
Target
&
target_host
)
{
const
tvm
::
Target
&
target_host
)
{
...
@@ -824,7 +791,7 @@ void VMCompiler::Lower(IRModule mod,
...
@@ -824,7 +791,7 @@ void VMCompiler::Lower(IRModule mod,
BaseFunc
base_func
=
mod
->
Lookup
(
"main"
);
BaseFunc
base_func
=
mod
->
Lookup
(
"main"
);
CHECK
(
base_func
->
IsInstance
<
FunctionNode
>
())
CHECK
(
base_func
->
IsInstance
<
FunctionNode
>
())
<<
"VM compiler expects to compile relay::Function"
;
<<
"VM compiler expects to compile relay::Function"
;
auto
f
=
BindParamsByName
(
Downcast
<
Function
>
(
base_func
),
params_
);
auto
f
=
relay
::
backend
::
BindParamsByName
(
Downcast
<
Function
>
(
base_func
),
params_
);
auto
gvar
=
mod
->
GetGlobalVar
(
"main"
);
auto
gvar
=
mod
->
GetGlobalVar
(
"main"
);
mod
->
Add
(
gvar
,
f
);
mod
->
Add
(
gvar
,
f
);
}
}
...
...
src/relay/backend/vm/compiler.h
View file @
10f85d03
...
@@ -115,16 +115,6 @@ class VMCompiler : public runtime::ModuleNode {
...
@@ -115,16 +115,6 @@ class VMCompiler : public runtime::ModuleNode {
void
Codegen
();
void
Codegen
();
protected
:
protected
:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay
::
Function
BindParamsByName
(
relay
::
Function
func
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
);
IRModule
OptimizeModule
(
const
IRModule
&
mod
,
const
TargetsMap
&
targets
);
IRModule
OptimizeModule
(
const
IRModule
&
mod
,
const
TargetsMap
&
targets
);
void
PopulateGlobalMap
();
void
PopulateGlobalMap
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment