Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
138ec7be
Commit
138ec7be
authored
May 24, 2019
by
Zhi
Committed by
Tianqi Chen
May 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Transform] merge PassContext and BuildConfig (#3234)
parent
415a270d
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
499 additions
and
262 deletions
+499
-262
docs/api/python/relay/build_module.rst
+0
-8
docs/api/python/relay/transform.rst
+45
-0
include/tvm/relay/transform.h
+81
-11
python/tvm/relay/__init__.py
+2
-1
python/tvm/relay/build_module.py
+13
-85
python/tvm/relay/quantize/quantize.py
+7
-7
python/tvm/relay/transform.py
+98
-27
src/relay/pass/pass_manager.cc
+249
-119
tests/python/frontend/coreml/test_forward.py
+2
-2
tests/python/frontend/keras/test_forward.py
+1
-1
tutorials/frontend/from_tflite.py
+1
-1
No files found.
docs/api/python/relay/build_module.rst
View file @
138ec7be
...
...
@@ -22,17 +22,9 @@ tvm.relay.build_module
.. autofunction:: tvm.relay.build_module.build
.. autofunction:: tvm.relay.build_module.build_config
.. autofunction:: tvm.relay.build_module.optimize
.. autofunction:: tvm.relay.build_module.create_executor
.. autoclass:: tvm.relay.build_module.BuildConfig
:members:
.. autofunction:: tvm.relay.build_module.build_config
:members:
.. autoclass:: tvm.relay.build_module.GraphExecutor
:members:
docs/api/python/relay/transform.rst
0 → 100644
View file @
138ec7be
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.transform
----------------------
.. automodule:: tvm.relay.transform
.. autofunction:: tvm.relay.transform.build_config
.. autofunction:: tvm.relay.transform.module_pass
.. autofunction:: tvm.relay.transform.function_pass
.. autoclass:: tvm.relay.transform.Pass
:members:
.. autoclass:: tvm.relay.transform.PassInfo
:members:
.. autoclass:: tvm.relay.transform.PassContext
:members:
.. autoclass:: tvm.relay.transform.ModulePass
:members:
.. autoclass:: tvm.relay.transform.FunctionPass
:members:
.. autoclass:: tvm.relay.transform.Sequential
:members:
include/tvm/relay/transform.h
View file @
138ec7be
...
...
@@ -56,11 +56,13 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace
tvm
{
...
...
@@ -83,18 +85,69 @@ class PassContextNode : public RelayNode {
*/
ErrorReporter
err_reporter
;
/*! \brief The default optimization level. */
int
opt_level
{
2
};
/*! \brief CPU is the default fallback device for heterogeneous execution. */
int
fallback_device
{
static_cast
<
int
>
(
kDLCPU
)};
/*! \brief The list of required passes. */
tvm
::
Array
<
tvm
::
Expr
>
required_pass
;
/*! \brief The list of disabled passes. */
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
;
PassContextNode
()
=
default
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"opt_level"
,
&
opt_level
);
v
->
Visit
(
"fallback_device"
,
&
fallback_device
);
v
->
Visit
(
"required_pass"
,
&
required_pass
);
v
->
Visit
(
"disabled_pass"
,
&
disabled_pass
);
}
TVM_DLL
static
PassContext
make
();
static
constexpr
const
char
*
_type_key
=
"relay.PassContext"
;
TVM_DECLARE_NODE_TYPE_INFO
(
PassContextNode
,
RelayNode
);
};
TVM_DEFINE_NODE_REF
(
PassContext
,
PassContextNode
)
class
PassContext
:
public
NodeRef
{
public
:
PassContext
()
{}
explicit
PassContext
(
tvm
::
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
*/
TVM_DLL
PassContext
(
int
opt_level
,
int
fallback_device
,
tvm
::
Array
<
tvm
::
Expr
>
required_pass
,
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
);
// Get the currently used pass context.
TVM_DLL
static
PassContext
Current
();
const
PassContextNode
*
operator
->
()
const
;
using
ContainerType
=
PassContextNode
;
class
Internal
;
private
:
// The entry of a pass context scope.
TVM_DLL
void
EnterWithScope
();
// The exit of a pass context scope.
TVM_DLL
void
ExitWithScope
();
// Classes to get the Python `with` like syntax.
friend
class
Internal
;
friend
class
tvm
::
With
<
PassContext
>
;
};
/*
* \brief The meta data of a pass.
...
...
@@ -150,20 +203,28 @@ class PassNode : public RelayNode {
virtual
PassInfo
Info
()
const
=
0
;
/*!
* \brief Set the context information for a pass.
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
*
* \param mod The module that an optimization pass runs on.
*
* \
param pass_ctx The context information for a certain pass
.
* \
return The updated module
.
*/
virtual
void
SetContext
(
const
PassContext
&
pass_ctx
)
=
0
;
Module
operator
()(
const
Module
&
mod
)
const
{
return
this
->
operator
()(
mod
,
PassContext
::
Current
());
}
/*!
* \brief Execute the optimization pass using a functor.
* \brief Execute the optimization pass using a functor
under a given pass context
.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
*
* \return The updated module.
*/
virtual
Module
operator
()(
const
Module
&
mod
)
const
=
0
;
virtual
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
=
0
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
override
{}
...
...
@@ -189,13 +250,22 @@ class Sequential : public Pass {
public
:
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
TVM_DLL
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
PassInfo
pass_info
,
tvm
::
Array
<
tvm
::
Expr
>
disabled
);
PassInfo
pass_info
);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
std
::
string
name
=
"sequential"
);
Sequential
()
=
default
;
explicit
Sequential
(
tvm
::
NodePtr
<::
tvm
::
Node
>
n
)
:
Pass
(
n
)
{}
...
...
python/tvm/relay/__init__.py
View file @
138ec7be
...
...
@@ -26,7 +26,8 @@ from . import module
from
.
import
adt
from
.
import
ir_pass
from
.
import
transform
from
.build_module
import
build
,
build_config
,
create_executor
from
.build_module
import
build
,
create_executor
from
.transform
import
build_config
from
.
import
prelude
from
.
import
parser
from
.
import
debug
...
...
python/tvm/relay/build_module.py
View file @
138ec7be
...
...
@@ -28,81 +28,10 @@ from . import _build_module
from
.
import
ir_pass
from
.
import
ty
as
_ty
from
.
import
expr
as
_expr
from
.
import
transform
as
_transform
from
.backend
import
interpreter
as
_interpreter
from
.backend.vm
import
VMExecutor
class
BuildConfig
(
object
):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current
=
None
defaults
=
{
"opt_level"
:
2
,
"add_pass"
:
None
,
"disable_pass"
:
None
,
"fallback_device"
:
None
,
}
def
__init__
(
self
,
**
kwargs
):
self
.
_old_scope
=
None
for
k
,
_
in
kwargs
.
items
():
if
k
not
in
BuildConfig
.
defaults
:
raise
ValueError
(
"invalid argument
%
s, candidates are
%
s"
%
(
k
,
BuildConfig
.
defaults
.
keys
()))
self
.
_attr
=
kwargs
def
__getattr__
(
self
,
name
):
if
name
not
in
self
.
_attr
:
return
BuildConfig
.
defaults
[
name
]
return
self
.
_attr
[
name
]
def
__enter__
(
self
):
# pylint: disable=protected-access
self
.
_old_scope
=
BuildConfig
.
current
attr
=
BuildConfig
.
current
.
_attr
.
copy
()
attr
.
update
(
self
.
_attr
)
self
.
_attr
=
attr
BuildConfig
.
current
=
self
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
assert
self
.
_old_scope
BuildConfig
.
current
=
self
.
_old_scope
BuildConfig
.
current
=
BuildConfig
()
def
build_config
(
**
kwargs
):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
disable_pass: set of str
Optimization pass to be disabled during optimization.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
The build configuration
"""
return
BuildConfig
(
**
kwargs
)
def
_update_target
(
target
):
target
=
target
if
target
else
_target
.
current_target
()
if
target
is
None
:
...
...
@@ -189,7 +118,7 @@ class BuildModule(object):
return
graph_json
,
mod
,
params
def
_setup_build_config
(
self
,
params
):
cfg
=
BuildConfig
.
current
cfg
=
_transform
.
PassContext
.
current
()
# Set opt_level.
self
.
set_opt_level
(
cfg
.
opt_level
)
...
...
@@ -199,24 +128,24 @@ class BuildModule(object):
self
.
set_fallback_device
(
cfg
.
fallback_device
)
# Add required passes.
if
cfg
.
ad
d_pass
:
if
cfg
.
require
d_pass
:
passes
=
set
()
if
isinstance
(
cfg
.
ad
d_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
ad
d_pass
)
if
isinstance
(
cfg
.
require
d_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
require
d_pass
)
else
:
raise
TypeError
(
"add_pass must be list, tuple, or set, but "
+
"got {}"
.
format
(
type
(
cfg
.
ad
d_pass
)))
"got {}"
.
format
(
type
(
cfg
.
require
d_pass
)))
for
pass_name
in
passes
:
self
.
add_pass
(
pass_name
)
# Add disabled passes.
if
cfg
.
disable_pass
:
if
cfg
.
disable
d
_pass
:
passes
=
set
()
if
isinstance
(
cfg
.
disable_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
disable_pass
)
if
isinstance
(
cfg
.
disable
d
_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
disable
d
_pass
)
else
:
raise
TypeError
(
"disable_pass must be list, tuple, or set, "
+
"but got {}"
.
format
(
type
(
cfg
.
disable_pass
)))
"but got {}"
.
format
(
type
(
cfg
.
disable
d
_pass
)))
for
pass_name
in
passes
:
self
.
disable_pass
(
pass_name
)
...
...
@@ -287,12 +216,11 @@ class BuildModule(object):
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
if
isinstance
(
fallback_device
,
str
):
if
isinstance
(
fallback_device
,
(
int
,
str
)
):
fallback_device
=
_nd
.
context
(
fallback_device
)
if
not
isinstance
(
fallback_device
,
TVMContext
):
raise
TypeError
(
"fallback_device is expected to be str "
+
"TVMContext, or dict of device name to target, "
+
"but received: {}"
.
format
(
type
(
fallback_device
)))
raise
TypeError
(
"fallback_device is expected to be str, int, or "
+
"TVMContext but received: {}"
.
format
(
type
(
fallback_device
)))
self
.
_set_fallback_device
(
fallback_device
.
device_type
)
...
...
python/tvm/relay/quantize/quantize.py
View file @
138ec7be
...
...
@@ -22,7 +22,7 @@ import numpy as np
from
.
import
_quantize
from
..
import
expr
as
_expr
from
..
import
ir_pass
as
_ir_pass
from
..
import
build_module
as
_build
from
..
import
transform
as
_transform
from
..
import
op
as
_op
from
...
import
make
as
_make
from
..base
import
NodeBase
,
register_relay_node
...
...
@@ -301,7 +301,7 @@ def optimize(func, params=None):
"FoldConstant"
,
"CanonicalizeOps"
]
cfg
=
_
build
.
build_config
(
ad
d_pass
=
opt_passes
)
cfg
=
_
transform
.
build_config
(
require
d_pass
=
opt_passes
)
if
params
:
name_dict
=
{}
...
...
@@ -321,25 +321,25 @@ def optimize(func, params=None):
bind_dict
[
arg
]
=
_expr
.
const
(
v
)
func
=
_expr
.
bind
(
func
,
bind_dict
)
if
"SimplifyInference"
in
cfg
.
ad
d_pass
:
if
"SimplifyInference"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
simplify_inference
(
func
)
if
"FoldConstant"
in
cfg
.
ad
d_pass
:
if
"FoldConstant"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
fold_constant
(
func
)
if
"FoldScaleAxis"
in
cfg
.
ad
d_pass
:
if
"FoldScaleAxis"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
backward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
forward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
fold_constant
(
func
)
if
"CanonicalizeOps"
in
cfg
.
ad
d_pass
:
if
"CanonicalizeOps"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
canonicalize_ops
(
func
)
if
"FoldConstant"
in
cfg
.
ad
d_pass
:
if
"FoldConstant"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
fold_constant
(
func
)
return
func
...
...
python/tvm/relay/transform.py
View file @
138ec7be
...
...
@@ -23,8 +23,10 @@ conveniently.
"""
import
types
from
tvm._ffi.runtime_ctypes
import
TVMContext
from
.
import
_transform
from
.base
import
RelayNode
,
register_relay_node
from
..
import
nd
as
_nd
@register_relay_node
...
...
@@ -57,10 +59,102 @@ class PassContext(RelayNode):
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
opt_level : Optional[int]
The optimization level of this pass.
fallback_device : Optional[Union[int, str, TVMContext]]
The fallback device type. It is also used as the default device for
operators that are not annotated during heterogeneous execution.
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
"""
def
__init__
(
self
,
opt_level
=
2
,
fallback_device
=
_nd
.
cpu
(),
required_pass
=
None
,
disabled_pass
=
None
):
if
isinstance
(
fallback_device
,
str
):
fallback_device
=
_nd
.
context
(
fallback_device
)
.
device_type
elif
isinstance
(
fallback_device
,
TVMContext
):
fallback_device
=
fallback_device
.
device_type
if
not
isinstance
(
fallback_device
,
int
):
raise
TypeError
(
"required_pass is expected to be the type of "
+
"int/str/TVMContext."
)
required
=
list
(
required_pass
)
if
required_pass
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"required_pass is expected to be the type of "
+
"list/tuple/set."
)
disabled
=
list
(
disabled_pass
)
if
disabled_pass
else
[]
if
not
isinstance
(
disabled
,
(
list
,
tuple
)):
raise
TypeError
(
"disabled_pass is expected to be the type of "
+
"list/tuple/set."
)
self
.
__init_handle_by_constructor__
(
_transform
.
PassContext
,
opt_level
,
fallback_device
,
required
,
disabled
)
def
__enter__
(
self
):
_transform
.
EnterPassContext
(
self
)
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
_transform
.
ExitPassContext
(
self
)
@staticmethod
def
current
():
"""Return the current pass context."""
return
_transform
.
GetCurrentPassContext
()
def
build_config
(
opt_level
=
2
,
fallback_device
=
_nd
.
cpu
(),
required_pass
=
None
,
disabled_pass
=
None
):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, optional
Optimization level. The optimization pass name and level are as the
following:
.. code-block:: python
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
def
__init__
(
self
):
self
.
__init_handle_by_constructor__
(
_transform
.
PassContext
)
fallback_device : int, str, or tvm.TVMContext, optional
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.
Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return
PassContext
(
opt_level
,
fallback_device
,
required_pass
,
disabled_pass
)
@register_relay_node
...
...
@@ -70,20 +164,6 @@ class Pass(RelayNode):
conveniently interact with the base class.
"""
def
set_pass_context
(
self
,
pass_ctx
):
"""Setup the pass context for analysis and optimizations. This context
could be shared by different passes for sequential passes.
Parameters
----------
pass_ctx : PassContext
The context that is used to help perform a certain pass or a series
of passes.
"""
if
not
isinstance
(
pass_ctx
,
PassContext
):
raise
TypeError
(
"pass_ctx is expected to be the PassContext type"
)
_transform
.
SetContext
(
self
,
pass_ctx
)
@property
def
info
(
self
):
"""Get the pass meta."""
...
...
@@ -150,32 +230,23 @@ class Sequential(Pass):
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
disabled : Optional[List[str]]
A list of disabled passes.
"""
def
__init__
(
self
,
passes
=
None
,
opt_level
=
2
,
name
=
"sequential"
,
required
=
None
,
disabled
=
None
):
required
=
None
):
passes
=
passes
if
passes
else
[]
if
not
isinstance
(
passes
,
(
list
,
tuple
)):
raise
TypeError
(
"passes must be a list of Pass objects."
)
disabled
=
disabled
if
disabled
else
[]
if
not
isinstance
(
disabled
,
(
list
,
tuple
)):
raise
TypeError
(
"disabled must be a list or tuple of pass names"
)
required
=
required
if
required
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"Required is expected to be the type of list/tuple."
)
self
.
__init_handle_by_constructor__
(
_transform
.
Sequential
,
passes
,
opt_level
,
name
,
required
,
disabled
)
passes
,
opt_level
,
name
,
required
)
def
module_pass
(
pass_func
=
None
,
opt_level
=
None
,
name
=
None
,
required
=
None
):
...
...
src/relay/pass/pass_manager.cc
View file @
138ec7be
...
...
@@ -22,8 +22,14 @@
* \file src/relay/pass/pass_manager.cc
* \brief Relay pass manager implementation.
*/
#include <dmlc/thread_local.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/device_api.h>
#include <algorithm>
#include <stack>
#include <unordered_set>
namespace
tvm
{
namespace
relay
{
...
...
@@ -31,6 +37,98 @@ namespace transform {
using
tvm
::
IRPrinter
;
/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*/
class
OptPassLevel
{
public
:
/*!
* \brief Get level for an optimization pass
*
* \param key pass name
* \return int level
*/
int
operator
[](
const
std
::
string
&
key
)
const
{
const
auto
data
=
CreateMap
();
auto
it
=
data
.
find
(
key
);
if
(
it
==
data
.
end
())
{
return
-
1
;
}
return
it
->
second
;
}
private
:
static
const
std
::
unordered_map
<
std
::
string
,
int
>
CreateMap
()
{
const
std
::
unordered_map
<
std
::
string
,
int
>
m
=
{
{
"SimplifyInference"
,
0
},
{
"OpFusion"
,
1
},
{
"FoldConstant"
,
2
},
{
"CombineParallelConv2D"
,
3
},
{
"FoldScaleAxis"
,
3
},
{
"AlterOpLayout"
,
3
},
{
"CanonicalizeOps"
,
3
},
{
"EliminateCommonSubexpr"
,
3
}
};
return
m
;
}
};
PassContext
::
PassContext
(
int
opt_level
,
int
fallback_device
,
tvm
::
Array
<
tvm
::
Expr
>
required_pass
,
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
)
{
auto
ctx
=
make_node
<
PassContextNode
>
();
ctx
->
opt_level
=
opt_level
;
ctx
->
fallback_device
=
fallback_device
;
ctx
->
required_pass
=
std
::
move
(
required_pass
);
ctx
->
disabled_pass
=
std
::
move
(
disabled_pass
);
node_
=
std
::
move
(
ctx
);
}
const
PassContextNode
*
PassContext
::
operator
->
()
const
{
return
static_cast
<
const
PassContextNode
*>
(
node_
.
get
());
}
struct
RelayPassContextThreadLocalEntry
{
/*! \brief The default pass context. */
PassContext
default_context
;
/*! \brief The current pass context. */
std
::
stack
<
PassContext
>
context_stack
;
RelayPassContextThreadLocalEntry
()
{
default_context
=
PassContext
(
make_node
<
PassContextNode
>
());
}
};
/*! \brief Thread local store to hold the pass context. */
typedef
dmlc
::
ThreadLocalStore
<
RelayPassContextThreadLocalEntry
>
RelayPassContextThreadLocalStore
;
void
PassContext
::
EnterWithScope
()
{
RelayPassContextThreadLocalEntry
*
entry
=
RelayPassContextThreadLocalStore
::
Get
();
entry
->
context_stack
.
push
(
*
this
);
}
void
PassContext
::
ExitWithScope
()
{
RelayPassContextThreadLocalEntry
*
entry
=
RelayPassContextThreadLocalStore
::
Get
();
CHECK
(
!
entry
->
context_stack
.
empty
());
CHECK
(
entry
->
context_stack
.
top
().
same_as
(
*
this
));
entry
->
context_stack
.
pop
();
}
PassContext
PassContext
::
Current
()
{
RelayPassContextThreadLocalEntry
*
entry
=
RelayPassContextThreadLocalStore
::
Get
();
if
(
!
entry
->
context_stack
.
empty
())
{
return
entry
->
context_stack
.
top
();
}
else
{
return
entry
->
default_context
;
}
}
class
ModulePass
;
/*!
...
...
@@ -58,38 +156,26 @@ class ModulePassNode : public PassNode {
}
/*!
* \brief Run a module pass on
a certain module
.
* \brief Run a module pass on
given pass context
.
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
Module
operator
()(
const
Module
&
mod
)
const
final
;
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
final
;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo
Info
()
const
{
return
pass_info
;
}
/*!
* \brief Set the context information for a module pass.
*
* \param pass_ctx The context information for a module pass.
*/
void
SetContext
(
const
PassContext
&
pass_ctx
)
final
;
TVM_DLL
static
ModulePass
make
(
runtime
::
TypedPackedFunc
<
Module
(
Module
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
);
static
constexpr
const
char
*
_type_key
=
"relay.ModulePass"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ModulePassNode
,
PassNode
);
private
:
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext
pass_ctx_
;
};
RELAY_DEFINE_NODE_REF
(
ModulePass
,
ModulePassNode
,
Pass
);
...
...
@@ -124,26 +210,20 @@ class FunctionPassNode : public PassNode {
}
/*!
* \brief Run a function pass on
a certain module
.
* \brief Run a function pass on
given pass context
.
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
Module
operator
()(
const
Module
&
mod
)
const
final
;
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
final
;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo
Info
()
const
{
return
pass_info
;
}
/*!
* \brief Set the context information for a function-level pass.
*
* \param pass_ctx The context information for a function-level pass.
*/
void
SetContext
(
const
PassContext
&
pass_ctx
)
final
;
TVM_DLL
static
FunctionPass
make
(
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
);
...
...
@@ -160,11 +240,6 @@ class FunctionPassNode : public PassNode {
* \return Return true if the function will be skipped, otherwise false.
*/
bool
SkipFunction
(
const
Function
&
func
)
const
;
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext
pass_ctx_
;
};
RELAY_DEFINE_NODE_REF
(
FunctionPass
,
FunctionPassNode
,
Pass
);
...
...
@@ -182,18 +257,17 @@ class SequentialNode : public PassNode {
/* \brief The pass meta data.*/
PassInfo
pass_info
;
/*! \brief A list of passes that used to compose a sequential pass. */
tvm
::
Array
<
Pass
>
passes
;
/*!
* \brief A
list of disabled passes that should be excluded when executing the
*
sequential pass
.
* \brief A
helper struct to get the optimization pass name to opt level
*
mapping
.
*/
tvm
::
Array
<
tvm
::
Expr
>
disabled
;
OptPassLevel
opt_pass_level
;
/*! \brief A list of passes that used to compose a sequential pass. */
tvm
::
Array
<
Pass
>
passes
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"pass_info"
,
&
pass_info
);
v
->
Visit
(
"passes"
,
&
passes
);
v
->
Visit
(
"disabled"
,
&
disabled
);
}
/*!
...
...
@@ -211,6 +285,15 @@ class SequentialNode : public PassNode {
}
/*!
* \brief Check if a pass is enabled.
*
* \param pass_name The name of an optimization/analysis pass.
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool
pass_enabled
(
const
std
::
string
&
pass_name
)
const
;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*
...
...
@@ -224,7 +307,11 @@ class SequentialNode : public PassNode {
*/
void
ResolveDependency
(
const
Module
&
mod
);
TVM_DLL
std
::
vector
<
std
::
string
>
DisabledPasses
()
const
;
std
::
unordered_set
<
std
::
string
>
DisabledPasses
(
const
Array
<
tvm
::
Expr
>&
disabled
)
const
;
std
::
unordered_set
<
std
::
string
>
RequiredPasses
(
const
Array
<
tvm
::
Expr
>&
disabled
)
const
;
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
...
...
@@ -232,27 +319,15 @@ class SequentialNode : public PassNode {
* be overloaded to focus on different metrics, i.e. performance,
* memory footprint, etc.
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that these passes are applied on.
* \param pass_ctx The context that these passes execute on.
*
* \return Return the updated module.
*/
Module
operator
()(
const
Module
&
mod
)
const
final
;
/*!
* \brief Set the context information for a sequential pass.
*
* \param pass_ctx The context information for a sequential pass.
*/
void
SetContext
(
const
PassContext
&
pass_ctx
)
final
;
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
final
;
static
constexpr
const
char
*
_type_key
=
"relay.Sequential"
;
TVM_DECLARE_NODE_TYPE_INFO
(
SequentialNode
,
PassNode
);
private
:
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext
pass_ctx_
;
};
PassInfo
PassInfoNode
::
make
(
int
opt_level
,
std
::
string
name
,
...
...
@@ -264,11 +339,6 @@ PassInfo PassInfoNode::make(int opt_level, std::string name,
return
PassInfo
(
pass_info
);
}
PassContext
PassContextNode
::
make
()
{
auto
ctx
=
make_node
<
PassContextNode
>
();
return
PassContext
(
ctx
);
}
ModulePass
ModulePassNode
::
make
(
runtime
::
TypedPackedFunc
<
Module
(
Module
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
)
{
...
...
@@ -279,23 +349,19 @@ ModulePass ModulePassNode::make(
}
// Module -> Module optimizations.
// TODO(zhiics) 1. Check and handle the required passes.
// 2. Probably use CoW for all places that use module instead of
// returning the updated one.
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
)
const
{
// TODO(zhiics) Check and handle the required passes.
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
LOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
.
operator
->
()
->
name
<<
" with opt level: "
<<
pass_info
.
operator
->
()
->
opt_level
<<
"
\n
"
;
CHECK
(
mod
.
defined
());
auto
updated_mod
=
pass_func
(
mod
,
pass_ctx
_
);
auto
updated_mod
=
pass_func
(
mod
,
pass_ctx
);
CHECK
(
updated_mod
.
defined
());
return
updated_mod
;
}
void
ModulePassNode
::
SetContext
(
const
PassContext
&
pass_ctx
)
{
pass_ctx_
=
pass_ctx
;
}
FunctionPass
FunctionPassNode
::
make
(
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
)
{
...
...
@@ -307,31 +373,22 @@ FunctionPass FunctionPassNode::make(
// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
)
const
{
Module
FunctionPassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
LOG
(
INFO
)
<<
"Executing function pass : "
<<
pass_info
.
operator
->
()
->
name
<<
" with opt level: "
<<
pass_info
.
operator
->
()
->
opt_level
<<
"
\n
"
;
CHECK
(
mod
.
defined
());
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>>
updated_funcs
;
ModuleNode
*
mod_node
=
mod
.
operator
->
();
for
(
const
auto
&
it
:
mod_node
->
functions
)
{
if
(
!
SkipFunction
(
it
.
second
))
{
auto
updated_func
=
pass_func
(
it
.
second
,
pass_ctx_
);
CHECK
(
updated_func
.
defined
());
updated_funcs
.
push_back
({
std
::
move
(
it
.
first
),
std
::
move
(
updated_func
)});
}
}
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
// Update the optimized functions.
for
(
const
auto
&
it
:
updated_funcs
)
{
mod_node
->
Update
(
it
.
first
,
it
.
second
);
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
}
return
GetRef
<
Module
>
(
mod_node
);
}
void
FunctionPassNode
::
SetContext
(
const
PassContext
&
pass_ctx
)
{
pass_ctx_
=
pass_ctx
;
return
new_mod
;
}
// TODO(zhiics) Create an enum attribute for FunctionNode
...
...
@@ -342,31 +399,23 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
return
pval
&&
pval
->
value
!=
0
;
}
Sequential
::
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
PassInfo
pass_info
,
tvm
::
Array
<
tvm
::
Expr
>
disabled
)
{
Sequential
::
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
PassInfo
pass_info
)
{
auto
n
=
make_node
<
SequentialNode
>
();
n
->
passes
=
std
::
move
(
passes
);
n
->
pass_info
=
std
::
move
(
pass_info
);
n
->
disabled
=
std
::
move
(
disabled
);
node_
=
std
::
move
(
n
);
}
const
SequentialNode
*
Sequential
::
operator
->
()
const
{
return
static_cast
<
const
SequentialNode
*>
(
this
->
node_
.
get
());
Sequential
::
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
std
::
string
name
)
{
auto
n
=
make_node
<
SequentialNode
>
();
n
->
passes
=
std
::
move
(
passes
);
PassInfo
pass_info
=
PassInfoNode
::
make
(
2
,
std
::
move
(
name
),
{});
n
->
pass_info
=
std
::
move
(
pass_info
);
node_
=
std
::
move
(
n
);
}
// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
Module
SequentialNode
::
operator
()(
const
Module
&
module
)
const
{
Module
mod
=
module
;
for
(
const
Pass
&
pass
:
passes
)
{
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
const
auto
*
pn
=
pass
.
operator
->
();
mod
=
(
*
pn
)(
mod
);
}
return
mod
;
const
SequentialNode
*
Sequential
::
operator
->
()
const
{
return
static_cast
<
const
SequentialNode
*>
(
this
->
node_
.
get
());
}
void
SequentialNode
::
ResolveDependency
(
const
Module
&
mod
)
{
...
...
@@ -378,18 +427,68 @@ void SequentialNode::ResolveDependency(const Module& mod) {
<<
"
\n
"
;
}
std
::
vector
<
std
::
string
>
SequentialNode
::
DisabledPasses
()
const
{
std
::
vector
<
std
::
string
>
ret
;
std
::
unordered_set
<
std
::
string
>
SequentialNode
::
DisabledPasses
(
const
Array
<
tvm
::
Expr
>&
disabled
)
const
{
std
::
unordered_set
<
std
::
string
>
ret
;
for
(
const
auto
&
it
:
disabled
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"disabled passes must be string."
;
ret
.
push_back
(
str
->
value
);
ret
.
emplace
(
str
->
value
);
}
return
ret
;
}
void
SequentialNode
::
SetContext
(
const
PassContext
&
pass_ctx
)
{
pass_ctx_
=
pass_ctx
;
std
::
unordered_set
<
std
::
string
>
SequentialNode
::
RequiredPasses
(
const
Array
<
tvm
::
Expr
>&
required
)
const
{
std
::
unordered_set
<
std
::
string
>
ret
;
for
(
const
auto
&
it
:
required
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"disabled passes must be string."
;
ret
.
emplace
(
str
->
value
);
}
return
ret
;
}
bool
SequentialNode
::
pass_enabled
(
const
std
::
string
&
pass_name
)
const
{
PassContext
ctx
=
PassContext
::
Current
();
const
PassContextNode
*
ctx_node
=
ctx
.
operator
->
();
auto
required
=
RequiredPasses
(
ctx_node
->
required_pass
);
auto
disabled
=
DisabledPasses
(
ctx_node
->
required_pass
);
if
(
disabled
.
count
(
pass_name
))
{
return
false
;
}
if
(
required
.
count
(
pass_name
))
{
return
true
;
}
return
ctx_node
->
opt_level
>=
opt_pass_level
[
pass_name
];
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
Module
SequentialNode
::
operator
()(
const
Module
&
module
,
const
PassContext
&
pass_ctx
)
const
{
const
auto
*
ctx_node
=
pass_ctx
.
operator
->
();
int
opt_level
=
ctx_node
->
opt_level
;
auto
disabled
=
DisabledPasses
(
ctx_node
->
disabled_pass
);
Module
mod
=
module
;
for
(
const
Pass
&
pass
:
passes
)
{
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
PassInfo
info
=
pass
->
Info
();
const
auto
&
pass_name
=
info
.
operator
->
()
->
name
;
const
auto
&
pass_opt_level
=
info
.
operator
->
()
->
opt_level
;
// Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled.
if
(
pass_opt_level
>
opt_level
||
disabled
.
count
(
pass_name
))
{
continue
;
}
const
auto
*
pn
=
pass
.
operator
->
();
mod
=
(
*
pn
)(
mod
,
pass_ctx
);
}
return
mod
;
}
Pass
CreateModulePass
(
...
...
@@ -481,9 +580,8 @@ TVM_REGISTER_API("relay._transform.Sequential")
int
opt_level
=
args
[
1
];
std
::
string
name
=
args
[
2
];
tvm
::
Array
<
tvm
::
Expr
>
required
=
args
[
3
];
tvm
::
Array
<
tvm
::
Expr
>
disabled
=
args
[
4
];
PassInfo
pass_info
=
PassInfoNode
::
make
(
opt_level
,
name
,
required
);
*
ret
=
Sequential
(
passes
,
pass_info
,
disabled
);
*
ret
=
Sequential
(
passes
,
pass_info
);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
...
...
@@ -501,26 +599,58 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"]"
;
});
TVM_REGISTER_API
(
"relay._transform.SetContext"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Pass
pass
=
args
[
0
];
PassContext
pass_ctx
=
args
[
1
];
pass
->
SetContext
(
pass_ctx
);
});
TVM_REGISTER_NODE_TYPE
(
PassContextNode
);
TVM_REGISTER_API
(
"relay._transform.PassContext"
)
.
set_body_typed
(
PassContextNode
::
make
);
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
int
opt_level
=
args
[
0
];
int
fallback_device
=
args
[
1
];
tvm
::
Array
<
tvm
::
Expr
>
required
=
args
[
2
];
tvm
::
Array
<
tvm
::
Expr
>
disabled
=
args
[
3
];
*
ret
=
PassContext
(
opt_level
,
fallback_device
,
required
,
disabled
);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
PassContextNode
>
([](
const
PassContextNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"TODO(zhiics): printing context"
;
LOG
(
FATAL
)
<<
"PassContext printer has not been implemented yet."
p
->
stream
<<
"Pass context information: "
<<
"
\n
"
;
p
->
stream
<<
"
\t
opt_level: "
<<
node
->
opt_level
<<
"
\n
"
;
p
->
stream
<<
"
\t
fallback device: "
<<
runtime
::
DeviceName
(
node
->
opt_level
)
<<
"
\n
"
;
p
->
stream
<<
"
\t
required passes: ["
<<
node
->
opt_level
;
for
(
const
auto
&
it
:
node
->
required_pass
)
{
p
->
stream
<<
it
<<
" "
;
}
p
->
stream
<<
"]
\n
"
;
p
->
stream
<<
"
\t
disabled passes: ["
<<
node
->
opt_level
;
for
(
const
auto
&
it
:
node
->
disabled_pass
)
{
p
->
stream
<<
it
<<
" "
;
}
p
->
stream
<<
"]"
;
});
class
PassContext
::
Internal
{
public
:
static
void
EnterScope
(
PassContext
pass_ctx
)
{
pass_ctx
.
EnterWithScope
();
}
static
void
ExitScope
(
PassContext
pass_ctx
)
{
pass_ctx
.
ExitWithScope
();
}
};
TVM_REGISTER_API
(
"relay._transform.GetCurrentPassContext"
)
.
set_body_typed
(
PassContext
::
Current
);
TVM_REGISTER_API
(
"relay._transform.EnterPassContext"
)
.
set_body_typed
(
PassContext
::
Internal
::
EnterScope
);
TVM_REGISTER_API
(
"relay._transform.ExitPassContext"
)
.
set_body_typed
(
PassContext
::
Internal
::
ExitScope
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
tests/python/frontend/coreml/test_forward.py
View file @
138ec7be
...
...
@@ -31,7 +31,7 @@ import model_zoo
def
get_tvm_output
(
func
,
x
,
params
,
target
,
ctx
,
out_shape
=
(
1
,
1000
),
input_name
=
'image'
,
dtype
=
'float32'
):
with
relay
.
build_module
.
build_config
(
opt_level
=
3
):
with
relay
.
transform
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# set inputs
...
...
@@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
dtype_dict
=
{
input_name
:
input_data
.
dtype
}
func
,
params
=
relay
.
frontend
.
from_coreml
(
coreml_model
,
shape_dict
)
with
relay
.
build_module
.
build_config
(
opt_level
=
3
):
with
relay
.
transform
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
from
tvm.contrib
import
graph_runtime
...
...
tests/python/frontend/keras/test_forward.py
View file @
138ec7be
...
...
@@ -43,7 +43,7 @@ def verify_keras_frontend(keras_model, need_transpose=True):
def
get_tvm_output
(
xs
,
target
,
ctx
,
dtype
=
'float32'
):
shape_dict
=
{
name
:
x
.
shape
for
(
name
,
x
)
in
zip
(
keras_model
.
input_names
,
xs
)}
func
,
params
=
relay
.
frontend
.
from_keras
(
keras_model
,
shape_dict
)
with
relay
.
build_module
.
build_config
(
opt_level
=
2
):
with
relay
.
transform
.
build_config
(
opt_level
=
2
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
for
name
,
x
in
zip
(
keras_model
.
input_names
,
xs
):
...
...
tutorials/frontend/from_tflite.py
View file @
138ec7be
...
...
@@ -144,7 +144,7 @@ func, params = relay.frontend.from_tflite(tflite_model,
# target x86 CPU
target
=
"llvm"
with
relay
.
build_module
.
build_config
(
opt_level
=
3
):
with
relay
.
transform
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
######################################################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment