Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5088a034
Unverified
Commit
5088a034
authored
Mar 20, 2020
by
Zhi
Committed by
GitHub
Mar 20, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][BYOCG] Propagate constant to subgraphs (#5094)
* bind constant to subgraphs * con -> constant
parent
53643bdb
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
3 deletions
+115
-3
src/relay/backend/contrib/codegen_c/codegen.cc
+55
-1
src/relay/backend/contrib/codegen_c/codegen_c.h
+2
-2
src/relay/transforms/partition_graph.cc
+13
-0
tests/python/relay/test_pass_partition_graph.py
+45
-0
No files found.
src/relay/backend/contrib/codegen_c/codegen.cc
View file @
5088a034
...
...
@@ -19,6 +19,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/object.h>
...
...
@@ -40,7 +41,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
public
:
explicit
CodegenC
(
const
std
::
string
&
id
)
{
this
->
ext_func_id_
=
id
;
}
void
VisitExpr_
(
const
VarNode
*
node
)
{
void
VisitExpr_
(
const
VarNode
*
node
)
final
{
ext_func_args_
.
push_back
(
GetRef
<
Var
>
(
node
));
out_
.
clear
();
Output
output
;
...
...
@@ -48,6 +49,55 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
out_
.
push_back
(
output
);
}
void
VisitExpr_
(
const
ConstantNode
*
cn
)
final
{
Constant
constant
=
GetRef
<
Constant
>
(
cn
);
if
(
visited_
.
count
(
constant
))
{
// Note this is for demostration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.
out_
.
push_back
(
visited_
[
constant
]);
return
;
}
std
::
ostringstream
decl_stream
;
std
::
ostringstream
buf_stream
;
out_
.
clear
();
Output
output
;
output
.
name
=
"const_"
+
std
::
to_string
(
const_idx_
++
);
out_
.
push_back
(
output
);
visited_
[
constant
]
=
output
;
runtime
::
NDArray
array
=
cn
->
data
;
const
auto
&
shape
=
array
.
Shape
();
const
DLTensor
&
dl_tensor
=
array
.
ToDLPack
()
->
dl_tensor
;
// Get the number of elements.
int64_t
num_elems
=
1
;
for
(
auto
i
:
shape
)
num_elems
*=
i
;
const
auto
*
type_node
=
cn
->
checked_type
().
as
<
TensorTypeNode
>
();
CHECK
(
type_node
);
const
auto
&
dtype
=
GetDtypeString
(
type_node
);
// Define a const buffer: float const_0[64] = {1.0, 2.0, ...};
//
// Technically, you may need: static float* const_0 = (float*)malloc(4 * 64)
// to avoid possible stack overflow.
buf_stream
<<
dtype
<<
" "
<<
output
.
name
<<
"["
<<
num_elems
<<
"] = {"
;
if
(
dtype
==
"float"
)
{
float
*
p_flt
=
static_cast
<
float
*>
(
dl_tensor
.
data
);
for
(
int64_t
i
=
0
;
i
<
num_elems
-
1
;
i
++
)
buf_stream
<<
p_flt
[
i
]
<<
", "
;
if
(
num_elems
)
buf_stream
<<
p_flt
[
num_elems
-
1
];
}
else
if
(
dtype
==
"int"
)
{
int
*
p_flt
=
static_cast
<
int
*>
(
dl_tensor
.
data
);
for
(
int64_t
i
=
0
;
i
<
num_elems
-
1
;
i
++
)
buf_stream
<<
p_flt
[
i
]
<<
", "
;
if
(
num_elems
)
buf_stream
<<
p_flt
[
num_elems
-
1
];
}
else
{
LOG
(
FATAL
)
<<
"Only float and int are supported for now."
;
}
buf_stream
<<
"};"
;
ext_func_body
.
insert
(
ext_func_body
.
begin
(),
buf_stream
.
str
());
}
void
VisitExpr_
(
const
CallNode
*
call
)
final
{
std
::
ostringstream
macro_stream
;
std
::
ostringstream
decl_stream
;
...
...
@@ -138,6 +188,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
int
func_idx
=
0
;
/*! \brief The index of allocated buffers. */
int
buf_idx_
=
0
;
/*! \brief The index of global constants. */
int
const_idx_
=
0
;
/*! \brief The arguments of a C compiler compatible function. */
Array
<
Var
>
ext_func_args_
;
/*! \brief The statements of a C compiler compatible function. */
...
...
@@ -148,6 +200,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
std
::
vector
<
std
::
string
>
buf_decl_
;
/*! \brief The name and index pairs for output. */
std
::
vector
<
Output
>
out_
;
/*! \brief The cached expressions. */
std
::
unordered_map
<
Expr
,
Output
,
ObjectHash
,
ObjectEqual
>
visited_
;
};
class
CSourceCodegen
:
public
CSourceModuleCodegenBase
{
...
...
src/relay/backend/contrib/codegen_c/codegen_c.h
View file @
5088a034
...
...
@@ -197,7 +197,7 @@ class CodegenCBase {
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool
IsOp
(
const
CallNode
*
call
,
std
::
string
op_name
)
const
{
bool
IsOp
(
const
CallNode
*
call
,
const
std
::
string
&
op_name
)
const
{
const
auto
*
op_node
=
call
->
op
.
as
<
OpNode
>
();
CHECK
(
op_node
)
<<
"Expects a single op."
;
Op
op
=
GetRef
<
Op
>
(
op_node
);
...
...
@@ -218,7 +218,7 @@ class CodegenCBase {
*
* \return The emitted code string.
*/
std
::
string
JitImpl
(
std
::
string
ext_func_id
,
const
Array
<
Var
>&
args
,
std
::
string
JitImpl
(
const
std
::
string
&
ext_func_id
,
const
Array
<
Var
>&
args
,
const
std
::
vector
<
std
::
string
>&
buf_decl
,
const
std
::
vector
<
std
::
string
>&
body
,
const
std
::
vector
<
Output
>&
out
)
{
...
...
src/relay/transforms/partition_graph.cc
View file @
5088a034
...
...
@@ -42,6 +42,8 @@
#include <utility>
#include <vector>
#include "../backend/utils.h"
namespace
tvm
{
namespace
relay
{
namespace
partitioning
{
...
...
@@ -200,15 +202,21 @@ class Partitioner : public ExprMutator {
auto
input
=
VisitExpr
(
call
->
args
[
0
]);
Array
<
Var
>
params
;
Array
<
Expr
>
args
;
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>
params_bind
;
// The subgraph may be merged so we need to update it again.
subgraph
=
GetSubgraph
(
GetRef
<
Call
>
(
call
));
CHECK
(
subgraph
);
// Record the constants for propagation.
for
(
auto
pair
:
subgraph
->
args
)
{
params
.
push_back
(
pair
.
first
);
if
(
const
auto
*
cn
=
pair
.
second
.
as
<
ConstantNode
>
())
{
params_bind
[
pair
.
first
->
name_hint
()]
=
cn
->
data
;
}
else
{
args
.
push_back
(
pair
.
second
);
}
}
auto
subgraph_func
=
Function
(
params
,
input
,
call
->
checked_type_
,
{});
...
...
@@ -223,6 +231,11 @@ class Partitioner : public ExprMutator {
tvm
::
tir
::
StringImmNode
::
make
(
compiler_attrs
->
compiler
));
subgraph_func
=
WithAttr
(
std
::
move
(
subgraph_func
),
attr
::
kInline
,
tvm
::
Integer
(
1
));
// Constant propagation
if
(
!
params_bind
.
empty
())
{
subgraph_func
=
backend
::
BindParamsByName
(
subgraph_func
,
params_bind
);
}
CHECK
(
!
module_
->
ContainGlobalVar
(
name
))
<<
"Global function "
<<
name
<<
" already exists"
;
// Create a global function and add it to the IRModule for the subgraph.
...
...
tests/python/relay/test_pass_partition_graph.py
View file @
5088a034
...
...
@@ -634,6 +634,50 @@ def test_function_lifting_inline():
assert
relay
.
analysis
.
alpha_equal
(
partitioned
,
ref_mod
)
def
test_constant_propagation
():
ones
=
np
.
ones
(
shape
=
(
8
,
8
),
dtype
=
"float32"
)
def
expected
():
mod
=
tvm
.
IRModule
()
x
=
relay
.
const
(
ones
)
y
=
relay
.
var
(
"y"
,
shape
=
(
8
,
8
))
x0
=
relay
.
const
(
ones
)
y0
=
relay
.
var
(
"y0"
,
shape
=
(
8
,
8
))
add
=
x0
+
y0
# Function that uses C compiler
func
=
relay
.
Function
([
y0
],
add
)
func
=
func
.
with_attr
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
func
=
func
.
with_attr
(
"Inline"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
func
=
func
.
with_attr
(
"Compiler"
,
tvm
.
tir
.
StringImm
(
"ccompiler"
))
func
=
func
.
with_attr
(
"ExternalSymbol"
,
tvm
.
tir
.
StringImm
(
"ccompiler_0"
))
glb_0
=
relay
.
GlobalVar
(
"ccompiler_0"
)
mod
[
glb_0
]
=
func
add_call
=
relay
.
Call
(
glb_0
,
[
y
])
log
=
relay
.
log
(
add_call
)
main
=
relay
.
Function
([
y
],
log
)
mod
[
"main"
]
=
main
return
mod
x
=
relay
.
var
(
"x"
,
shape
=
(
8
,
8
))
y
=
relay
.
var
(
"y"
,
shape
=
(
8
,
8
))
add
=
x
+
y
log
=
relay
.
log
(
add
)
f
=
relay
.
Function
([
x
,
y
],
log
)
f
=
relay
.
build_module
.
bind_params_by_name
(
f
,
{
"x"
:
tvm
.
nd
.
array
(
ones
)})
mod
=
tvm
.
IRModule
()
mod
[
"main"
]
=
f
mod
=
WhiteListAnnotator
([
"add"
],
"ccompiler"
)(
mod
)
mod
=
transform
.
PartitionGraph
()(
mod
)
expected_mod
=
expected
()
assert
relay
.
alpha_equal
(
mod
,
expected_mod
)
y_data
=
np
.
random
.
rand
(
8
,
8
)
.
astype
(
'float32'
)
np_add
=
ones
+
y_data
check_result
(
mod
,
{
"y"
:
y_data
},
(
8
,
8
),
np
.
log
(
np_add
))
if
__name__
==
"__main__"
:
test_multi_node_compiler
()
test_extern_ccompiler_single_op
()
...
...
@@ -643,3 +687,4 @@ if __name__ == "__main__":
test_extern_dnnl_mobilenet
()
test_function_lifting
()
test_function_lifting_inline
()
test_constant_propagation
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment