Unverified Commit 06e9542e by masahi Committed by GitHub

[Torch] Add initial control flow support (#4964)

* Add support for prim::If and prim::Loop with test cases

* rebase and fix tests

* add some comments

* simplifying, fix float cast

* parse -> convert

* recursivly retrive ops in get_all_op_names

* use multiple return values from block correctly, simplify loop convert

* choose dtype properly for zeros and ones

* simplifying, replace convert_inputs with _get_relay_input_vars

* fix for while loop with non input dependent init cond

* add assert on loop var update

* move the condition around

* better testing for seg models

* rebase fix, disable inception v3 in quant test as it is too slow to
load with torch-1.4 + torchvision 0.5

* simplify and add more comparison op converter
parent c0bc1882
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
"""PT: PyTorch frontend.""" """PT: PyTorch frontend."""
import itertools import itertools
import logging import logging
import sys
import numpy as np import numpy as np
...@@ -29,6 +30,7 @@ from tvm.ir import module as _module ...@@ -29,6 +30,7 @@ from tvm.ir import module as _module
from .. import analysis as _analysis from .. import analysis as _analysis
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ..loops import while_loop
from .common import get_relay_op from .common import get_relay_op
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value from .common import infer_value as _infer_value
...@@ -107,9 +109,8 @@ def _select(): ...@@ -107,9 +109,8 @@ def _select():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
dim = int(inputs[1]) dim = int(inputs[1])
index = int(inputs[2]) index = _wrap_const(inputs[2])
return _op.transform.take(data, index, axis=dim)
return _op.transform.take(data, _expr.const(index, dtype="int32"), axis=dim)
return _impl return _impl
def _ones(): def _ones():
...@@ -126,7 +127,10 @@ def _ones(): ...@@ -126,7 +127,10 @@ def _ones():
else: else:
assert "data type {} could not be parsed in ones op" % (type(data)) assert "data type {} could not be parsed in ones op" % (type(data))
return _op.full(_expr.const(1), shape, dtype=_convert_data_type(input_types[0])) dtype_map = {6: "float32", 3: "int32"}
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id])
return _impl return _impl
def _zeros(): def _zeros():
...@@ -143,7 +147,10 @@ def _zeros(): ...@@ -143,7 +147,10 @@ def _zeros():
else: else:
assert "data type {} could not be parsed in zeros op" % (type(data)) assert "data type {} could not be parsed in zeros op" % (type(data))
return _op.full(_expr.const(0), shape, dtype=_convert_data_type(input_types[0])) dtype_map = {6: "float32", 3: "int32"}
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id])
return _impl return _impl
def _relu(): def _relu():
...@@ -222,12 +229,10 @@ def _convolution(): ...@@ -222,12 +229,10 @@ def _convolution():
else: else:
assert "data type {} could not be parsed in conv op" % (type(weight)) assert "data type {} could not be parsed in conv op" % (type(weight))
# TODO: Add reshape when channel multiplier > 1. Pending PR #4644
channels = weight_shape[0] channels = weight_shape[0]
groups = int(inputs[8]) groups = int(inputs[8])
if groups > 1: if groups > 1:
# in torch, groups == in_channels for depth wise conv
channel_multiplier = channels // groups channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3]) new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
weight = _op.transform.reshape(weight, new_weight_shape) weight = _op.transform.reshape(weight, new_weight_shape)
...@@ -496,7 +501,7 @@ def _dropout(): ...@@ -496,7 +501,7 @@ def _dropout():
return _impl return _impl
def _reduce(name): def _reduce(name):
def _impl(inputs, attrs, params): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
return get_relay_op(name)(data) return get_relay_op(name)(data)
return _impl return _impl
...@@ -714,7 +719,6 @@ def _upsample(method): ...@@ -714,7 +719,6 @@ def _upsample(method):
return _impl return _impl
def _expand_as(): def _expand_as():
def _impl(inputs, input_types): def _impl(inputs, input_types):
# TODO: maybe fix this # TODO: maybe fix this
...@@ -724,6 +728,29 @@ def _expand_as(): ...@@ -724,6 +728,29 @@ def _expand_as():
return inputs[0] return inputs[0]
return _impl return _impl
def _neg():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.negative(data)
return _impl
def _tanh():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.tanh(data)
return _impl
def _Bool():
def _impl(inputs, input_types):
assert len(inputs) == 1
return inputs[0]
return _impl
def _Float():
def _impl(inputs, input_types):
assert len(inputs) == 1
return _op.cast(inputs[0], "float32")
return _impl
# Helper functions for operator implementation # Helper functions for operator implementation
...@@ -780,6 +807,11 @@ def _convert_elemwise_input(data, input_type): ...@@ -780,6 +807,11 @@ def _convert_elemwise_input(data, input_type):
else: else:
return data return data
def _wrap_const(c):
if not isinstance(c, _expr.Expr) and not isinstance(c, list):
return _expr.const(c)
return c
# Operator mappings # Operator mappings
_convert_map = { _convert_map = {
...@@ -845,7 +877,16 @@ _convert_map = { ...@@ -845,7 +877,16 @@ _convert_map = {
"aten::detach" : _identity(), "aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::expand_as" : _expand_as() "aten::expand_as" : _expand_as(),
"aten::lt" : _elemwise("less"),
"aten::gt" : _elemwise("greater"),
"aten::le" : _elemwise("less_equal"),
"aten::ge" : _elemwise("greater_equal"),
"aten::ne" : _elemwise("not_equal"),
"aten::Bool" : _Bool(),
"aten::Float" : _Float(),
"aten::neg" : _neg(),
"aten::tanh" : _tanh(),
} }
...@@ -894,7 +935,8 @@ def _report_missing_conversion(op_names): ...@@ -894,7 +935,8 @@ def _report_missing_conversion(op_names):
""" Check if all ops in an input graph are supported by TVM """ """ Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr", known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack", "prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"] "prim::TupleConstruct", "prim::TupleUnpack",
"prim::If", "prim::Loop"]
known_ops += list(_convert_map.keys()) known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys()) known_ops += list(qnn_torch.convert_map.keys())
...@@ -939,9 +981,13 @@ def _get_input_types(op_node): ...@@ -939,9 +981,13 @@ def _get_input_types(op_node):
input_node_kind = in_ty.kind() input_node_kind = in_ty.kind()
if input_node_kind == 'TensorType': if input_node_kind == 'TensorType':
if in_ty.scalarType() is None: if in_ty.scalarType() is None:
input_list_types.append(None) # Tensor's type can be unknown if we use torch.jit.script(...)
# Defaults to float for now
logging.warning("Untyped Tensor found, assume it is float")
input_list_types.append("float")
else: else:
input_list_types.append(in_ty.scalarType().lower()) input_list_types.append(in_ty.scalarType().lower())
elif input_node_kind == 'ListType': elif input_node_kind == 'ListType':
input_list_types.append(str(in_ty.getElementType()).lower()) input_list_types.append(str(in_ty.getElementType()).lower())
elif input_node_kind in ['IntType', 'FloatType', 'BoolType', elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
...@@ -1004,15 +1050,10 @@ def _get_operator_nodes(nodes): ...@@ -1004,15 +1050,10 @@ def _get_operator_nodes(nodes):
return ops return ops
def parse_inputs(graph_inputs, input_shapes): def _get_relay_input_vars(input_shapes):
""" Return Relay vars from torch input vars """ """ Return Relay vars from input shapes """
ir_inputs = list(graph_inputs) return {iname: _expr.var(iname, shape=ishape)
input_vars = {} for iname, ishape in input_shapes.items()}
for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
input_vars[input_name] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars
def get_use_chains(root_node, terminate=lambda _: False): def get_use_chains(root_node, terminate=lambda _: False):
...@@ -1055,7 +1096,7 @@ def get_attr_chains(root_getattr_node): ...@@ -1055,7 +1096,7 @@ def get_attr_chains(root_getattr_node):
return get_use_chains(root_getattr_node, terminate) return get_use_chains(root_getattr_node, terminate)
def parse_params(graph, state_dict): def convert_params(graph, state_dict):
""" """
Return Relay vars and TVM NDArrays for input parameters Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time A chain of prim::GetAttr nodes is processed one at a time
...@@ -1090,7 +1131,109 @@ def parse_params(graph, state_dict): ...@@ -1090,7 +1131,109 @@ def parse_params(graph, state_dict):
return params, param_tensors, packed_param_map return params, param_tensors, packed_param_map
def parse_operators(operators, outputs, output_index_map, ret_name): def convert_block(block, outputs, output_index_map):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
return convert_operators(ops, outputs, output_index_map, ret_names)
def convert_if(if_node, outputs, output_index_map):
""" Translate Torch prim::If to Relay If """
cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
blocks = list(if_node.blocks())
true_branch = convert_block(blocks[0], outputs, output_index_map)
false_branch = convert_block(blocks[1], outputs, output_index_map)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])
def convert_loop(loop_node, outputs, output_index_map):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
inode = ivalue.node()
if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode))
var_name = ivalue.debugName()
assert var_name in output_index_map
return _wrap_const(outputs[output_index_map[var_name]])
# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
# The first input: %max_trip_count
# The second input: %initial_condition
# The rest of input: loop variables
max_loop_count = get_input(0)
init_cond = get_input(1)
num_loop_var = len(list(loop_node.inputs())) - 2
init_vals = [get_input(i + 2) for i in range(num_loop_var)]
# while loop has always max_loop_count being int64 max
# max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again
is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
_get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)
body_block = list(loop_node.blocks())[0]
block_input_names = _get_input_names(body_block)
def cond(*current_vals):
i = current_vals[0]
if is_while_loop:
return _op.equal(i, _expr.const(True, 'bool'))
return _op.less(i, max_loop_count)
def body(*current_vals):
# Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names):
outputs[output_index_map[iname]] = current_vals[i]
block_outputs = convert_block(body_block, outputs, output_index_map)
if not is_while_loop:
# iter var increment implicit in torch, so do it manually
# for while loop, block_outputs[0] is already a boolean,
# the result of termination check
incr = _expr.const(1, dtype="int32")
block_outputs[0] = current_vals[0] + incr
return block_outputs
def get_var(name, val):
if isinstance(val, _expr.Constant):
return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
return _expr.var(name)
if is_while_loop:
loop_iter_dtype = "bool"
# while loop with non input dependent condition such as while i < 10:
# init_cond is int, need to cast to bool to type check
if isinstance(init_cond, _expr.Constant):
init_cond = _op.cast(init_cond, "bool")
init_loop_iter_val = init_cond
else:
loop_iter_dtype = "int32"
# always count from 0
init_loop_iter_val = _expr.const(0, dtype="int32")
name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
_update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)
loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
loop_val = loop(init_loop_iter_val, *init_vals)
# The first element is a loop counter or boolean condition, ignore it
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
def convert_operators(operators, outputs, output_index_map, ret_names):
""" Convert each Torch IR operators to Relay equivalent """ """ Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators: for node_name, op_node in operators:
operator = op_node.kind() operator = op_node.kind()
...@@ -1110,17 +1253,35 @@ def parse_operators(operators, outputs, output_index_map, ret_name): ...@@ -1110,17 +1253,35 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
unpacked_names = _get_output_names(op_node) unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]), _update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map) outputs, output_index_map)
elif operator == "prim::If":
if_out = convert_if(op_node, outputs, output_index_map)
output_index_map[node_name] = len(outputs)
outputs.append(if_out)
elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs, output_index_map)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
_update_outputs_from_pairs(zip(unpacked_names, loop_out),
outputs, output_index_map)
else: else:
output_index_map[node_name] = len(outputs) output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator] relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, _get_input_types(op_node))) outputs.append(relay_op(inputs, _get_input_types(op_node)))
return outputs[output_index_map[ret_name]] return [_wrap_const(outputs[output_index_map[ret_name]])
for ret_name in ret_names]
def get_all_op_names(graph): def get_all_op_names(graph):
""" Return all operator names in the input graph """ """ Return all operator names in the input graph """
return set(node.kind() for node in graph.nodes()) nodes = list(graph.nodes())
prim_with_blocks = ["prim::If", "prim::Loop"]
for prim in prim_with_blocks:
prim_nodes = graph.findAllNodes(prim, recurse=True)
for prim_node in prim_nodes:
for block in prim_node.blocks():
nodes += block.nodes()
return set(node.kind() for node in nodes)
def get_graph_input_names(script_module): def get_graph_input_names(script_module):
...@@ -1167,14 +1328,14 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): ...@@ -1167,14 +1328,14 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
_check_input_names(script_module, input_shapes) _check_input_names(script_module, input_shapes)
params = script_module.state_dict() params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes) input_vars = _get_relay_input_vars(input_shapes)
param_vars, tensors, packed_param_map = parse_params(graph, params) param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
input_vars.update(param_vars) input_vars.update(param_vars)
outputs = list(input_vars.values()) outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0] ret_name = _get_input_names(graph.return_node())
# For quantized models # For quantized models
if "aten::quantize_per_tensor" in op_names: if "aten::quantize_per_tensor" in op_names:
...@@ -1186,8 +1347,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): ...@@ -1186,8 +1347,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
qnn_torch.add_quant_params(tvm_params, weight_quant_params) qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map) _convert_map.update(qnn_torch.convert_map)
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name) output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body) func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
return _module.IRModule.from_expr(func), tvm_params return _module.IRModule.from_expr(func), tvm_params
...@@ -347,7 +347,8 @@ def test_quantized_imagenet(): ...@@ -347,7 +347,8 @@ def test_quantized_imagenet():
qmodels += [ qmodels += [
("resnet18", qresnet.resnet18(pretrained=True), per_channel), ("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
("inception_v3", qinception.inception_v3(pretrained=True), per_channel), # disable inception test for now, since loading it takes ~5min on torchvision-0.5
#("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel), ("googlenet", qgooglenet(pretrained=True), per_channel),
] ]
......
...@@ -756,7 +756,6 @@ def test_vgg11_bn(): ...@@ -756,7 +756,6 @@ def test_vgg11_bn():
verify_model("vgg11_bn") verify_model("vgg11_bn")
""" """
def test_custom_conversion_map(): def test_custom_conversion_map():
def get_roi_align(): def get_roi_align():
pool_size = 5 pool_size = 5
...@@ -801,11 +800,193 @@ def test_segmentaton_models(): ...@@ -801,11 +800,193 @@ def test_segmentaton_models():
inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)] inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
for model in [fcn, deeplab]: verify_model(SegmentationModelWrapper(fcn.eval()), inp)
# depthwise + dilated covolution not supported on x86
# see https://github.com/apache/incubator-tvm/issues/4962 # depthwise + dilated covolution not supported on x86
verify_model(SegmentationModelWrapper(model.eval()), inp, # see https://github.com/apache/incubator-tvm/issues/4962
ctx_list=[("cuda", tvm.gpu(0))]) cuda_ctx = ("cuda", tvm.gpu(0))
if cuda_ctx[1].exist:
verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))
inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
for input_name in input_names]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()
for name, inp in zip(input_names, inputs):
params[name] = inp.numpy()
op_res = evaluator(**params)
with torch.no_grad():
pt_result = pt_model(*inputs)
if not isinstance(pt_result, torch.Tensor):
tvm_res = op_res.asnumpy().item()
assert pt_result == tvm_res
else:
tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)
def test_control_flow():
class SimpleIf(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, inp):
if inp.sum() > 0.:
output = self.weight + inp
else:
output = self.weight - inp
return output
class NestedIf(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, inp):
if inp.sum() > 0.:
if inp.mean() > 0.:
output = self.weight + inp
else:
output = self.weight - inp
else:
if inp.mean() >= 0.:
output = self.weight * inp
else:
output = self.weight / inp
return output
class ScalarLoop(torch.nn.Module):
def forward(self, inp):
a = 0
for i in range(inp.size(0)):
b = i * i
b = b + 1
a += b
if a != 0:
a += 1
else:
a += 2
return a
class SimpleLoop(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * 2.
c = a + b
a += c
return a
class LoopWithIf(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * 2.
b = a + b
if b.sum() > 0.0:
a += b
else:
a -= b
return a
class NestedLoop(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * float(i)
for j in range(inp.size(1)):
a += b * float(j)
return a
class SimpleScalarWhileLoop(torch.nn.Module):
def forward(self, inp):
a = 1
i = 0
while i <= inp.size(0):
a += i
i += 2
i = 0
# also test constant init cond
while i < 10:
a += i
i += 3
return a
class SimpleWhileLoop(torch.nn.Module):
def forward(self, inp):
a = inp
i = 0
while i < inp.size(0):
a += a * float(i) * 2.0
i += 1
return a
models = [
SimpleIf(10, 20),
NestedIf(10, 20),
ScalarLoop(),
SimpleLoop(),
LoopWithIf(),
SimpleScalarWhileLoop(),
SimpleWhileLoop(),
NestedLoop(),
]
for pt_model in models:
verify_script_model(pt_model.eval(), [(10, 20)])
def test_simple_rnn():
# The mixed tracing and scripting example from
# https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#mixing-scripting-and-tracing
class DecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class Cell(torch.nn.Module):
def __init__(self, dg):
super(Cell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
class RNNLoop(torch.nn.Module):
def __init__(self):
super().__init__()
x = torch.rand(10, 4, dtype=torch.float)
h = torch.rand(10, 4, dtype=torch.float)
self.cell = torch.jit.trace(Cell(DecisionGate()), (x, h))
def forward(self, xs):
h = torch.zeros(10, 4, dtype=torch.float)
y = torch.zeros(10, 4, dtype=torch.float)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y
verify_script_model(RNNLoop().eval(), [(10, 10, 4)])
if __name__ == "__main__": if __name__ == "__main__":
...@@ -860,3 +1041,7 @@ if __name__ == "__main__": ...@@ -860,3 +1041,7 @@ if __name__ == "__main__":
test_quantized_modules() test_quantized_modules()
test_quantized_imagenet() test_quantized_imagenet()
# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment