Commit 16fefd89 by 雾雨魔理沙 Committed by Thierry Moreau

[Relay][VTA] Add ChangeBatch pass (#3656)

* init

* lint

* lint
parent a88b2842
......@@ -195,9 +195,10 @@ class ExprMutator(ExprFunctor):
and reconstructs the AST.
"""
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
return Function(
list(fn.params),
list(new_params),
new_body,
fn.ret_type,
fn.type_params,
......@@ -214,8 +215,8 @@ class ExprMutator(ExprFunctor):
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, rvar):
return rvar
def visit_var(self, var):
return var
def visit_global_id(self, global_var):
return global_var
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring
"""
Relay pass transformation infrastructure.
"""
......@@ -22,7 +22,9 @@ import types
import inspect
import functools
import tvm
from tvm._ffi.runtime_ctypes import TVMContext
from tvm import relay
from . import _transform
from .base import RelayNode, register_relay_node
from .. import nd as _nd
......@@ -908,3 +910,40 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
if pass_func:
return create_function_pass(pass_func)
return create_function_pass
@function_pass(opt_level=1)
class ChangeBatch:
"""
Change the batch size.
Parameters
----------
data: Dict[relay.Var, int]
A dictionary of all the params to change.
The keys are all params, and the values is which dimension hold the batch.
batch_size: int
The batch size to change to.
Returns
-------
pass: FunctionPass
The pass.
"""
def __init__(self, data, batch_size=16):
self.data = data
self.batch_size = batch_size
def transform_function(self, func, mod, ctx):
func = relay.Function(func.params, func.body, None, func.type_params, func.attrs)
change_batch = self
class ChangeBatchMutator(tvm.relay.ExprMutator):
def visit_var(self, var):
if var in change_batch.data:
ty = var.type_annotation
new_shape = list(ty.shape)
new_shape[change_batch.data[var]] = change_batch.batch_size
return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype))
else:
return var
return ChangeBatchMutator().visit(func)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay.testing import resnet
from tvm.relay import transform
def test_change_batch_resnet():
net, params = resnet.get_workload()
new_net = transform.ChangeBatch({net["main"].params[0]: 0}, batch_size=123)(net)
assert new_net["main"].checked_type.ret_type == relay.TensorType((123, 1000))
if __name__ == "__main__":
test_change_batch_resnet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment