Commit 9c638f06 by Yao Wang Committed by Zhi

[Relay][Pass]Improve memory_allocation pass to support multiple i/o dynamic kernels (#4595)

* Add more shape funcs

* Fix test

* Enhance test_any_concat

* Fix pylint

* Minor fix test

* Fix pylint

* Minor refactor

* Add test any for elemwise
parent e69bd128
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition # pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
""" """
A pass for manifesting explicit memory allocations. A pass for manifesting explicit memory allocations.
""" """
...@@ -173,6 +173,8 @@ class ManifestAllocPass(ExprMutator): ...@@ -173,6 +173,8 @@ class ManifestAllocPass(ExprMutator):
new_args = [self.visit(arg) for arg in call.args] new_args = [self.visit(arg) for arg in call.args]
ins = expr.Tuple(new_args) ins = expr.Tuple(new_args)
ret_type = call.checked_type ret_type = call.checked_type
view = LinearizeRetType(ret_type)
out_types = view.unpack()
is_dynamic = ret_type.is_dynamic() is_dynamic = ret_type.is_dynamic()
# TODO(@jroesch): restore this code, more complex then it seems # TODO(@jroesch): restore this code, more complex then it seems
...@@ -180,26 +182,37 @@ class ManifestAllocPass(ExprMutator): ...@@ -180,26 +182,37 @@ class ManifestAllocPass(ExprMutator):
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic() # is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
if is_dynamic: if is_dynamic:
assert isinstance(ret_type, ty.TensorType)
shape_func_ins = [] shape_func_ins = []
engine = compile_engine.get() engine = compile_engine.get()
cfunc = engine.lower_shape_func(call.op, self.target_host) cfunc = engine.lower_shape_func(call.op, self.target_host)
input_states = cfunc.shape_func_param_states input_states = cfunc.shape_func_param_states
is_inputs = [] is_inputs = []
input_pos = 0
for i, (arg, state) in enumerate(zip(new_args, input_states)): for i, (arg, state) in enumerate(zip(new_args, input_states)):
state = int(state) state = int(state)
# Pass Shapes # Pass Shapes
if state == 2: if state == 2:
sh_of = self.visit(self.shape_of(arg)) if isinstance(arg.type_annotation, ty.TupleType):
shape_func_ins.append( for j in range(len(arg.type_annotation.fields)):
scope.let("in_shape_{0}".format(i), sh_of)) let_in_arg = scope.let("in_arg_{0}".format(input_pos + j),
expr.TupleGetItem(arg, j))
sh_of = self.visit(self.shape_of(let_in_arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos + j), sh_of))
input_pos += len(arg.type_annotation.fields)
else:
sh_of = self.visit(self.shape_of(arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos), sh_of))
input_pos += 1
is_inputs.append(0) is_inputs.append(0)
# Pass Inputs # Pass Inputs
elif state == 1: elif state == 1:
new_arg = self.visit(arg) new_arg = self.visit(arg)
shape_func_ins.append( shape_func_ins.append(
scope.let("in_shape_{0}".format(i), new_arg)) scope.let("in_shape_{0}".format(input_pos), new_arg))
input_pos += 1
is_inputs.append(1) is_inputs.append(1)
# TODO(@jroesch): handle 3rd case # TODO(@jroesch): handle 3rd case
else: else:
...@@ -219,9 +232,6 @@ class ManifestAllocPass(ExprMutator): ...@@ -219,9 +232,6 @@ class ManifestAllocPass(ExprMutator):
scope.let("shape_func", shape_call) scope.let("shape_func", shape_call)
out_types = []
out_types.append(call.checked_type)
storages = [] storages = []
for out_shape, out_type in zip(out_shapes, out_types): for out_shape, out_type in zip(out_shapes, out_types):
size = self.compute_storage_in_relay( size = self.compute_storage_in_relay(
...@@ -242,15 +252,13 @@ class ManifestAllocPass(ExprMutator): ...@@ -242,15 +252,13 @@ class ManifestAllocPass(ExprMutator):
alloc = scope.let("out_{i}".format(i=i), alloc) alloc = scope.let("out_{i}".format(i=i), alloc)
outs.append(alloc) outs.append(alloc)
invoke = self.invoke_tvm(call.op, ins, expr.Tuple(outs)) tuple_outs = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, tuple_outs)
scope.let("", invoke) scope.let("", invoke)
return outs[0] return outs[0] if len(outs) == 1 else tuple_outs
else: else:
view = LinearizeRetType(ret_type)
out_tys = view.unpack()
outs = [] outs = []
for i, out_ty in enumerate(out_tys): for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i) out = self.make_static_allocation(scope, out_ty, i)
outs.append(out) outs.append(out)
......
...@@ -18,9 +18,11 @@ ...@@ -18,9 +18,11 @@
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
import topi import topi
from topi.util import get_const_tuple
from .op import register_compute, register_schedule, register_pattern, register_shape_func from .op import register_compute, register_schedule, register_pattern, register_shape_func
from .op import schedule_injective, OpPattern from .op import schedule_injective, OpPattern
from ...hybrid import script from ...hybrid import script
from ...api import convert
schedule_broadcast = schedule_injective schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective schedule_elemwise = schedule_injective
...@@ -120,20 +122,20 @@ def _cast_shape_function(x): ...@@ -120,20 +122,20 @@ def _cast_shape_function(x):
def cast_shape_func(attrs, inputs, out_ndims): def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)] return [_cast_shape_function(*inputs)]
# shape func
@script @script
def _full_shape_func(x): def _full_shape_func(shape):
out_ndim = len(x) out_ndim = len(shape)
out = output_tensor((out_ndim,), "int64") out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim): for i in const_range(out_ndim):
out[i] = x[i] out[i] = int64(shape[i])
return out return out
def full_shape_func(attrs, inputs, out_ndims): def full_shape_func(attrs, inputs, out_ndims):
""" """
Shape func for zeros, zeros_like, ones, ones_like. Shape func for zeros, zeros_like, ones, ones_like.
""" """
return [_full_shape_func(*inputs)] shape = get_const_tuple(attrs.shape)
return [_full_shape_func(convert(shape))]
@script @script
def _broadcast_shape_func(x, y, ndim): def _broadcast_shape_func(x, y, ndim):
...@@ -177,9 +179,11 @@ def elemwise_shape_func(attrs, inputs, _): ...@@ -177,9 +179,11 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("cast", False, cast_shape_func) register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func) register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, full_shape_func) register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, full_shape_func) register_shape_func("ones", False, full_shape_func)
register_shape_func("ones_like", False, full_shape_func) register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
register_shape_func("add", False, broadcast_shape_func) register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func)
...@@ -196,6 +200,9 @@ register_shape_func("less", False, broadcast_shape_func) ...@@ -196,6 +200,9 @@ register_shape_func("less", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func) register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func) register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func) register_shape_func("greater_equal", False, broadcast_shape_func)
register_shape_func("maximum", False, broadcast_shape_func)
register_shape_func("minimum", False, broadcast_shape_func)
register_shape_func("sqrt", False, elemwise_shape_func) register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func) register_shape_func("negative", False, elemwise_shape_func)
register_shape_func("exp", False, elemwise_shape_func)
...@@ -452,24 +452,8 @@ def transpose_shape_func(attrs, inputs, _): ...@@ -452,24 +452,8 @@ def transpose_shape_func(attrs, inputs, _):
@script @script
def _squeeze_shape_func(data_shape, keep_axes): def _squeeze_shape_func(data_shape, keep_axes):
out = output_tensor((len(keep_axes),), "int64") out = output_tensor((len(keep_axes),), "int64")
if len(keep_axes) == 0: for i in const_range(len(keep_axes)):
out_size = 0 out[i] = data_shape[keep_axes[i]]
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out_size += 1
if out_size == 0:
out_size = 1
out = output_tensor((out_size,), "int64")
out[0] = int64(1)
pos = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out[pos] = data_shape[i]
pos += 1
else:
for i in const_range(len(keep_axes)):
out[i] = data_shape[keep_axes[i]]
return out return out
...@@ -485,7 +469,16 @@ def squeeze_shape_func(attrs, inputs, _): ...@@ -485,7 +469,16 @@ def squeeze_shape_func(attrs, inputs, _):
if i not in axis: if i not in axis:
keep_axes.append(i) keep_axes.append(i)
return [_squeeze_shape_func(inputs[0], convert(keep_axes))] # Due to current relay type system, it is possible even
# a static kernel function needs shape function. To handle
# this case, we allow axis to be None in squeeze shape func
# for now.
# TODO(kevinthesun): Enhance relay type system to avoid this.
if keep_axes:
out = _squeeze_shape_func(inputs[0], convert(keep_axes))
else:
out = tvm.compute((), lambda *indices: 0)
return [out]
@script @script
def _reshape_like_shape_func(target_shape): def _reshape_like_shape_func(target_shape):
...@@ -527,9 +520,56 @@ def _tile_shape_func(data, reps, ndim, tndim, rndim): ...@@ -527,9 +520,56 @@ def _tile_shape_func(data, reps, ndim, tndim, rndim):
@_reg.register_shape_func("tile", False) @_reg.register_shape_func("tile", False)
def tile_shape_func(attrs, inputs, _): def tile_shape_func(attrs, inputs, _):
"""
Shape function for tile op.
"""
reps = get_const_tuple(attrs.reps) reps = get_const_tuple(attrs.reps)
ndim = inputs[0].shape[0].value ndim = inputs[0].shape[0].value
rndim = len(reps) rndim = len(reps)
tndim = ndim if ndim > rndim else rndim tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], convert(reps), convert(ndim), return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
convert(tndim), convert(rndim))] convert(tndim), convert(rndim))]
@script
def _split_shape_func(data_shape, index, indices_or_sections, axis):
out = output_tensor((data_shape.shape[0],), "int64")
if len(indices_or_sections) == 1:
for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
else:
out[i] = data_shape[i]
else:
start = int64(0)
if index > 0:
start = int64(indices_or_sections[index - 1])
end = data_shape[axis]
if index < len(indices_or_sections):
end = int64(indices_or_sections[index])
for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = end - start
else:
out[i] = data_shape[i]
return out
@_reg.register_shape_func("split", False)
def split_shape_func(attrs, inputs, _):
"""
Shape function for split op.
"""
if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)):
indices_or_sections = get_const_int(attrs.indices_or_sections)
else:
indices_or_sections = get_const_tuple(attrs.indices_or_sections)
axis = get_const_int(attrs.axis)
num_out = indices_or_sections if isinstance(indices_or_sections, int) \
else len(indices_or_sections) + 1
if isinstance(indices_or_sections, int):
indices_or_sections = [indices_or_sections]
return [_split_shape_func(inputs[0],
convert(i),
convert(indices_or_sections),
convert(axis)) for i in range(num_out)]
...@@ -59,6 +59,22 @@ def test_any_broadcast(): ...@@ -59,6 +59,22 @@ def test_any_broadcast():
verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add) verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add) verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
def verify_any_elemwise(x_shape, x_np_shape, op, np_op):
dtype = 'float32'
x = relay.var('x', shape=x_shape, dtype=dtype)
mod = relay.module.Module()
mod["main"] = relay.Function([x], op(x))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
res_np = np_op(x_np)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np)
tvm.testing.assert_allclose(result.asnumpy(), res_np)
def test_any_elemwise():
verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt)
verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative)
verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp)
def test_any_broadcast_fail(): def test_any_broadcast_fail():
# Test broadcast with incompatible values at runtime # Test broadcast with incompatible values at runtime
...@@ -107,12 +123,14 @@ def test_any_full(): ...@@ -107,12 +123,14 @@ def test_any_full():
def test_any_concat(): def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32") y = relay.var('y', shape=(1, 2), dtype="float32")
z = relay.op.concatenate([x, y], axis=0) xx = x - relay.expr.const(3.0)
yy = y * relay.expr.const(5.0)
z = relay.op.concatenate([xx, yy], axis=0)
mod = relay.module.Module() mod = relay.module.Module()
mod["main"] = relay.Function([x, y], z) mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32') x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32')
ref = np.concatenate([x_np, y_np], axis=0) ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
for kind in ["debug", "vm"]: for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np) result = ex.evaluate()(x_np, y_np)
...@@ -417,6 +435,24 @@ def test_any_global_pool2d(): ...@@ -417,6 +435,24 @@ def test_any_global_pool2d():
verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4), verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
"NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4)) "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4))
def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, ref_out_shape):
mod = relay.Module()
dtype = "float32"
data = relay.var('data', shape=data_shape, dtype=dtype)
y = relay.split(data, indices_or_sections, axis)
mod["main"] = relay.Function([data], y.astuple())
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
for kind in ["vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
for ret, ref_ret in zip(result, ref_out_shape):
assert ret.asnumpy().shape == ref_ret, \
"Shape mismatch: expect %s but got %s." % (str(ref_ret), str(ret.asnumpy().shape))
def test_any_split():
verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)])
verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
def test_any_batch_flatten(): def test_any_batch_flatten():
mod = relay.Module() mod = relay.Module()
dtype = "float32" dtype = "float32"
...@@ -601,11 +637,13 @@ def test_recursive_concat_with_wrong_annotation(): ...@@ -601,11 +637,13 @@ def test_recursive_concat_with_wrong_annotation():
if __name__ == "__main__": if __name__ == "__main__":
test_any_full() test_any_full()
test_any_broadcast() test_any_broadcast()
test_any_elemwise()
test_any_broadcast_fail() test_any_broadcast_fail()
test_any_concat() test_any_concat()
test_any_reshape() test_any_reshape()
test_any_take() test_any_take()
test_any_tile() test_any_tile()
test_any_split()
test_any_shape_of() test_any_shape_of()
test_any_reduce() test_any_reduce()
test_any_layout_transform() test_any_layout_transform()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment