Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
16fefd89
Commit
16fefd89
authored
Jul 29, 2019
by
雾雨魔理沙
Committed by
Thierry Moreau
Jul 29, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][VTA] Add ChangeBatch pass (#3656)
* init * lint * lint
parent
a88b2842
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
4 deletions
+72
-4
python/tvm/relay/expr_functor.py
+4
-3
python/tvm/relay/transform.py
+40
-1
tests/python/relay/test_change_batch.py
+28
-0
No files found.
python/tvm/relay/expr_functor.py
View file @
16fefd89
...
...
@@ -195,9 +195,10 @@ class ExprMutator(ExprFunctor):
and reconstructs the AST.
"""
def
visit_function
(
self
,
fn
):
new_params
=
[
self
.
visit
(
x
)
for
x
in
fn
.
params
]
new_body
=
self
.
visit
(
fn
.
body
)
return
Function
(
list
(
fn
.
params
),
list
(
new_
params
),
new_body
,
fn
.
ret_type
,
fn
.
type_params
,
...
...
@@ -214,8 +215,8 @@ class ExprMutator(ExprFunctor):
new_args
=
[
self
.
visit
(
arg
)
for
arg
in
call
.
args
]
return
Call
(
new_fn
,
new_args
,
call
.
attrs
)
def
visit_var
(
self
,
r
var
):
return
r
var
def
visit_var
(
self
,
var
):
return
var
def
visit_global_id
(
self
,
global_var
):
return
global_var
...
...
python/tvm/relay/transform.py
View file @
16fefd89
...
...
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name
,arguments-differ,no-else-return,unused-argument,missing-docstring
"""
Relay pass transformation infrastructure.
"""
...
...
@@ -22,7 +22,9 @@ import types
import
inspect
import
functools
import
tvm
from
tvm._ffi.runtime_ctypes
import
TVMContext
from
tvm
import
relay
from
.
import
_transform
from
.base
import
RelayNode
,
register_relay_node
from
..
import
nd
as
_nd
...
...
@@ -908,3 +910,40 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
if
pass_func
:
return
create_function_pass
(
pass_func
)
return
create_function_pass
@function_pass
(
opt_level
=
1
)
class
ChangeBatch
:
"""
Change the batch size.
Parameters
----------
data: Dict[relay.Var, int]
A dictionary of all the params to change.
The keys are all params, and the values is which dimension hold the batch.
batch_size: int
The batch size to change to.
Returns
-------
pass: FunctionPass
The pass.
"""
def
__init__
(
self
,
data
,
batch_size
=
16
):
self
.
data
=
data
self
.
batch_size
=
batch_size
def
transform_function
(
self
,
func
,
mod
,
ctx
):
func
=
relay
.
Function
(
func
.
params
,
func
.
body
,
None
,
func
.
type_params
,
func
.
attrs
)
change_batch
=
self
class
ChangeBatchMutator
(
tvm
.
relay
.
ExprMutator
):
def
visit_var
(
self
,
var
):
if
var
in
change_batch
.
data
:
ty
=
var
.
type_annotation
new_shape
=
list
(
ty
.
shape
)
new_shape
[
change_batch
.
data
[
var
]]
=
change_batch
.
batch_size
return
relay
.
Var
(
var
.
name_hint
,
relay
.
TensorType
(
new_shape
,
ty
.
dtype
))
else
:
return
var
return
ChangeBatchMutator
()
.
visit
(
func
)
tests/python/relay/test_change_batch.py
0 → 100644
View file @
16fefd89
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
from
tvm
import
relay
from
tvm.relay.testing
import
resnet
from
tvm.relay
import
transform
def
test_change_batch_resnet
():
net
,
params
=
resnet
.
get_workload
()
new_net
=
transform
.
ChangeBatch
({
net
[
"main"
]
.
params
[
0
]:
0
},
batch_size
=
123
)(
net
)
assert
new_net
[
"main"
]
.
checked_type
.
ret_type
==
relay
.
TensorType
((
123
,
1000
))
if
__name__
==
"__main__"
:
test_change_batch_resnet
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment