Unverified Commit 06e9542e by masahi Committed by GitHub

[Torch] Add initial control flow support (#4964)

* Add support for prim::If and prim::Loop with test cases

* rebase and fix tests

* add some comments

* simplifying, fix float cast

* parse -> convert

* recursivly retrive ops in get_all_op_names

* use multiple return values from block correctly, simplify loop convert

* choose dtype properly for zeros and ones

* simplifying, replace convert_inputs with _get_relay_input_vars

* fix for while loop with non input dependent init cond

* add assert on loop var update

* move the condition around

* better testing for seg models

* rebase fix, disable inception v3 in quant test as it is too slow to
load with torch-1.4 + torchvision 0.5

* simplify and add more comparison op converter
parent c0bc1882
......@@ -347,7 +347,8 @@ def test_quantized_imagenet():
qmodels += [
("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
# disable inception test for now, since loading it takes ~5min on torchvision-0.5
#("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel),
]
......
......@@ -756,7 +756,6 @@ def test_vgg11_bn():
verify_model("vgg11_bn")
"""
def test_custom_conversion_map():
def get_roi_align():
pool_size = 5
......@@ -801,11 +800,193 @@ def test_segmentaton_models():
inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
for model in [fcn, deeplab]:
# depthwise + dilated covolution not supported on x86
# see https://github.com/apache/incubator-tvm/issues/4962
verify_model(SegmentationModelWrapper(model.eval()), inp,
ctx_list=[("cuda", tvm.gpu(0))])
verify_model(SegmentationModelWrapper(fcn.eval()), inp)
# depthwise + dilated covolution not supported on x86
# see https://github.com/apache/incubator-tvm/issues/4962
cuda_ctx = ("cuda", tvm.gpu(0))
if cuda_ctx[1].exist:
verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))
inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
for input_name in input_names]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()
for name, inp in zip(input_names, inputs):
params[name] = inp.numpy()
op_res = evaluator(**params)
with torch.no_grad():
pt_result = pt_model(*inputs)
if not isinstance(pt_result, torch.Tensor):
tvm_res = op_res.asnumpy().item()
assert pt_result == tvm_res
else:
tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)
def test_control_flow():
class SimpleIf(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, inp):
if inp.sum() > 0.:
output = self.weight + inp
else:
output = self.weight - inp
return output
class NestedIf(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, inp):
if inp.sum() > 0.:
if inp.mean() > 0.:
output = self.weight + inp
else:
output = self.weight - inp
else:
if inp.mean() >= 0.:
output = self.weight * inp
else:
output = self.weight / inp
return output
class ScalarLoop(torch.nn.Module):
def forward(self, inp):
a = 0
for i in range(inp.size(0)):
b = i * i
b = b + 1
a += b
if a != 0:
a += 1
else:
a += 2
return a
class SimpleLoop(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * 2.
c = a + b
a += c
return a
class LoopWithIf(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * 2.
b = a + b
if b.sum() > 0.0:
a += b
else:
a -= b
return a
class NestedLoop(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(inp.size(0)):
b = a * float(i)
for j in range(inp.size(1)):
a += b * float(j)
return a
class SimpleScalarWhileLoop(torch.nn.Module):
def forward(self, inp):
a = 1
i = 0
while i <= inp.size(0):
a += i
i += 2
i = 0
# also test constant init cond
while i < 10:
a += i
i += 3
return a
class SimpleWhileLoop(torch.nn.Module):
def forward(self, inp):
a = inp
i = 0
while i < inp.size(0):
a += a * float(i) * 2.0
i += 1
return a
models = [
SimpleIf(10, 20),
NestedIf(10, 20),
ScalarLoop(),
SimpleLoop(),
LoopWithIf(),
SimpleScalarWhileLoop(),
SimpleWhileLoop(),
NestedLoop(),
]
for pt_model in models:
verify_script_model(pt_model.eval(), [(10, 20)])
def test_simple_rnn():
# The mixed tracing and scripting example from
# https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#mixing-scripting-and-tracing
class DecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class Cell(torch.nn.Module):
def __init__(self, dg):
super(Cell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
class RNNLoop(torch.nn.Module):
def __init__(self):
super().__init__()
x = torch.rand(10, 4, dtype=torch.float)
h = torch.rand(10, 4, dtype=torch.float)
self.cell = torch.jit.trace(Cell(DecisionGate()), (x, h))
def forward(self, xs):
h = torch.zeros(10, 4, dtype=torch.float)
y = torch.zeros(10, 4, dtype=torch.float)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y
verify_script_model(RNNLoop().eval(), [(10, 10, 4)])
if __name__ == "__main__":
......@@ -860,3 +1041,7 @@ if __name__ == "__main__":
test_quantized_modules()
test_quantized_imagenet()
# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment