Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
138ec7be
Commit
138ec7be
authored
May 24, 2019
by
Zhi
Committed by
Tianqi Chen
May 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Transform] merge PassContext and BuildConfig (#3234)
parent
415a270d
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
250 additions
and
143 deletions
+250
-143
docs/api/python/relay/build_module.rst
+0
-8
docs/api/python/relay/transform.rst
+45
-0
include/tvm/relay/transform.h
+81
-11
python/tvm/relay/__init__.py
+2
-1
python/tvm/relay/build_module.py
+13
-85
python/tvm/relay/quantize/quantize.py
+7
-7
python/tvm/relay/transform.py
+98
-27
src/relay/pass/pass_manager.cc
+0
-0
tests/python/frontend/coreml/test_forward.py
+2
-2
tests/python/frontend/keras/test_forward.py
+1
-1
tutorials/frontend/from_tflite.py
+1
-1
No files found.
docs/api/python/relay/build_module.rst
View file @
138ec7be
...
@@ -22,17 +22,9 @@ tvm.relay.build_module
...
@@ -22,17 +22,9 @@ tvm.relay.build_module
.. autofunction:: tvm.relay.build_module.build
.. autofunction:: tvm.relay.build_module.build
.. autofunction:: tvm.relay.build_module.build_config
.. autofunction:: tvm.relay.build_module.optimize
.. autofunction:: tvm.relay.build_module.optimize
.. autofunction:: tvm.relay.build_module.create_executor
.. autofunction:: tvm.relay.build_module.create_executor
.. autoclass:: tvm.relay.build_module.BuildConfig
:members:
.. autofunction:: tvm.relay.build_module.build_config
:members:
.. autoclass:: tvm.relay.build_module.GraphExecutor
.. autoclass:: tvm.relay.build_module.GraphExecutor
:members:
:members:
docs/api/python/relay/transform.rst
0 → 100644
View file @
138ec7be
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.transform
----------------------
.. automodule:: tvm.relay.transform
.. autofunction:: tvm.relay.transform.build_config
.. autofunction:: tvm.relay.transform.module_pass
.. autofunction:: tvm.relay.transform.function_pass
.. autoclass:: tvm.relay.transform.Pass
:members:
.. autoclass:: tvm.relay.transform.PassInfo
:members:
.. autoclass:: tvm.relay.transform.PassContext
:members:
.. autoclass:: tvm.relay.transform.ModulePass
:members:
.. autoclass:: tvm.relay.transform.FunctionPass
:members:
.. autoclass:: tvm.relay.transform.Sequential
:members:
include/tvm/relay/transform.h
View file @
138ec7be
...
@@ -56,11 +56,13 @@
...
@@ -56,11 +56,13 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/module.h>
#include <string>
#include <string>
#include <unordered_map>
#include <vector>
#include <vector>
namespace
tvm
{
namespace
tvm
{
...
@@ -83,18 +85,69 @@ class PassContextNode : public RelayNode {
...
@@ -83,18 +85,69 @@ class PassContextNode : public RelayNode {
*/
*/
ErrorReporter
err_reporter
;
ErrorReporter
err_reporter
;
/*! \brief The default optimization level. */
int
opt_level
{
2
};
/*! \brief CPU is the default fallback device for heterogeneous execution. */
int
fallback_device
{
static_cast
<
int
>
(
kDLCPU
)};
/*! \brief The list of required passes. */
tvm
::
Array
<
tvm
::
Expr
>
required_pass
;
/*! \brief The list of disabled passes. */
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
;
PassContextNode
()
=
default
;
PassContextNode
()
=
default
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"opt_level"
,
&
opt_level
);
v
->
Visit
(
"fallback_device"
,
&
fallback_device
);
v
->
Visit
(
"required_pass"
,
&
required_pass
);
v
->
Visit
(
"disabled_pass"
,
&
disabled_pass
);
}
}
TVM_DLL
static
PassContext
make
();
static
constexpr
const
char
*
_type_key
=
"relay.PassContext"
;
static
constexpr
const
char
*
_type_key
=
"relay.PassContext"
;
TVM_DECLARE_NODE_TYPE_INFO
(
PassContextNode
,
RelayNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
PassContextNode
,
RelayNode
);
};
};
TVM_DEFINE_NODE_REF
(
PassContext
,
PassContextNode
)
class
PassContext
:
public
NodeRef
{
public
:
PassContext
()
{}
explicit
PassContext
(
tvm
::
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
*/
TVM_DLL
PassContext
(
int
opt_level
,
int
fallback_device
,
tvm
::
Array
<
tvm
::
Expr
>
required_pass
,
tvm
::
Array
<
tvm
::
Expr
>
disabled_pass
);
// Get the currently used pass context.
TVM_DLL
static
PassContext
Current
();
const
PassContextNode
*
operator
->
()
const
;
using
ContainerType
=
PassContextNode
;
class
Internal
;
private
:
// The entry of a pass context scope.
TVM_DLL
void
EnterWithScope
();
// The exit of a pass context scope.
TVM_DLL
void
ExitWithScope
();
// Classes to get the Python `with` like syntax.
friend
class
Internal
;
friend
class
tvm
::
With
<
PassContext
>
;
};
/*
/*
* \brief The meta data of a pass.
* \brief The meta data of a pass.
...
@@ -150,20 +203,28 @@ class PassNode : public RelayNode {
...
@@ -150,20 +203,28 @@ class PassNode : public RelayNode {
virtual
PassInfo
Info
()
const
=
0
;
virtual
PassInfo
Info
()
const
=
0
;
/*!
/*!
* \brief Set the context information for a pass.
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
*
* \param mod The module that an optimization pass runs on.
*
*
* \
param pass_ctx The context information for a certain pass
.
* \
return The updated module
.
*/
*/
virtual
void
SetContext
(
const
PassContext
&
pass_ctx
)
=
0
;
Module
operator
()(
const
Module
&
mod
)
const
{
return
this
->
operator
()(
mod
,
PassContext
::
Current
());
}
/*!
/*!
* \brief Execute the optimization pass using a functor.
* \brief Execute the optimization pass using a functor
under a given pass context
.
*
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
*
*
* \return The updated module.
* \return The updated module.
*/
*/
virtual
Module
operator
()(
const
Module
&
mod
)
const
=
0
;
virtual
Module
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
=
0
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
override
{}
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
override
{}
...
@@ -189,13 +250,22 @@ class Sequential : public Pass {
...
@@ -189,13 +250,22 @@ class Sequential : public Pass {
public
:
public
:
/*!
/*!
* \brief The constructor of `Sequential`.
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
*/
TVM_DLL
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
TVM_DLL
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
PassInfo
pass_info
,
PassInfo
pass_info
);
tvm
::
Array
<
tvm
::
Expr
>
disabled
);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
std
::
string
name
=
"sequential"
);
Sequential
()
=
default
;
Sequential
()
=
default
;
explicit
Sequential
(
tvm
::
NodePtr
<::
tvm
::
Node
>
n
)
:
Pass
(
n
)
{}
explicit
Sequential
(
tvm
::
NodePtr
<::
tvm
::
Node
>
n
)
:
Pass
(
n
)
{}
...
...
python/tvm/relay/__init__.py
View file @
138ec7be
...
@@ -26,7 +26,8 @@ from . import module
...
@@ -26,7 +26,8 @@ from . import module
from
.
import
adt
from
.
import
adt
from
.
import
ir_pass
from
.
import
ir_pass
from
.
import
transform
from
.
import
transform
from
.build_module
import
build
,
build_config
,
create_executor
from
.build_module
import
build
,
create_executor
from
.transform
import
build_config
from
.
import
prelude
from
.
import
prelude
from
.
import
parser
from
.
import
parser
from
.
import
debug
from
.
import
debug
...
...
python/tvm/relay/build_module.py
View file @
138ec7be
...
@@ -28,81 +28,10 @@ from . import _build_module
...
@@ -28,81 +28,10 @@ from . import _build_module
from
.
import
ir_pass
from
.
import
ir_pass
from
.
import
ty
as
_ty
from
.
import
ty
as
_ty
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
transform
as
_transform
from
.backend
import
interpreter
as
_interpreter
from
.backend
import
interpreter
as
_interpreter
from
.backend.vm
import
VMExecutor
from
.backend.vm
import
VMExecutor
class
BuildConfig
(
object
):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current
=
None
defaults
=
{
"opt_level"
:
2
,
"add_pass"
:
None
,
"disable_pass"
:
None
,
"fallback_device"
:
None
,
}
def
__init__
(
self
,
**
kwargs
):
self
.
_old_scope
=
None
for
k
,
_
in
kwargs
.
items
():
if
k
not
in
BuildConfig
.
defaults
:
raise
ValueError
(
"invalid argument
%
s, candidates are
%
s"
%
(
k
,
BuildConfig
.
defaults
.
keys
()))
self
.
_attr
=
kwargs
def
__getattr__
(
self
,
name
):
if
name
not
in
self
.
_attr
:
return
BuildConfig
.
defaults
[
name
]
return
self
.
_attr
[
name
]
def
__enter__
(
self
):
# pylint: disable=protected-access
self
.
_old_scope
=
BuildConfig
.
current
attr
=
BuildConfig
.
current
.
_attr
.
copy
()
attr
.
update
(
self
.
_attr
)
self
.
_attr
=
attr
BuildConfig
.
current
=
self
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
assert
self
.
_old_scope
BuildConfig
.
current
=
self
.
_old_scope
BuildConfig
.
current
=
BuildConfig
()
def
build_config
(
**
kwargs
):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
disable_pass: set of str
Optimization pass to be disabled during optimization.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
The build configuration
"""
return
BuildConfig
(
**
kwargs
)
def
_update_target
(
target
):
def
_update_target
(
target
):
target
=
target
if
target
else
_target
.
current_target
()
target
=
target
if
target
else
_target
.
current_target
()
if
target
is
None
:
if
target
is
None
:
...
@@ -189,7 +118,7 @@ class BuildModule(object):
...
@@ -189,7 +118,7 @@ class BuildModule(object):
return
graph_json
,
mod
,
params
return
graph_json
,
mod
,
params
def
_setup_build_config
(
self
,
params
):
def
_setup_build_config
(
self
,
params
):
cfg
=
BuildConfig
.
current
cfg
=
_transform
.
PassContext
.
current
()
# Set opt_level.
# Set opt_level.
self
.
set_opt_level
(
cfg
.
opt_level
)
self
.
set_opt_level
(
cfg
.
opt_level
)
...
@@ -199,24 +128,24 @@ class BuildModule(object):
...
@@ -199,24 +128,24 @@ class BuildModule(object):
self
.
set_fallback_device
(
cfg
.
fallback_device
)
self
.
set_fallback_device
(
cfg
.
fallback_device
)
# Add required passes.
# Add required passes.
if
cfg
.
ad
d_pass
:
if
cfg
.
require
d_pass
:
passes
=
set
()
passes
=
set
()
if
isinstance
(
cfg
.
ad
d_pass
,
(
list
,
tuple
,
set
)):
if
isinstance
(
cfg
.
require
d_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
ad
d_pass
)
passes
=
set
(
cfg
.
require
d_pass
)
else
:
else
:
raise
TypeError
(
"add_pass must be list, tuple, or set, but "
+
raise
TypeError
(
"add_pass must be list, tuple, or set, but "
+
"got {}"
.
format
(
type
(
cfg
.
ad
d_pass
)))
"got {}"
.
format
(
type
(
cfg
.
require
d_pass
)))
for
pass_name
in
passes
:
for
pass_name
in
passes
:
self
.
add_pass
(
pass_name
)
self
.
add_pass
(
pass_name
)
# Add disabled passes.
# Add disabled passes.
if
cfg
.
disable_pass
:
if
cfg
.
disable
d
_pass
:
passes
=
set
()
passes
=
set
()
if
isinstance
(
cfg
.
disable_pass
,
(
list
,
tuple
,
set
)):
if
isinstance
(
cfg
.
disable
d
_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
disable_pass
)
passes
=
set
(
cfg
.
disable
d
_pass
)
else
:
else
:
raise
TypeError
(
"disable_pass must be list, tuple, or set, "
+
raise
TypeError
(
"disable_pass must be list, tuple, or set, "
+
"but got {}"
.
format
(
type
(
cfg
.
disable_pass
)))
"but got {}"
.
format
(
type
(
cfg
.
disable
d
_pass
)))
for
pass_name
in
passes
:
for
pass_name
in
passes
:
self
.
disable_pass
(
pass_name
)
self
.
disable_pass
(
pass_name
)
...
@@ -287,12 +216,11 @@ class BuildModule(object):
...
@@ -287,12 +216,11 @@ class BuildModule(object):
fallback_device : str or tvm.TVMContext
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
The fallback device used for heterogeneous execution.
"""
"""
if
isinstance
(
fallback_device
,
str
):
if
isinstance
(
fallback_device
,
(
int
,
str
)
):
fallback_device
=
_nd
.
context
(
fallback_device
)
fallback_device
=
_nd
.
context
(
fallback_device
)
if
not
isinstance
(
fallback_device
,
TVMContext
):
if
not
isinstance
(
fallback_device
,
TVMContext
):
raise
TypeError
(
"fallback_device is expected to be str "
+
raise
TypeError
(
"fallback_device is expected to be str, int, or "
+
"TVMContext, or dict of device name to target, "
+
"TVMContext but received: {}"
.
format
(
type
(
fallback_device
)))
"but received: {}"
.
format
(
type
(
fallback_device
)))
self
.
_set_fallback_device
(
fallback_device
.
device_type
)
self
.
_set_fallback_device
(
fallback_device
.
device_type
)
...
...
python/tvm/relay/quantize/quantize.py
View file @
138ec7be
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
from
.
import
_quantize
from
.
import
_quantize
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
ir_pass
as
_ir_pass
from
..
import
ir_pass
as
_ir_pass
from
..
import
build_module
as
_build
from
..
import
transform
as
_transform
from
..
import
op
as
_op
from
..
import
op
as
_op
from
...
import
make
as
_make
from
...
import
make
as
_make
from
..base
import
NodeBase
,
register_relay_node
from
..base
import
NodeBase
,
register_relay_node
...
@@ -301,7 +301,7 @@ def optimize(func, params=None):
...
@@ -301,7 +301,7 @@ def optimize(func, params=None):
"FoldConstant"
,
"FoldConstant"
,
"CanonicalizeOps"
]
"CanonicalizeOps"
]
cfg
=
_
build
.
build_config
(
ad
d_pass
=
opt_passes
)
cfg
=
_
transform
.
build_config
(
require
d_pass
=
opt_passes
)
if
params
:
if
params
:
name_dict
=
{}
name_dict
=
{}
...
@@ -321,25 +321,25 @@ def optimize(func, params=None):
...
@@ -321,25 +321,25 @@ def optimize(func, params=None):
bind_dict
[
arg
]
=
_expr
.
const
(
v
)
bind_dict
[
arg
]
=
_expr
.
const
(
v
)
func
=
_expr
.
bind
(
func
,
bind_dict
)
func
=
_expr
.
bind
(
func
,
bind_dict
)
if
"SimplifyInference"
in
cfg
.
ad
d_pass
:
if
"SimplifyInference"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
simplify_inference
(
func
)
func
=
_ir_pass
.
simplify_inference
(
func
)
if
"FoldConstant"
in
cfg
.
ad
d_pass
:
if
"FoldConstant"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
fold_constant
(
func
)
func
=
_ir_pass
.
fold_constant
(
func
)
if
"FoldScaleAxis"
in
cfg
.
ad
d_pass
:
if
"FoldScaleAxis"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
backward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
backward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
forward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
forward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
fold_constant
(
func
)
func
=
_ir_pass
.
fold_constant
(
func
)
if
"CanonicalizeOps"
in
cfg
.
ad
d_pass
:
if
"CanonicalizeOps"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
canonicalize_ops
(
func
)
func
=
_ir_pass
.
canonicalize_ops
(
func
)
if
"FoldConstant"
in
cfg
.
ad
d_pass
:
if
"FoldConstant"
in
cfg
.
require
d_pass
:
func
=
_ir_pass
.
fold_constant
(
func
)
func
=
_ir_pass
.
fold_constant
(
func
)
return
func
return
func
...
...
python/tvm/relay/transform.py
View file @
138ec7be
...
@@ -23,8 +23,10 @@ conveniently.
...
@@ -23,8 +23,10 @@ conveniently.
"""
"""
import
types
import
types
from
tvm._ffi.runtime_ctypes
import
TVMContext
from
.
import
_transform
from
.
import
_transform
from
.base
import
RelayNode
,
register_relay_node
from
.base
import
RelayNode
,
register_relay_node
from
..
import
nd
as
_nd
@register_relay_node
@register_relay_node
...
@@ -57,10 +59,102 @@ class PassContext(RelayNode):
...
@@ -57,10 +59,102 @@ class PassContext(RelayNode):
Each pass context contains a number of auxiliary information that is used
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
to record the errors of during the optimization, etc.
opt_level : Optional[int]
The optimization level of this pass.
fallback_device : Optional[Union[int, str, TVMContext]]
The fallback device type. It is also used as the default device for
operators that are not annotated during heterogeneous execution.
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
"""
"""
def
__init__
(
self
,
opt_level
=
2
,
fallback_device
=
_nd
.
cpu
(),
required_pass
=
None
,
disabled_pass
=
None
):
if
isinstance
(
fallback_device
,
str
):
fallback_device
=
_nd
.
context
(
fallback_device
)
.
device_type
elif
isinstance
(
fallback_device
,
TVMContext
):
fallback_device
=
fallback_device
.
device_type
if
not
isinstance
(
fallback_device
,
int
):
raise
TypeError
(
"required_pass is expected to be the type of "
+
"int/str/TVMContext."
)
required
=
list
(
required_pass
)
if
required_pass
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"required_pass is expected to be the type of "
+
"list/tuple/set."
)
disabled
=
list
(
disabled_pass
)
if
disabled_pass
else
[]
if
not
isinstance
(
disabled
,
(
list
,
tuple
)):
raise
TypeError
(
"disabled_pass is expected to be the type of "
+
"list/tuple/set."
)
self
.
__init_handle_by_constructor__
(
_transform
.
PassContext
,
opt_level
,
fallback_device
,
required
,
disabled
)
def
__enter__
(
self
):
_transform
.
EnterPassContext
(
self
)
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
_transform
.
ExitPassContext
(
self
)
@staticmethod
def
current
():
"""Return the current pass context."""
return
_transform
.
GetCurrentPassContext
()
def
build_config
(
opt_level
=
2
,
fallback_device
=
_nd
.
cpu
(),
required_pass
=
None
,
disabled_pass
=
None
):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, optional
Optimization level. The optimization pass name and level are as the
following:
.. code-block:: python
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
def
__init__
(
self
):
fallback_device : int, str, or tvm.TVMContext, optional
self
.
__init_handle_by_constructor__
(
_transform
.
PassContext
)
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.
Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return
PassContext
(
opt_level
,
fallback_device
,
required_pass
,
disabled_pass
)
@register_relay_node
@register_relay_node
...
@@ -70,20 +164,6 @@ class Pass(RelayNode):
...
@@ -70,20 +164,6 @@ class Pass(RelayNode):
conveniently interact with the base class.
conveniently interact with the base class.
"""
"""
def
set_pass_context
(
self
,
pass_ctx
):
"""Setup the pass context for analysis and optimizations. This context
could be shared by different passes for sequential passes.
Parameters
----------
pass_ctx : PassContext
The context that is used to help perform a certain pass or a series
of passes.
"""
if
not
isinstance
(
pass_ctx
,
PassContext
):
raise
TypeError
(
"pass_ctx is expected to be the PassContext type"
)
_transform
.
SetContext
(
self
,
pass_ctx
)
@property
@property
def
info
(
self
):
def
info
(
self
):
"""Get the pass meta."""
"""Get the pass meta."""
...
@@ -150,32 +230,23 @@ class Sequential(Pass):
...
@@ -150,32 +230,23 @@ class Sequential(Pass):
required : Optional[List[str]]
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
The list of passes that the sequential pass is dependent on.
disabled : Optional[List[str]]
A list of disabled passes.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
passes
=
None
,
passes
=
None
,
opt_level
=
2
,
opt_level
=
2
,
name
=
"sequential"
,
name
=
"sequential"
,
required
=
None
,
required
=
None
):
disabled
=
None
):
passes
=
passes
if
passes
else
[]
passes
=
passes
if
passes
else
[]
if
not
isinstance
(
passes
,
(
list
,
tuple
)):
if
not
isinstance
(
passes
,
(
list
,
tuple
)):
raise
TypeError
(
"passes must be a list of Pass objects."
)
raise
TypeError
(
"passes must be a list of Pass objects."
)
disabled
=
disabled
if
disabled
else
[]
if
not
isinstance
(
disabled
,
(
list
,
tuple
)):
raise
TypeError
(
"disabled must be a list or tuple of pass names"
)
required
=
required
if
required
else
[]
required
=
required
if
required
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"Required is expected to be the type of list/tuple."
)
raise
TypeError
(
"Required is expected to be the type of list/tuple."
)
self
.
__init_handle_by_constructor__
(
_transform
.
Sequential
,
self
.
__init_handle_by_constructor__
(
_transform
.
Sequential
,
passes
,
opt_level
,
name
,
required
,
passes
,
opt_level
,
name
,
required
)
disabled
)
def
module_pass
(
pass_func
=
None
,
opt_level
=
None
,
name
=
None
,
required
=
None
):
def
module_pass
(
pass_func
=
None
,
opt_level
=
None
,
name
=
None
,
required
=
None
):
...
...
src/relay/pass/pass_manager.cc
View file @
138ec7be
This diff is collapsed.
Click to expand it.
tests/python/frontend/coreml/test_forward.py
View file @
138ec7be
...
@@ -31,7 +31,7 @@ import model_zoo
...
@@ -31,7 +31,7 @@ import model_zoo
def
get_tvm_output
(
func
,
x
,
params
,
target
,
ctx
,
def
get_tvm_output
(
func
,
x
,
params
,
target
,
ctx
,
out_shape
=
(
1
,
1000
),
input_name
=
'image'
,
dtype
=
'float32'
):
out_shape
=
(
1
,
1000
),
input_name
=
'image'
,
dtype
=
'float32'
):
with
relay
.
build_module
.
build_config
(
opt_level
=
3
):
with
relay
.
transform
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# set inputs
# set inputs
...
@@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
...
@@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
dtype_dict
=
{
input_name
:
input_data
.
dtype
}
dtype_dict
=
{
input_name
:
input_data
.
dtype
}
func
,
params
=
relay
.
frontend
.
from_coreml
(
coreml_model
,
shape_dict
)
func
,
params
=
relay
.
frontend
.
from_coreml
(
coreml_model
,
shape_dict
)
with
relay
.
build_module
.
build_config
(
opt_level
=
3
):
with
relay
.
transform
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
from
tvm.contrib
import
graph_runtime
from
tvm.contrib
import
graph_runtime
...
...
tests/python/frontend/keras/test_forward.py
View file @
138ec7be
...
@@ -43,7 +43,7 @@ def verify_keras_frontend(keras_model, need_transpose=True):
...
@@ -43,7 +43,7 @@ def verify_keras_frontend(keras_model, need_transpose=True):
def
get_tvm_output
(
xs
,
target
,
ctx
,
dtype
=
'float32'
):
def
get_tvm_output
(
xs
,
target
,
ctx
,
dtype
=
'float32'
):
shape_dict
=
{
name
:
x
.
shape
for
(
name
,
x
)
in
zip
(
keras_model
.
input_names
,
xs
)}
shape_dict
=
{
name
:
x
.
shape
for
(
name
,
x
)
in
zip
(
keras_model
.
input_names
,
xs
)}
func
,
params
=
relay
.
frontend
.
from_keras
(
keras_model
,
shape_dict
)
func
,
params
=
relay
.
frontend
.
from_keras
(
keras_model
,
shape_dict
)
with
relay
.
build_module
.
build_config
(
opt_level
=
2
):
with
relay
.
transform
.
build_config
(
opt_level
=
2
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
for
name
,
x
in
zip
(
keras_model
.
input_names
,
xs
):
for
name
,
x
in
zip
(
keras_model
.
input_names
,
xs
):
...
...
tutorials/frontend/from_tflite.py
View file @
138ec7be
...
@@ -144,7 +144,7 @@ func, params = relay.frontend.from_tflite(tflite_model,
...
@@ -144,7 +144,7 @@ func, params = relay.frontend.from_tflite(tflite_model,
# target x86 CPU
# target x86 CPU
target
=
"llvm"
target
=
"llvm"
with
relay
.
build_module
.
build_config
(
opt_level
=
3
):
with
relay
.
transform
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
params
)
######################################################################
######################################################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment