Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
395804e5
Commit
395804e5
authored
Dec 22, 2018
by
Jared Roesch
Committed by
Tianqi Chen
Dec 22, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Small refactors and bug fixes. (#2281)
parent
5cb729ec
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
188 additions
and
144 deletions
+188
-144
include/tvm/relay/expr.h
+7
-0
python/tvm/relay/__init__.py
+5
-0
python/tvm/relay/backend/graph_runtime_codegen.py
+5
-1
python/tvm/relay/expr.py
+3
-127
python/tvm/relay/expr_functor.py
+155
-0
src/relay/backend/compile_engine.cc
+4
-4
src/relay/backend/interpreter.cc
+1
-8
src/relay/ir/expr.cc
+7
-1
src/relay/pass/fuse_ops.cc
+1
-3
No files found.
include/tvm/relay/expr.h
View file @
395804e5
...
@@ -248,6 +248,13 @@ class FunctionNode : public ExprNode {
...
@@ -248,6 +248,13 @@ class FunctionNode : public ExprNode {
*/
*/
TVM_DLL
FuncType
func_type_annotation
()
const
;
TVM_DLL
FuncType
func_type_annotation
()
const
;
/*!
* \brief Check whether the function is a primitive function.
*
* \return Whether the function is primitive or not.
*/
bool
IsPrimitive
()
const
;
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Var
>
params
,
TVM_DLL
static
Function
make
(
tvm
::
Array
<
Var
>
params
,
Expr
body
,
Expr
body
,
Type
ret_type
,
Type
ret_type
,
...
...
python/tvm/relay/__init__.py
View file @
395804e5
...
@@ -5,6 +5,7 @@ from ..api import register_func
...
@@ -5,6 +5,7 @@ from ..api import register_func
from
.
import
base
from
.
import
base
from
.
import
ty
from
.
import
ty
from
.
import
expr
from
.
import
expr
from
.
import
expr_functor
from
.
import
module
from
.
import
module
from
.
import
ir_pass
from
.
import
ir_pass
from
.build_module
import
build
,
build_config
,
create_executor
from
.build_module
import
build
,
build_config
,
create_executor
...
@@ -53,6 +54,10 @@ Let = expr.Let
...
@@ -53,6 +54,10 @@ Let = expr.Let
If
=
expr
.
If
If
=
expr
.
If
TupleGetItem
=
expr
.
TupleGetItem
TupleGetItem
=
expr
.
TupleGetItem
# ExprFunctor
ExprFunctor
=
expr_functor
.
ExprFunctor
ExprMutator
=
expr_functor
.
ExprMutator
# helper functions
# helper functions
var
=
expr
.
var
var
=
expr
.
var
const
=
expr
.
const
const
=
expr
.
const
...
...
python/tvm/relay/backend/graph_runtime_codegen.py
View file @
395804e5
...
@@ -24,7 +24,8 @@ import attr
...
@@ -24,7 +24,8 @@ import attr
from
.
import
_backend
from
.
import
_backend
from
.
import
compile_engine
from
.
import
compile_engine
from
..op
import
Op
from
..op
import
Op
from
..expr
import
Function
,
GlobalVar
,
ExprFunctor
from
..expr
import
Function
,
GlobalVar
from
..expr_functor
import
ExprFunctor
from
..ty
import
TupleType
,
TensorType
from
..ty
import
TupleType
,
TensorType
...
@@ -251,6 +252,9 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -251,6 +252,9 @@ class GraphRuntimeCodegen(ExprFunctor):
op_name
,
inputs
,
{})
op_name
,
inputs
,
{})
return
self
.
add_node
(
op_node
,
call
)
return
self
.
add_node
(
op_node
,
call
)
def
visit_op
(
self
,
_
):
raise
Exception
(
"can not compile op in non-eta expanded form"
)
def
_get_json
(
self
):
def
_get_json
(
self
):
"""
"""
Convert the sequence of nodes stored by the compiler into the
Convert the sequence of nodes stored by the compiler into the
...
...
python/tvm/relay/expr.py
View file @
395804e5
...
@@ -222,12 +222,13 @@ class Function(Expr):
...
@@ -222,12 +222,13 @@ class Function(Expr):
params
,
params
,
body
,
body
,
ret_type
=
None
,
ret_type
=
None
,
type_params
=
None
):
type_params
=
None
,
attrs
=
None
):
if
type_params
is
None
:
if
type_params
is
None
:
type_params
=
convert
([])
type_params
=
convert
([])
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_make
.
Function
,
params
,
body
,
ret_type
,
type_params
)
_make
.
Function
,
params
,
body
,
ret_type
,
type_params
,
attrs
)
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
*
args
):
"""Invoke the gobal function.
"""Invoke the gobal function.
...
@@ -343,131 +344,6 @@ class TempExpr(Expr):
...
@@ -343,131 +344,6 @@ class TempExpr(Expr):
return
_expr
.
TempExprRealize
(
self
)
return
_expr
.
TempExprRealize
(
self
)
class
ExprFunctor
(
object
):
"""
An abstract visitor defined over Expr.
Defines the default dispatch over expressions, and
implements memoization.
"""
def
__init__
(
self
):
self
.
memo_map
=
{}
# pylint: disable=no-else-return
def
visit
(
self
,
expr
):
"""Apply the visitor to an expression."""
found
=
self
.
memo_map
.
get
(
expr
)
if
found
:
return
found
if
isinstance
(
expr
,
Function
):
res
=
self
.
visit_function
(
expr
)
elif
isinstance
(
expr
,
Call
):
res
=
self
.
visit_call
(
expr
)
elif
isinstance
(
expr
,
Let
):
res
=
self
.
visit_let
(
expr
)
elif
isinstance
(
expr
,
Var
):
res
=
self
.
visit_var
(
expr
)
elif
isinstance
(
expr
,
GlobalVar
):
res
=
self
.
visit_global_var
(
expr
)
elif
isinstance
(
expr
,
If
):
res
=
self
.
visit_if
(
expr
)
elif
isinstance
(
expr
,
Tuple
):
res
=
self
.
visit_tuple
(
expr
)
elif
isinstance
(
expr
,
TupleGetItem
):
res
=
self
.
visit_tuple_getitem
(
expr
)
elif
isinstance
(
expr
,
Constant
):
res
=
self
.
visit_constant
(
expr
)
else
:
raise
Exception
(
"warning unhandled case: {0}"
.
format
(
type
(
expr
)))
self
.
memo_map
[
expr
]
=
res
return
res
def
visit_function
(
self
,
_
):
raise
NotImplementedError
()
def
visit_let
(
self
,
_
):
raise
NotImplementedError
()
def
visit_call
(
self
,
_
):
raise
NotImplementedError
()
def
visit_var
(
self
,
_
):
raise
NotImplementedError
()
def
visit_type
(
self
,
typ
):
return
typ
def
visit_if
(
self
,
_
):
raise
NotImplementedError
()
def
visit_tuple
(
self
,
_
):
raise
NotImplementedError
()
def
visit_tuple_getitem
(
self
,
_
):
raise
NotImplementedError
()
def
visit_constant
(
self
,
_
):
raise
NotImplementedError
()
def
visit_global_var
(
self
,
_
):
raise
NotImplementedError
()
class
ExprMutator
(
ExprFunctor
):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def
visit_function
(
self
,
fn
):
new_body
=
self
.
visit
(
fn
.
body
)
return
Function
(
list
(
fn
.
params
),
fn
.
ret_type
,
new_body
,
fn
.
type_params
)
def
visit_let
(
self
,
let
):
new_var
=
self
.
visit
(
let
.
var
)
new_val
=
self
.
visit
(
let
.
value
)
new_body
=
self
.
visit
(
let
.
body
)
return
Let
(
new_var
,
new_val
,
new_body
)
def
visit_call
(
self
,
call
):
new_fn
=
self
.
visit
(
call
.
op
)
new_args
=
[
self
.
visit
(
arg
)
for
arg
in
call
.
args
]
return
Call
(
new_fn
,
new_args
,
call
.
attrs
)
def
visit_var
(
self
,
rvar
):
return
rvar
def
visit_global_id
(
self
,
global_var
):
return
global_var
def
visit_if
(
self
,
ite
):
return
If
(
self
.
visit
(
ite
.
guard
),
self
.
visit
(
ite
.
true_b
),
self
.
visit
(
ite
.
false_b
))
def
visit_tuple
(
self
,
tup
):
return
Tuple
([
self
.
visit
(
field
)
for
field
in
tup
.
fields
])
def
visit_tuple_getitem
(
self
,
op
):
tuple_value
=
self
.
visit
(
op
.
tuple_value
)
if
not
tuple_value
.
same_as
(
op
.
tuple_value
):
return
TupleGetItem
(
tuple_value
,
op
.
index
)
return
op
def
visit_global_var
(
self
,
gvar
):
return
gvar
def
visit_constant
(
self
,
rconst
):
return
rconst
class
TupleWrapper
(
object
):
class
TupleWrapper
(
object
):
"""TupleWrapper.
"""TupleWrapper.
...
...
python/tvm/relay/expr_functor.py
0 → 100644
View file @
395804e5
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
from
.expr
import
Function
,
Call
,
Let
,
Var
,
GlobalVar
,
If
,
Tuple
,
TupleGetItem
,
Constant
from
.op
import
Op
class
ExprFunctor
:
"""
An abstract visitor defined over Expr.
Defines the default dispatch over expressions, and
implements memoization.
"""
def
__init__
(
self
):
self
.
memo_map
=
{}
# pylint: disable=no-else-return
def
visit
(
self
,
expr
):
"""Apply the visitor to an expression."""
found
=
self
.
memo_map
.
get
(
expr
)
if
found
:
return
found
if
isinstance
(
expr
,
Function
):
res
=
self
.
visit_function
(
expr
)
elif
isinstance
(
expr
,
Call
):
res
=
self
.
visit_call
(
expr
)
elif
isinstance
(
expr
,
Let
):
res
=
self
.
visit_let
(
expr
)
elif
isinstance
(
expr
,
Var
):
res
=
self
.
visit_var
(
expr
)
elif
isinstance
(
expr
,
GlobalVar
):
res
=
self
.
visit_global_var
(
expr
)
elif
isinstance
(
expr
,
If
):
res
=
self
.
visit_if
(
expr
)
elif
isinstance
(
expr
,
Tuple
):
res
=
self
.
visit_tuple
(
expr
)
elif
isinstance
(
expr
,
TupleGetItem
):
res
=
self
.
visit_tuple_getitem
(
expr
)
elif
isinstance
(
expr
,
Constant
):
res
=
self
.
visit_constant
(
expr
)
elif
isinstance
(
expr
,
Op
):
res
=
self
.
visit_op
(
expr
)
else
:
raise
Exception
(
"warning unhandled case: {0}"
.
format
(
type
(
expr
)))
self
.
memo_map
[
expr
]
=
res
return
res
def
visit_function
(
self
,
_
):
raise
NotImplementedError
()
def
visit_let
(
self
,
_
):
raise
NotImplementedError
()
def
visit_call
(
self
,
_
):
raise
NotImplementedError
()
def
visit_var
(
self
,
_
):
raise
NotImplementedError
()
def
visit_type
(
self
,
typ
):
return
typ
def
visit_if
(
self
,
_
):
raise
NotImplementedError
()
def
visit_tuple
(
self
,
_
):
raise
NotImplementedError
()
def
visit_tuple_getitem
(
self
,
_
):
raise
NotImplementedError
()
def
visit_global_var
(
self
,
_
):
raise
NotImplementedError
()
def
visit_op
(
self
,
_
):
raise
NotImplementedError
()
def
visit_constant
(
self
,
_
):
raise
NotImplementedError
()
class
ExprMutator
(
ExprFunctor
):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def
visit_function
(
self
,
fn
):
new_body
=
self
.
visit
(
fn
.
body
)
return
Function
(
list
(
fn
.
params
),
new_body
,
fn
.
ret_type
,
fn
.
type_params
,
fn
.
attrs
)
def
visit_let
(
self
,
let
):
new_var
=
self
.
visit
(
let
.
var
)
new_val
=
self
.
visit
(
let
.
value
)
new_body
=
self
.
visit
(
let
.
body
)
return
Let
(
new_var
,
new_val
,
new_body
)
def
visit_call
(
self
,
call
):
new_fn
=
self
.
visit
(
call
.
op
)
new_args
=
[
self
.
visit
(
arg
)
for
arg
in
call
.
args
]
return
Call
(
new_fn
,
new_args
,
call
.
attrs
)
def
visit_var
(
self
,
rvar
):
return
rvar
def
visit_global_id
(
self
,
global_var
):
return
global_var
def
visit_if
(
self
,
ite
):
return
If
(
self
.
visit
(
ite
.
guard
),
self
.
visit
(
ite
.
true_b
),
self
.
visit
(
ite
.
false_b
))
def
visit_tuple
(
self
,
tup
):
return
Tuple
([
self
.
visit
(
field
)
for
field
in
tup
.
fields
])
def
visit_tuple_getitem
(
self
,
op
):
tuple_value
=
self
.
visit
(
op
.
tuple_value
)
if
not
tuple_value
.
same_as
(
op
.
tuple_value
):
return
TupleGetItem
(
tuple_value
,
op
.
index
)
return
op
def
visit_global_var
(
self
,
gvar
):
return
gvar
def
visit_op
(
self
,
op
):
return
op
def
visit_constant
(
self
,
const
):
return
const
def
visit_constructor
(
self
,
con
):
return
con
def
visit_match
(
self
,
m
):
return
Match
(
self
.
visit
(
m
.
data
),
[
Clause
(
c
.
lhs
,
self
.
visit
(
c
.
rhs
))
for
c
in
m
.
pattern
])
def
visit_ref_new
(
self
,
r
):
return
RefNew
(
self
.
visit
(
r
.
value
))
def
visit_ref_write
(
self
,
r
):
return
RefWrite
(
self
.
visit
(
r
.
ref
),
self
.
visit
(
r
.
value
))
def
visit_ref_read
(
self
,
r
):
return
RefRead
(
self
.
visit
(
r
.
ref
))
src/relay/backend/compile_engine.cc
View file @
395804e5
...
@@ -157,14 +157,14 @@ class ScheduleGetter :
...
@@ -157,14 +157,14 @@ class ScheduleGetter :
int
op_pattern
=
fpattern
[
op
];
int
op_pattern
=
fpattern
[
op
];
if
(
op_pattern
>=
kCommReduce
)
{
if
(
op_pattern
>=
kCommReduce
)
{
CHECK
(
!
master_op_
.
defined
()
||
master_op_pat
et
rn_
<
kCommReduce
)
CHECK
(
!
master_op_
.
defined
()
||
master_op_pat
te
rn_
<
kCommReduce
)
<<
"Two complicated op in a primitive function "
<<
"Two complicated op in a primitive function "
<<
" master="
<<
master_op_
<<
" current="
<<
op
;
<<
" master="
<<
master_op_
<<
" current="
<<
op
;
}
}
if
(
op_pattern
>=
master_op_pat
et
rn_
)
{
if
(
op_pattern
>=
master_op_pat
te
rn_
)
{
master_op_
=
op
;
master_op_
=
op
;
master_attrs_
=
call_node
->
attrs
;
master_attrs_
=
call_node
->
attrs
;
master_op_pat
et
rn_
=
op_pattern
;
master_op_pat
te
rn_
=
op_pattern
;
}
}
if
(
outputs
.
size
()
!=
1
)
{
if
(
outputs
.
size
()
!=
1
)
{
const
auto
*
tuple_type
=
const
auto
*
tuple_type
=
...
@@ -213,7 +213,7 @@ class ScheduleGetter :
...
@@ -213,7 +213,7 @@ class ScheduleGetter :
tvm
::
Target
target_
;
tvm
::
Target
target_
;
Op
master_op_
;
Op
master_op_
;
Attrs
master_attrs_
;
Attrs
master_attrs_
;
int
master_op_pat
et
rn_
{
0
};
int
master_op_pat
te
rn_
{
0
};
std
::
ostringstream
readable_name_stream_
;
std
::
ostringstream
readable_name_stream_
;
std
::
unordered_map
<
Expr
,
Array
<
Tensor
>
,
NodeHash
,
NodeEqual
>
memo_
;
std
::
unordered_map
<
Expr
,
Array
<
Tensor
>
,
NodeHash
,
NodeEqual
>
memo_
;
};
};
...
...
src/relay/backend/interpreter.cc
View file @
395804e5
...
@@ -292,17 +292,10 @@ class Interpreter :
...
@@ -292,17 +292,10 @@ class Interpreter :
}
}
}
}
// Check if function is a primitive function.
bool
IsPrimitive
(
const
Function
&
func
)
const
{
NodeRef
res
=
FunctionGetAttr
(
func
,
"Primitive"
);
const
ir
::
IntImm
*
pval
=
res
.
as
<
ir
::
IntImm
>
();
return
pval
&&
pval
->
value
!=
0
;
}
// Invoke the closure
// Invoke the closure
Value
Invoke
(
const
Closure
&
closure
,
const
tvm
::
Array
<
Value
>&
args
)
{
Value
Invoke
(
const
Closure
&
closure
,
const
tvm
::
Array
<
Value
>&
args
)
{
// Get a reference to the function inside the closure.
// Get a reference to the function inside the closure.
if
(
IsPrimitive
(
closure
->
func
))
{
if
(
closure
->
func
->
IsPrimitive
(
))
{
return
InvokePrimitiveOp
(
closure
->
func
,
args
);
return
InvokePrimitiveOp
(
closure
->
func
,
args
);
}
}
auto
func
=
closure
->
func
;
auto
func
=
closure
->
func
;
...
...
src/relay/ir/expr.cc
View file @
395804e5
...
@@ -135,6 +135,12 @@ FuncType FunctionNode::func_type_annotation() const {
...
@@ -135,6 +135,12 @@ FuncType FunctionNode::func_type_annotation() const {
return
FuncTypeNode
::
make
(
param_types
,
this
->
ret_type
,
this
->
type_params
,
{});
return
FuncTypeNode
::
make
(
param_types
,
this
->
ret_type
,
this
->
type_params
,
{});
}
}
bool
FunctionNode
::
IsPrimitive
()
const
{
NodeRef
res
=
FunctionGetAttr
(
GetRef
<
Function
>
(
this
),
"Primitive"
);
const
ir
::
IntImm
*
pval
=
res
.
as
<
ir
::
IntImm
>
();
return
pval
&&
pval
->
value
!=
0
;
}
NodeRef
FunctionGetAttr
(
const
Function
&
func
,
const
std
::
string
&
key
)
{
NodeRef
FunctionGetAttr
(
const
Function
&
func
,
const
std
::
string
&
key
)
{
if
(
!
func
->
attrs
.
defined
())
{
return
NodeRef
();
}
if
(
!
func
->
attrs
.
defined
())
{
return
NodeRef
();
}
...
@@ -172,7 +178,7 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
...
@@ -172,7 +178,7 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API
(
"relay._make.Function"
)
TVM_REGISTER_API
(
"relay._make.Function"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
FunctionNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
ret
=
FunctionNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]
,
args
[
4
]
);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
...
src/relay/pass/fuse_ops.cc
View file @
395804e5
...
@@ -699,9 +699,7 @@ class FuseMutator : private ExprMutator {
...
@@ -699,9 +699,7 @@ class FuseMutator : private ExprMutator {
std
::
unordered_map
<
GraphPartitioner
::
Group
*
,
GroupInfo
>
ginfo_
;
std
::
unordered_map
<
GraphPartitioner
::
Group
*
,
GroupInfo
>
ginfo_
;
// Skip primitive function.
// Skip primitive function.
Expr
VisitExpr_
(
const
FunctionNode
*
fn_node
)
{
Expr
VisitExpr_
(
const
FunctionNode
*
fn_node
)
{
NodeRef
res
=
FunctionGetAttr
(
GetRef
<
Function
>
(
fn_node
),
"Primitive"
);
if
(
fn_node
->
IsPrimitive
())
{
const
ir
::
IntImm
*
pval
=
res
.
as
<
ir
::
IntImm
>
();
if
(
pval
&&
pval
->
value
!=
0
)
{
return
GetRef
<
Expr
>
(
fn_node
);
return
GetRef
<
Expr
>
(
fn_node
);
}
else
{
}
else
{
return
ExprMutator
::
VisitExpr_
(
fn_node
);
return
ExprMutator
::
VisitExpr_
(
fn_node
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment