Unverified Commit 0145cd50 by masahi Committed by GitHub

[Torch] Support Python list, more realistic recurrent networks (#5306)

* use funcs from prelude, pass around convert_map

* get relay input type from user ishape

* handle tuple unpack

* experimenting with static tensor array

* use prelude concat instead of cons + rev

* minor clean up

* fix layer norm conversion bug, unwrap tensor array

* add infer shape on tensor array

* pass around prelude for now

* compile worked but runtime error

* fix tensor array wrapping

* begin list dynamic test

* is_list_dynamic first version

* finish dynamic list test

* a few fix

* use shape_of function if Any is found

* improve size conversion

* working on adding free vars to loop block

* fixed inlined inner loop issue

* clean up free var handling

* add support for tensor array concat

* adding ta concat on last axis

* fix concat, but got runtime error

* disable concat on axis -1 for now

* add lstm tests

* revert unrelated change

* fix stacked bidir test

* minor fix to test

* relax tol a bit, revert dnnl change to avoid conflict

* simplify infer type, use input tensor shape rather than concat shape

* more shape fix
parent cd0d52da
......@@ -25,20 +25,95 @@ import sys
import numpy as np
import tvm
from tvm.ir import module as _module
from .. import analysis as _analysis
from .. import expr as _expr
from .. import op as _op
from ..ty import TupleType, TensorType, Any
from ..loops import while_loop
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_type as _infer_type
from ..prelude import Prelude, StaticTensorArrayOps
from . import qnn_torch
__all__ = ["from_pytorch"]
# List ADT utilities
def _infer_type_with_prelude(val, prelude):
body = _infer_type(val, prelude.mod)
return body.checked_type
def _convert_to_list_adt(py_lst, prelude):
elem_tys = [_infer_type_with_prelude(elem, prelude) for elem in py_lst]
msg = "List elements should have identical types"
assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
adt_lst = prelude.nil()
for elem in reversed(py_lst):
adt_lst = prelude.cons(elem, adt_lst)
return adt_lst
def _map_tensor_array_constructor(adt_lst, prelude, shape):
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.register()
tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape)
return prelude.map(tensor_create, adt_lst)
def _convert_to_tensor_array(adt_lst, prelude):
if prelude.length(adt_lst) == 0:
return prelude.nil()
checked_type = _infer_type_with_prelude(prelude.hd(adt_lst), prelude)
shape = checked_type.shape
tensor_array = _map_tensor_array_constructor(adt_lst, prelude, shape)
return tensor_array, tuple(shape)
def _should_construct_dynamic_list(list_construct_node):
# if this list is element-accessed or modified at runtime, generate List ADT
def is_used_by_list_add(uses):
for use in uses:
op_name = use.user.kind()
output_type = _get_node_type(use.user)
if op_name in ["aten::add", "aten::add_"] and output_type == "ListType":
return True
return False
def inplace_add_to_add(op_name):
if op_name == "aten::add_":
return "aten::add"
else:
return op_name
uses = _get_uses(list_construct_node)
for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses):
block_input_index = loop_use.offset - 1
block = list(loop_use.user.blocks())[0]
list_loop_var = list(block.inputs())[block_input_index]
uses += _get_uses(list_loop_var.node())
op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses))
list_ops = set(["aten::add", "aten::__getitem__", "aten::stack"])
intersect = list_ops.intersection(op_names)
if len(intersect) > 0 and intersect != set(["aten::add"]):
return True
if is_used_by_list_add(filter(lambda use: use.user.kind() != "prim::Loop", uses)):
return True
return False
# operator implementation
def _elemwise(name):
def _impl(inputs, input_types):
......@@ -103,11 +178,27 @@ def _unsqueeze():
return _op.transform.expand_dims(data, int(axis), 1)
return _impl
def _concatenate():
def _concatenate(prelude):
def tensor_array_concat(lst, axis):
assert axis == 0, "Tensor array concat supported only for axis 0"
tensor_array, shape = _convert_to_tensor_array(lst, prelude)
concat_shape = (Any(),) + shape[1:]
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.define_tensor_get_data(concat_shape)
concat = prelude.get_var_static('tensor_array_concat', "float32", shape)
concatenated = concat(tensor_array)
get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
return get_tensor(concatenated)
def _impl(inputs, input_types):
data = inputs[0]
axis = inputs[1]
if not isinstance(data, list):
return tensor_array_concat(data, axis)
if isinstance(data, _expr.Expr):
data = [data]
......@@ -130,7 +221,7 @@ def _slice():
else:
end = data.shape
begin = [0]*len(end)
begin = [0] * len(end)
dim = int(inputs[1])
begin[dim] = int(inputs[2])
......@@ -371,7 +462,7 @@ def _maxpool_2d():
ceil_mode = int(inputs[5])
if dilation != (1, 1):
msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), )
msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation))
raise NotImplementedError(msg)
return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
......@@ -388,7 +479,7 @@ def _maxpool_1d():
ceil_mode = int(inputs[5])
if dilation != (1,):
msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), )
msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation))
raise NotImplementedError(msg)
return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
......@@ -404,7 +495,7 @@ def _maxpool_3d():
dilation = _infer_shape(inputs[4])
ceil_mode = int(inputs[5])
if dilation != (1, 1, 1):
msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), )
msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation))
raise NotImplementedError(msg)
return _op.nn.max_pool3d(data,
......@@ -618,13 +709,13 @@ def _layer_norm():
scale=True)
return _impl
def _transpose():
def _transpose(prelude):
def _impl(inputs, input_types):
data = inputs[0]
import torch
if isinstance(data, _expr.Expr):
ndims = len(_infer_shape(data))
ndims = len(_infer_shape(data, prelude.mod))
elif isinstance(data, list):
ndims = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
......@@ -693,15 +784,30 @@ def _dense():
return dense_out
return _impl
def _size():
def _size(prelude):
def _impl_dynamic(inp, axis):
shape_dynamic = _op.shape_of(inp)
if axis is not None:
return _op.take(shape_dynamic, _expr.const(axis), 0)
return shape_dynamic
def _impl(inputs, input_types):
shape = _infer_shape(inputs[0])
shape = _infer_shape(inputs[0], prelude.mod)
axis = None
if len(inputs) > 1:
axis = int(inputs[1])
if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)):
if axis is None or isinstance(shape[axis], tvm.tir.expr.Any):
return _impl_dynamic(inputs[0], axis)
if axis is not None:
return shape[axis]
return shape
return _impl
def _numtotensor():
def _impl(inputs, input_types):
val = inputs[0]
......@@ -862,7 +968,7 @@ def _mean():
return _impl
def _chunk():
def _chunk(prelude):
def _impl(inputs, input_types):
data = inputs[0]
......@@ -870,7 +976,7 @@ def _chunk():
axis = int(inputs[2])
if isinstance(data, _expr.Expr):
inferred_shape = _infer_shape(data)
inferred_shape = _infer_shape(data, prelude.mod)
shape = []
for infer in inferred_shape:
......@@ -894,7 +1000,6 @@ def _chunk():
chunk_out = _op.transform.strided_slice(data, begin, end, stride)
chunks.append(chunk_out)
if dim % num_chunks:
begin = [0] * len(shape)
end = shape[:]
......@@ -1077,6 +1182,49 @@ def _Float():
return _op.cast(inputs[0], "float32")
return _impl
def _mm():
def _impl(inputs, input_types):
return _op.nn.dense(inputs[0], inputs[1])
return _impl
def _list_getitem(prelude):
def _impl(inputs, input_types):
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
return _impl
def _list_len(prelude):
def _impl(inputs, input_types):
return prelude.length(inputs[0])
return _impl
def _add(prelude):
# add_ is overloaded for tensor add and list concat
def _impl(inputs, input_types):
if input_types[0] == "ListType":
return prelude.concat(inputs[0], inputs[1])
return _elemwise("add")(inputs, input_types)
return _impl
def _tensor_array_stack(prelude):
def _impl(inputs, input_types):
tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)
stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
stacked = stack(tensor_array)
stacked_shape = (Any(),) + shape
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.define_tensor_get_data(stacked_shape)
# passing stacked_shape below gives "'Prelude' object has no attribute" error
get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
return get_tensor(stacked)
return _impl
# Helper functions for operator implementation
def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64",
......@@ -1148,16 +1296,14 @@ def _convert_elemwise_input(data, input_type):
return data
def _wrap_const(c):
if not isinstance(c, _expr.Expr) and not isinstance(c, list):
if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)):
return _expr.const(c)
return c
# Operator mappings
_convert_map = {
def _get_convert_map(prelude):
convert_map = {
"aten::device" : _none(),
"aten::add" : _elemwise("add"),
"aten::add_" : _elemwise("add"),
"aten::sub" : _elemwise("subtract"),
"aten::sub_" : _elemwise("subtract"),
"aten::max" : _elemwise("maximum"),
......@@ -1165,10 +1311,10 @@ _convert_map = {
"aten::mul" : _elemwise("multiply"),
"aten::mul_" : _elemwise("multiply"),
"aten::pow" : _elemwise("power"),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
"aten::abs" : _abs(),
"aten::arange" : _arange(),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
......@@ -1177,7 +1323,7 @@ _convert_map = {
"aten::to" : _to(),
"aten::squeeze" : _squeeze(),
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::cat" : _concatenate(prelude),
"aten::slice" : _slice(),
"aten::split" : _split(),
"aten::split_with_sizes" : _split_with_sizes(),
......@@ -1207,12 +1353,12 @@ _convert_map = {
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
"aten::layer_norm" : _layer_norm(),
"aten::transpose" : _transpose(),
"aten::transpose_" : _transpose(),
"aten::t" : _transpose(),
"aten::transpose" : _transpose(prelude),
"aten::transpose_" : _transpose(prelude),
"aten::t" : _transpose(prelude),
"aten::flatten" : _flatten(),
"aten::addmm" : _dense(),
"aten::size" : _size(),
"aten::size" : _size(prelude),
"aten::view" : _view(),
"aten::reshape" : _reshape(),
"aten::clone" : _clone(),
......@@ -1226,14 +1372,13 @@ _convert_map = {
"aten::feature_dropout" : _dropout(),
"aten::alpha_dropout" : _dropout(),
"aten::mean" : _mean(),
"aten::chunk" : _chunk(),
"aten::chunk" : _chunk(prelude),
"aten::matmul" : _matmul(),
"aten::expand" : _expand(),
"aten::Int" : _int(),
"prim::NumToTensor" : _numtotensor(),
"prim::ListUnpack" : _identity(),
"aten::constant_pad_nd" : _pad(),
"aten::permute" : _transpose(),
"aten::permute" : _transpose(prelude),
"aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"),
"aten::sqrt" : _sqrt(),
......@@ -1252,8 +1397,16 @@ _convert_map = {
"aten::neg" : _neg(),
"aten::tanh" : _tanh(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d()
}
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
"aten::mm" : _matmul(),
"relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude),
"aten::add_" : _add(prelude),
"aten::stack" : _tensor_array_stack(prelude),
"aten::__getitem__" : _list_getitem(prelude),
"aten::len" : _list_len(prelude),
}
return convert_map
def _run_jit_passes(graph):
......@@ -1289,13 +1442,29 @@ def _get_op_inputs(op_node, outputs):
return [outputs[name] for name in _get_input_names(op_node)]
def _report_missing_conversion(op_names):
def _get_node_type(node):
assert node.outputsSize() == 1
return node.output().type().kind()
def _get_uses(node):
uses = []
for output in node.outputs():
uses += output.uses()
return uses
def _get_users(node):
return [use.user for use in _get_uses(node)]
def _report_missing_conversion(op_names, convert_map):
""" Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack",
"prim::If", "prim::Loop"]
known_ops += list(_convert_map.keys())
known_ops += list(convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())
missing = [op_name for op_name in op_names
......@@ -1361,7 +1530,7 @@ def _get_input_types(op_node):
input_list_types.append(in_ty.scalarType().lower())
elif input_node_kind == 'ListType':
input_list_types.append(str(in_ty.getElementType()).lower())
input_list_types.append("ListType")
elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
'StringType', 'OptionalType']:
input_list_types.append(str(in_ty).lower())
......@@ -1422,21 +1591,69 @@ def _get_operator_nodes(nodes):
return ops
def _get_relay_input_vars(graph, input_shapes):
def _get_graph_input_names(graph):
""" Get the graph input names (use after graph copy and run jit passes) """
# Variable names could change the first time a copy is made and after
# _run_jit_passes is called, expected that those functions already invoked
ir_inputs = _get_input_names(graph)
return ir_inputs[1:] # remove self at the 0th arg
def _get_relay_input_vars(graph, input_shapes, prelude):
"""
Return Relay vars from input shapes and create entries based on
expected graph inputs - to allow translation
"""
def get_relay_ty(ishape):
if _is_int_seq(ishape) or len(ishape) == 0:
return TensorType(ishape)
elif isinstance(ishape, tuple):
return TupleType([get_relay_ty(elem) for elem in ishape])
elif isinstance(ishape, list):
assert len(ishape) > 0
elem_tys = [get_relay_ty(s) for s in ishape]
msg = "List elements should have identical types"
assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
return prelude.l(elem_tys[0])
raise NotImplementedError("unsupported input type")
input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes]
input_vars = {}
ir_inputs = _get_graph_input_names(graph)
for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
inp = _expr.var(name, shape=shape)
for ir_input, (name, itype) in zip(ir_inputs, input_types):
inp = _expr.var(name, type_annotation=itype)
# Translate from graph input to user input name
input_vars[ir_input] = inp
return input_vars
def _unpack_tuple(tup):
def unpack(tup, num_fields):
return [_expr.TupleGetItem(tup, i) for i in range(num_fields)]
if isinstance(tup, _expr.Tuple):
return unpack(tup, len(tup.fields))
elif isinstance(tup.type_annotation, TupleType):
return unpack(tup, len(tup.type_annotation.fields))
# shouldn't happen
assert False
def _get_free_vars_from_block(block):
block_inp_names = _get_input_names(block)
bound_names = block_inp_names
free_vars = set()
for node in block.nodes():
inp_names = _get_input_names(node)
list_diff = [name for name in inp_names if name not in bound_names]
free_vars.update(list_diff)
bound_names += _get_output_names(node)
return free_vars
def get_use_chains(root_node, terminate=lambda _: False):
"""
Track a chain of users of this node forward, returning a list of chains
......@@ -1446,9 +1663,7 @@ def get_use_chains(root_node, terminate=lambda _: False):
return itertools.chain.from_iterable(lists)
def inner(current, accum):
users = []
for output in current.outputs():
users += [use.user for use in output.uses()]
users = _get_users(current)
if not users or terminate(users):
return [accum]
......@@ -1512,24 +1727,24 @@ def convert_params(graph, state_dict):
return params, param_tensors, packed_param_map
def convert_block(block, outputs):
def convert_block(block, outputs, convert_map, prelude):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
return convert_operators(ops, outputs, ret_names)
return convert_operators(ops, outputs, ret_names, convert_map, prelude)
def convert_if(if_node, outputs):
def convert_if(if_node, outputs, convert_map, prelude):
""" Translate Torch prim::If to Relay If """
cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
true_branch = convert_block(blocks[0], outputs)
false_branch = convert_block(blocks[1], outputs)
true_branch = convert_block(blocks[0], outputs, convert_map, prelude)
false_branch = convert_block(blocks[1], outputs, convert_map, prelude)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])
def convert_loop(loop_node, outputs):
def convert_loop(loop_node, outputs, convert_map, prelude):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
......@@ -1555,8 +1770,54 @@ def convert_loop(loop_node, outputs):
is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
_get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)
if is_while_loop:
loop_iter_dtype = "bool"
# while loop with non input dependent condition such as while i < 10:
# init_cond is int, need to cast to bool to type check
if isinstance(init_cond, _expr.Constant):
init_cond = _op.cast(init_cond, "bool")
init_loop_iter_val = init_cond
else:
loop_iter_dtype = "int32"
# always count from 0
init_loop_iter_val = _expr.const(0, dtype="int32")
body_block = list(loop_node.blocks())[0]
block_input_names = _get_input_names(body_block)
num_block_inputs = len(block_input_names)
name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
outputs.update(name_val_pairs)
def get_var(name, val):
if val:
checked_type = _infer_type_with_prelude(val, prelude)
return _expr.var(name, type_annotation=checked_type)
return _expr.var(name)
loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
# Add non constant free variables to loop variables to prevent code blow up
# Without this, if there are two for loops in a row, which often happens
# if the outer loop is unrolled, the computation corresponding to the first for loop
# is inlined inside loop body, turning O(N) + O(N) computation into O(N^2).
# This issue was found when converting from Stacked LSTM test. Torch does not add the output
# of the eariler loop into loop variables of the next loop.
# So the variable corresponding to the first loop output appears free in the second loop body.
free_vars = [var for var in _get_free_vars_from_block(body_block)
if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float))
and outputs[var]]
prev_outputs = {}
for name in free_vars:
prev_output = outputs[name]
new_loop_var = get_var(name, prev_output)
prev_outputs[name] = prev_output
outputs[name] = new_loop_var
loop_vars.append(new_loop_var)
init_vals.append(prev_output)
def cond(*current_vals):
i = current_vals[0]
......@@ -1568,11 +1829,16 @@ def convert_loop(loop_node, outputs):
def body(*current_vals):
# Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names):
outputs[iname] = current_vals[i]
assert len(current_vals) == num_block_inputs + len(free_vars)
for (i, val) in enumerate(current_vals):
if i < num_block_inputs:
outputs[block_input_names[i]] = val
else:
outputs[free_vars[i-num_block_inputs]] = val
block_outputs = convert_block(body_block, outputs)
block_outputs = convert_block(body_block, outputs, convert_map, prelude)
block_outputs += [outputs[name] for name in free_vars]
if not is_while_loop:
# iter var increment implicit in torch, so do it manually
......@@ -1583,38 +1849,17 @@ def convert_loop(loop_node, outputs):
return block_outputs
def get_var(name, val):
if isinstance(val, _expr.Constant):
return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
return _expr.var(name)
if is_while_loop:
loop_iter_dtype = "bool"
# while loop with non input dependent condition such as while i < 10:
# init_cond is int, need to cast to bool to type check
if isinstance(init_cond, _expr.Constant):
init_cond = _op.cast(init_cond, "bool")
init_loop_iter_val = init_cond
else:
loop_iter_dtype = "int32"
# always count from 0
init_loop_iter_val = _expr.const(0, dtype="int32")
name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
outputs.update(name_val_pairs)
loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
loop_val = loop(init_loop_iter_val, *init_vals)
# restore original output values for free vars
outputs.update(prev_outputs)
# The first element is a loop counter or boolean condition, ignore it
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
def convert_operators(operators, outputs, ret_names):
def convert_operators(operators, outputs, ret_names, convert_map, prelude):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
......@@ -1622,24 +1867,33 @@ def convert_operators(operators, outputs, ret_names):
if operator == "prim::Constant":
outputs[node_name] = _get_constant(op_node)
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
elif operator == "prim::ListConstruct" and _is_int_seq(inputs):
outputs[node_name] = _expr.var(node_name, shape=inputs)
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node):
outputs[node_name] = _convert_to_list_adt(inputs, prelude)
elif operator == "prim::ListConstruct":
# This assumes that no more elements will be appended to this list
# In this case, we keep the Python list
outputs[node_name] = inputs
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
elif operator == "prim::TupleConstruct":
outputs[node_name] = _expr.Tuple(inputs)
elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]:
assert len(inputs) == 1
unpacked_names = _get_output_names(op_node)
outputs.update(zip(unpacked_names, inputs[0]))
if isinstance(inputs[0], (list, _expr.TupleWrapper)):
unpacked = inputs[0]
else:
unpacked = _unpack_tuple(inputs[0])
outputs.update(zip(_get_output_names(op_node), unpacked))
elif operator == "prim::If":
if_out = convert_if(op_node, outputs)
if_out = convert_if(op_node, outputs, convert_map, prelude)
outputs[node_name] = if_out
elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs)
loop_out = convert_loop(op_node, outputs, convert_map, prelude)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = _convert_map[operator]
relay_op = convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node))
if isinstance(relay_out, tuple):
......@@ -1666,14 +1920,6 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)
def _get_graph_input_names(graph):
""" Get the graph input names (use after graph copy and run jit passes) """
# Variable names could change the first time a copy is made and after
# _run_jit_passes is called, expected that those functions already invoked
ir_inputs = _get_input_names(graph)
return ir_inputs[1:] # remove self at the 0th arg
def from_pytorch(script_module, input_shapes, custom_convert_map=None):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
......@@ -1700,18 +1946,23 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
params : dict of str to tvm.runtime.NDArray
Dict of converted parameters stored in tvm.runtime.ndarray format
"""
mod = tvm.IRModule()
prelude = Prelude(mod)
convert_map = _get_convert_map(prelude)
graph = script_module.graph.copy()
_run_jit_passes(graph)
if custom_convert_map:
_convert_map.update(custom_convert_map)
convert_map.update(custom_convert_map)
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
_report_missing_conversion(op_names, convert_map)
_check_inputs(graph, input_shapes)
params = script_module.state_dict()
outputs = _get_relay_input_vars(graph, input_shapes)
outputs = _get_relay_input_vars(graph, input_shapes, prelude)
param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
......@@ -1726,14 +1977,11 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)
convert_map.update(qnn_torch.convert_map)
ret = convert_operators(_get_operator_nodes(graph.nodes()),
outputs, ret_name)
if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])
outputs, ret_name, convert_map, prelude)
func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
return _module.IRModule.from_expr(func), tvm_params
return mod, tvm_params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Tests on torch lstm model conversion """
# originally from https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
# described in https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
from typing import List, Tuple
from torch import Tensor
import tvm
from tvm import relay
from tvm.relay.frontend.pytorch import from_pytorch
from tvm.relay.prelude import Prelude
from tvm.runtime.container import ADT, tuple_object
class LayerNormLSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
ln = nn.LayerNorm
self.layernorm_i = ln(4 * hidden_size)
self.layernorm_h = ln(4 * hidden_size)
self.layernorm_c = ln(hidden_size)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
gates = igates + hgates
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)
class LSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super().__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
outputs = []
for i in range(input.size(0)):
out, state = self.cell(input[i], state)
outputs += [out]
return torch.stack(outputs), state
class ReverseLSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(ReverseLSTMLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, inputs, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
outputs = jit.annotate(List[Tensor], [])
seq_len = inputs.size(0)
for i in range(seq_len):
out, state = self.cell(inputs[seq_len - i - 1], state)
# workaround for the lack of list rev support
outputs = [out] + outputs
return torch.stack(outputs), state
class BidirLSTMLayer(jit.ScriptModule):
__constants__ = ['directions']
def __init__(self, cell, *cell_args):
super(BidirLSTMLayer, self).__init__()
self.directions = nn.ModuleList([
LSTMLayer(cell, *cell_args),
ReverseLSTMLayer(cell, *cell_args),
])
@jit.script_method
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: [forward LSTMState, backward LSTMState]
outputs = jit.annotate(List[Tensor], [])
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
for (i, direction) in enumerate(self.directions):
state = states[i]
out, out_state = direction(input, state)
outputs += [out]
output_states += [out_state]
# tensor array concat assumes axis == 0 for now
# return torch.cat(outputs, -1), output_states
return torch.cat(outputs, 0), output_states
def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
for _ in range(num_layers - 1)]
return nn.ModuleList(layers)
class StackedLSTM(jit.ScriptModule):
__constants__ = ['layers'] # Necessary for iterating through self.layers
def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
super().__init__()
self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
other_layer_args)
@jit.script_method
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output = input
for (i, rnn_layer) in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
return output, output_states
class StackedBidirLSTM(jit.ScriptModule):
__constants__ = ['layers'] # Necessary for iterating through self.layers
def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
super(StackedBidirLSTM, self).__init__()
self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
other_layer_args)
@jit.script_method
def forward(self, input, states):
# type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
# List[List[LSTMState]]: The outer list is for layers,
# inner list is for directions.
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
output = input
for (i, rnn_layer) in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
return output, output_states
def lstm(input_size, hidden_size):
return LSTMLayer(LayerNormLSTMCell, input_size, hidden_size)
def stacked_lstm(input_size, hidden_size, num_layers):
return StackedLSTM(num_layers, LSTMLayer,
first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
def bidir_lstm(input_size, hidden_size):
return BidirLSTMLayer(LayerNormLSTMCell, input_size, hidden_size)
def stacked_bidir_lstm(input_size, hidden_size, num_layers):
return StackedBidirLSTM(num_layers, BidirLSTMLayer,
first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.nd.NDArray):
return [o]
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f, dtype))
return result
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def assert_equal(tvm_result, torch_result):
if isinstance(torch_result, (tuple, list)):
assert isinstance(tvm_result, list)
for tvm_res, pt_res in zip(tvm_result, torch_result):
assert_equal(tvm_res, pt_res)
elif isinstance(torch_result, torch.Tensor):
tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(),
rtol=1e-4, atol=1e-4)
def run_and_compare(mod, params, pt_result):
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
evaluator = executor.evaluate()
exec_res = evaluator(**params)
def flatten(nested):
res = []
for r in nested:
if isinstance(r, torch.Tensor):
res.append(r)
else:
res.extend(flatten(r))
return res
if isinstance(exec_res, tvm.runtime.container.ADT):
assert not isinstance(pt_result, torch.Tensor)
tvm_res = vmobj_to_list(exec_res)
torch_res = flatten(pt_result)
else:
tvm_res = exec_res
torch_res = pt_result
assert_equal(tvm_res, torch_res)
def convert_list_to_vmobj(py_lst):
def wrap_nd_array(arr):
return tvm.nd.array(arr, ctx=tvm.cpu(0))
mod = tvm.IRModule()
prelude = Prelude(mod)
adt_lst = ADT(prelude.nil.tag, [])
for elem in reversed(py_lst):
if isinstance(elem, np.ndarray):
vmobj = wrap_nd_array(elem)
elif isinstance(elem, tuple):
vmobj = tuple_object([wrap_nd_array(e) for e in elem])
elif isinstance(elem, list):
vmobj = convert_list_to_vmobj(elem)
adt_lst = ADT(prelude.cons.tag, [vmobj, adt_lst])
return adt_lst
def custom_lstm_test():
input_name = "input"
states_name = "states"
seq_len = 5
batch = 2
input_size = 3
hidden_size = 4
num_layers = 3
state_tensor_shape = (batch, hidden_size)
inp = torch.randn(seq_len, batch, input_size)
input_shapes = [(input_name, (seq_len, batch, input_size)),
(states_name, (state_tensor_shape, state_tensor_shape))]
input_shapes_stacked = [(input_name, (seq_len, batch, input_size)),
(states_name, [(state_tensor_shape, state_tensor_shape),
(state_tensor_shape, state_tensor_shape)])]
input_shapes_stacked_bidir = [(input_name, (seq_len, batch, input_size)),
(states_name, [[(state_tensor_shape,
state_tensor_shape)
for _ in range(2)]
for _ in range(num_layers)])]
states = [(torch.randn(state_tensor_shape),
torch.randn(state_tensor_shape))
for _ in range(num_layers)]
bidir_states = [(torch.randn(state_tensor_shape),
torch.randn(state_tensor_shape))
for _ in range(2)]
stacked_bidir_states = [[(torch.randn(state_tensor_shape),
torch.randn(state_tensor_shape))
for _ in range(2)]
for _ in range(num_layers)]
models = [
(lstm(input_size, hidden_size).eval(), states[0], input_shapes),
(stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked),
(bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked),
(stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(),
stacked_bidir_states, input_shapes_stacked_bidir)
]
for (raw_model, states, input_shapes) in models:
script_module = torch.jit.script(raw_model)
mod, params = from_pytorch(script_module, input_shapes)
with torch.no_grad():
pt_result = raw_model(inp.clone(), states)
params[input_name] = inp.numpy()
if isinstance(states, tuple):
states_np = tuple(st.numpy() for st in states)
elif isinstance(states, list) and isinstance(states[0], torch.Tensor):
states_np = [st.numpy() for st in states]
elif isinstance(states, list) and isinstance(states[0], tuple):
states_np = [tuple(st.numpy() for st in states[i])
for i in range(len(states))]
elif isinstance(states, list) and isinstance(states[0], list):
states_np = [[tuple(st.numpy() for st in states)
for states in states[layer]]
for layer in range(num_layers)]
else:
assert False
if isinstance(states_np, list):
params[states_name] = convert_list_to_vmobj(states_np)
else:
params[states_name] = states_np
run_and_compare(mod, params, pt_result)
......@@ -543,7 +543,7 @@ def test_forward_maxpool1d():
input_data)
verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(),
input_data)
verify_model( torch.nn.MaxPool1d(kernel_size=4,
verify_model(torch.nn.MaxPool1d(kernel_size=4,
padding=2,
stride=2).eval(),
input_data)
......@@ -1363,3 +1363,8 @@ if __name__ == "__main__":
# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()
# More complex recurrent models
from lstm_test import custom_lstm_test
custom_lstm_test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment