Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
06bbc7c9
Unverified
Commit
06bbc7c9
authored
Mar 17, 2020
by
Zhi
Committed by
GitHub
Mar 17, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Replace UseDefaultCompiler with GetAttr (#5088)
parent
4ae46748
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
24 additions
and
37 deletions
+24
-37
include/tvm/relay/function.h
+0
-9
src/relay/backend/compile_engine.cc
+2
-2
src/relay/backend/graph_runtime_codegen.cc
+3
-2
src/relay/backend/vm/compiler.cc
+2
-2
src/relay/backend/vm/inline_primitives.cc
+5
-5
src/relay/backend/vm/lambda_lift.cc
+5
-5
src/relay/ir/function.cc
+0
-5
src/relay/ir/transform.cc
+1
-1
src/relay/transforms/inline.cc
+5
-5
src/relay/transforms/to_a_normal_form.cc
+1
-1
No files found.
include/tvm/relay/function.h
View file @
06bbc7c9
...
...
@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
*/
TVM_DLL
FuncType
func_type_annotation
()
const
;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool
UseDefaultCompiler
()
const
;
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
FunctionNode
,
BaseFuncNode
);
};
...
...
src/relay/backend/compile_engine.cc
View file @
06bbc7c9
...
...
@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
for
(
const
auto
&
it
:
cache_
)
{
auto
src_func
=
it
.
first
->
source_func
;
CHECK
(
src_func
.
defined
());
if
(
!
src_func
->
UseDefaultCompiler
())
{
if
(
src_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
auto
code_gen
=
src_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
);
CHECK
(
code_gen
.
defined
())
<<
"No external codegen is set"
;
if
(
ext_mods
.
find
(
code_gen
->
value
)
==
ext_mods
.
end
())
{
...
...
@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if
(
!
key
->
source_func
->
UseDefaultCompiler
())
{
if
(
key
->
source_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
auto
cache_node
=
make_object
<
CachedFuncNode
>
();
const
auto
name_node
=
key
->
source_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kExternalSymbol
);
...
...
src/relay/backend/graph_runtime_codegen.cc
View file @
06bbc7c9
...
...
@@ -424,7 +424,7 @@ class GraphRuntimeCodegen
auto
pf1
=
GetPackedFunc
(
"relay.backend._CompileEngineLower"
);
Target
target
;
// Handle external function
if
(
!
func
->
UseDefaultCompiler
())
{
if
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
target
=
tvm
::
target
::
ext_dev
();
CCacheKey
key
=
(
*
pf0
)(
func
,
target
);
CachedFunc
ext_func
=
(
*
pf1
)(
compile_engine_
,
key
);
...
...
@@ -490,7 +490,8 @@ class GraphRuntimeCodegen
return
{};
}
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
FunctionNode
*
op
)
override
{
CHECK
(
!
op
->
UseDefaultCompiler
())
<<
"Only functions supported by custom codegen"
;
CHECK
(
op
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
<<
"Only functions supported by custom codegen"
;
return
{};
}
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
RefCreateNode
*
op
)
override
{
...
...
src/relay/backend/vm/compiler.cc
View file @
06bbc7c9
...
...
@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Target
target
;
if
(
!
func
->
UseDefaultCompiler
())
{
if
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
target
=
tvm
::
target
::
ext_dev
();
}
else
{
// Next generate the invoke instruction.
...
...
@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto
cfunc
=
engine_
->
Lower
(
key
);
auto
op_index
=
-
1
;
if
(
!
func
->
UseDefaultCompiler
())
{
if
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
op_index
=
context_
->
cached_funcs
.
size
();
context_
->
cached_funcs
.
push_back
(
cfunc
);
}
else
{
...
...
src/relay/backend/vm/inline_primitives.cc
View file @
06bbc7c9
...
...
@@ -122,17 +122,17 @@ struct PrimitiveInliner : ExprMutator {
auto
global
=
pair
.
first
;
auto
base_func
=
pair
.
second
;
if
(
auto
*
n
=
base_func
.
as
<
FunctionNode
>
())
{
if
(
!
n
->
UseDefaultCompiler
())
continue
;
if
(
n
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
continue
;
auto
func
=
GetRef
<
Function
>
(
n
);
DLOG
(
INFO
)
<<
"Before inlining primitives: "
<<
global
<<
std
::
endl
<<
AsText
(
func
,
false
);
func
=
Function
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
module_
->
Add
(
global
,
func
,
true
);
DLOG
(
INFO
)
<<
"After inlining primitives: "
<<
global
...
...
src/relay/backend/vm/lambda_lift.cc
View file @
06bbc7c9
...
...
@@ -187,13 +187,13 @@ class LambdaLifter : public ExprMutator {
auto
glob_funcs
=
module_
->
functions
;
for
(
auto
pair
:
glob_funcs
)
{
if
(
auto
*
n
=
pair
.
second
.
as
<
FunctionNode
>
())
{
if
(
!
n
->
UseDefaultCompiler
())
continue
;
if
(
n
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
continue
;
auto
func
=
GetRef
<
Function
>
(
n
);
func
=
Function
(
func
->
params
,
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
VisitExpr
(
func
->
body
),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
module_
->
Add
(
pair
.
first
,
func
,
true
);
}
}
...
...
src/relay/ir/function.cc
View file @
06bbc7c9
...
...
@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
return
FuncType
(
param_types
,
ret_type
,
this
->
type_params
,
{});
}
bool
FunctionNode
::
UseDefaultCompiler
()
const
{
tir
::
StringImm
val
=
this
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
);
return
!
val
.
defined
()
||
val
->
value
==
"default"
;
}
TVM_REGISTER_NODE_TYPE
(
FunctionNode
);
TVM_REGISTER_GLOBAL
(
"relay.ir.Function"
)
...
...
src/relay/ir/transform.cc
View file @
06bbc7c9
...
...
@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
bool
FunctionPassNode
::
SkipFunction
(
const
Function
&
func
)
const
{
return
func
->
GetAttr
<
Integer
>
(
attr
::
kSkipOptimization
,
0
)
->
value
!=
0
||
!
(
func
->
UseDefaultCompiler
());
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
());
}
Pass
CreateFunctionPass
(
...
...
src/relay/transforms/inline.cc
View file @
06bbc7c9
...
...
@@ -125,13 +125,13 @@ class Inliner : ExprMutator {
CHECK
(
fn
)
<<
"Expected to work on a Relay function."
;
auto
func
=
Function
(
fn
->
params
,
fn
->
body
,
fn
->
ret_type
,
fn
->
type_params
,
fn
->
attrs
);
fn
->
body
,
fn
->
ret_type
,
fn
->
type_params
,
fn
->
attrs
);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
if
(
func
->
UseDefaultCompiler
())
{
if
(
!
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
CHECK_EQ
(
func
->
params
.
size
(),
args
.
size
())
<<
"Mismatch found in the number of parameters and call args"
;
// Bind the parameters with call args.
...
...
src/relay/transforms/to_a_normal_form.cc
View file @
06bbc7c9
...
...
@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for
(
const
auto
&
it
:
funcs
)
{
CHECK_EQ
(
FreeVars
(
it
.
second
).
size
(),
0
);
if
(
const
auto
*
n
=
it
.
second
.
as
<
FunctionNode
>
())
{
if
(
!
n
->
UseDefaultCompiler
())
continue
;
if
(
n
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
continue
;
}
Expr
ret
=
TransformF
([
&
](
const
Expr
&
e
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment