Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
06bbc7c9
Unverified
Commit
06bbc7c9
authored
Mar 17, 2020
by
Zhi
Committed by
GitHub
Mar 17, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Replace UseDefaultCompiler with GetAttr (#5088)
parent
4ae46748
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
12 additions
and
25 deletions
+12
-25
include/tvm/relay/function.h
+0
-9
src/relay/backend/compile_engine.cc
+2
-2
src/relay/backend/graph_runtime_codegen.cc
+3
-2
src/relay/backend/vm/compiler.cc
+2
-2
src/relay/backend/vm/inline_primitives.cc
+1
-1
src/relay/backend/vm/lambda_lift.cc
+1
-1
src/relay/ir/function.cc
+0
-5
src/relay/ir/transform.cc
+1
-1
src/relay/transforms/inline.cc
+1
-1
src/relay/transforms/to_a_normal_form.cc
+1
-1
No files found.
include/tvm/relay/function.h
View file @
06bbc7c9
...
@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
...
@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
*/
*/
TVM_DLL
FuncType
func_type_annotation
()
const
;
TVM_DLL
FuncType
func_type_annotation
()
const
;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool
UseDefaultCompiler
()
const
;
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
static
constexpr
const
char
*
_type_key
=
"relay.Function"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
FunctionNode
,
BaseFuncNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
FunctionNode
,
BaseFuncNode
);
};
};
...
...
src/relay/backend/compile_engine.cc
View file @
06bbc7c9
...
@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
...
@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
for
(
const
auto
&
it
:
cache_
)
{
for
(
const
auto
&
it
:
cache_
)
{
auto
src_func
=
it
.
first
->
source_func
;
auto
src_func
=
it
.
first
->
source_func
;
CHECK
(
src_func
.
defined
());
CHECK
(
src_func
.
defined
());
if
(
!
src_func
->
UseDefaultCompiler
())
{
if
(
src_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
auto
code_gen
=
src_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
);
auto
code_gen
=
src_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
);
CHECK
(
code_gen
.
defined
())
<<
"No external codegen is set"
;
CHECK
(
code_gen
.
defined
())
<<
"No external codegen is set"
;
if
(
ext_mods
.
find
(
code_gen
->
value
)
==
ext_mods
.
end
())
{
if
(
ext_mods
.
find
(
code_gen
->
value
)
==
ext_mods
.
end
())
{
...
@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
...
@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
}
}
// No need to lower external functions for now. We will invoke the external
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
// codegen tool once and lower all functions together.
if
(
!
key
->
source_func
->
UseDefaultCompiler
())
{
if
(
key
->
source_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
auto
cache_node
=
make_object
<
CachedFuncNode
>
();
auto
cache_node
=
make_object
<
CachedFuncNode
>
();
const
auto
name_node
=
const
auto
name_node
=
key
->
source_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kExternalSymbol
);
key
->
source_func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kExternalSymbol
);
...
...
src/relay/backend/graph_runtime_codegen.cc
View file @
06bbc7c9
...
@@ -424,7 +424,7 @@ class GraphRuntimeCodegen
...
@@ -424,7 +424,7 @@ class GraphRuntimeCodegen
auto
pf1
=
GetPackedFunc
(
"relay.backend._CompileEngineLower"
);
auto
pf1
=
GetPackedFunc
(
"relay.backend._CompileEngineLower"
);
Target
target
;
Target
target
;
// Handle external function
// Handle external function
if
(
!
func
->
UseDefaultCompiler
())
{
if
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
target
=
tvm
::
target
::
ext_dev
();
target
=
tvm
::
target
::
ext_dev
();
CCacheKey
key
=
(
*
pf0
)(
func
,
target
);
CCacheKey
key
=
(
*
pf0
)(
func
,
target
);
CachedFunc
ext_func
=
(
*
pf1
)(
compile_engine_
,
key
);
CachedFunc
ext_func
=
(
*
pf1
)(
compile_engine_
,
key
);
...
@@ -490,7 +490,8 @@ class GraphRuntimeCodegen
...
@@ -490,7 +490,8 @@ class GraphRuntimeCodegen
return
{};
return
{};
}
}
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
FunctionNode
*
op
)
override
{
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
FunctionNode
*
op
)
override
{
CHECK
(
!
op
->
UseDefaultCompiler
())
<<
"Only functions supported by custom codegen"
;
CHECK
(
op
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
<<
"Only functions supported by custom codegen"
;
return
{};
return
{};
}
}
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
RefCreateNode
*
op
)
override
{
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
RefCreateNode
*
op
)
override
{
...
...
src/relay/backend/vm/compiler.cc
View file @
06bbc7c9
...
@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
...
@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Target
target
;
Target
target
;
if
(
!
func
->
UseDefaultCompiler
())
{
if
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
target
=
tvm
::
target
::
ext_dev
();
target
=
tvm
::
target
::
ext_dev
();
}
else
{
}
else
{
// Next generate the invoke instruction.
// Next generate the invoke instruction.
...
@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
...
@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto
cfunc
=
engine_
->
Lower
(
key
);
auto
cfunc
=
engine_
->
Lower
(
key
);
auto
op_index
=
-
1
;
auto
op_index
=
-
1
;
if
(
!
func
->
UseDefaultCompiler
())
{
if
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
op_index
=
context_
->
cached_funcs
.
size
();
op_index
=
context_
->
cached_funcs
.
size
();
context_
->
cached_funcs
.
push_back
(
cfunc
);
context_
->
cached_funcs
.
push_back
(
cfunc
);
}
else
{
}
else
{
...
...
src/relay/backend/vm/inline_primitives.cc
View file @
06bbc7c9
...
@@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator {
...
@@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto
global
=
pair
.
first
;
auto
global
=
pair
.
first
;
auto
base_func
=
pair
.
second
;
auto
base_func
=
pair
.
second
;
if
(
auto
*
n
=
base_func
.
as
<
FunctionNode
>
())
{
if
(
auto
*
n
=
base_func
.
as
<
FunctionNode
>
())
{
if
(
!
n
->
UseDefaultCompiler
())
continue
;
if
(
n
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
continue
;
auto
func
=
GetRef
<
Function
>
(
n
);
auto
func
=
GetRef
<
Function
>
(
n
);
DLOG
(
INFO
)
<<
"Before inlining primitives: "
<<
global
DLOG
(
INFO
)
<<
"Before inlining primitives: "
<<
global
...
...
src/relay/backend/vm/lambda_lift.cc
View file @
06bbc7c9
...
@@ -187,7 +187,7 @@ class LambdaLifter : public ExprMutator {
...
@@ -187,7 +187,7 @@ class LambdaLifter : public ExprMutator {
auto
glob_funcs
=
module_
->
functions
;
auto
glob_funcs
=
module_
->
functions
;
for
(
auto
pair
:
glob_funcs
)
{
for
(
auto
pair
:
glob_funcs
)
{
if
(
auto
*
n
=
pair
.
second
.
as
<
FunctionNode
>
())
{
if
(
auto
*
n
=
pair
.
second
.
as
<
FunctionNode
>
())
{
if
(
!
n
->
UseDefaultCompiler
())
continue
;
if
(
n
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
continue
;
auto
func
=
GetRef
<
Function
>
(
n
);
auto
func
=
GetRef
<
Function
>
(
n
);
func
=
Function
(
func
->
params
,
func
=
Function
(
func
->
params
,
VisitExpr
(
func
->
body
),
VisitExpr
(
func
->
body
),
...
...
src/relay/ir/function.cc
View file @
06bbc7c9
...
@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
...
@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
return
FuncType
(
param_types
,
ret_type
,
this
->
type_params
,
{});
return
FuncType
(
param_types
,
ret_type
,
this
->
type_params
,
{});
}
}
bool
FunctionNode
::
UseDefaultCompiler
()
const
{
tir
::
StringImm
val
=
this
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
);
return
!
val
.
defined
()
||
val
->
value
==
"default"
;
}
TVM_REGISTER_NODE_TYPE
(
FunctionNode
);
TVM_REGISTER_NODE_TYPE
(
FunctionNode
);
TVM_REGISTER_GLOBAL
(
"relay.ir.Function"
)
TVM_REGISTER_GLOBAL
(
"relay.ir.Function"
)
...
...
src/relay/ir/transform.cc
View file @
06bbc7c9
...
@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
...
@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
bool
FunctionPassNode
::
SkipFunction
(
const
Function
&
func
)
const
{
bool
FunctionPassNode
::
SkipFunction
(
const
Function
&
func
)
const
{
return
func
->
GetAttr
<
Integer
>
(
attr
::
kSkipOptimization
,
0
)
->
value
!=
0
||
return
func
->
GetAttr
<
Integer
>
(
attr
::
kSkipOptimization
,
0
)
->
value
!=
0
||
!
(
func
->
UseDefaultCompiler
());
(
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
());
}
}
Pass
CreateFunctionPass
(
Pass
CreateFunctionPass
(
...
...
src/relay/transforms/inline.cc
View file @
06bbc7c9
...
@@ -131,7 +131,7 @@ class Inliner : ExprMutator {
...
@@ -131,7 +131,7 @@ class Inliner : ExprMutator {
fn
->
attrs
);
fn
->
attrs
);
// Inline the function body to the caller if this function uses default
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
// compiler, i.e. no external codegen is needed.
if
(
func
->
UseDefaultCompiler
())
{
if
(
!
func
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
{
CHECK_EQ
(
func
->
params
.
size
(),
args
.
size
())
CHECK_EQ
(
func
->
params
.
size
(),
args
.
size
())
<<
"Mismatch found in the number of parameters and call args"
;
<<
"Mismatch found in the number of parameters and call args"
;
// Bind the parameters with call args.
// Bind the parameters with call args.
...
...
src/relay/transforms/to_a_normal_form.cc
View file @
06bbc7c9
...
@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
...
@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for
(
const
auto
&
it
:
funcs
)
{
for
(
const
auto
&
it
:
funcs
)
{
CHECK_EQ
(
FreeVars
(
it
.
second
).
size
(),
0
);
CHECK_EQ
(
FreeVars
(
it
.
second
).
size
(),
0
);
if
(
const
auto
*
n
=
it
.
second
.
as
<
FunctionNode
>
())
{
if
(
const
auto
*
n
=
it
.
second
.
as
<
FunctionNode
>
())
{
if
(
!
n
->
UseDefaultCompiler
())
continue
;
if
(
n
->
GetAttr
<
tir
::
StringImm
>
(
attr
::
kCompiler
).
defined
())
continue
;
}
}
Expr
ret
=
Expr
ret
=
TransformF
([
&
](
const
Expr
&
e
)
{
TransformF
([
&
](
const
Expr
&
e
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment