Unverified Commit b64a843a by masahi Committed by GitHub

[Torch, QNN] Add missing upcast to uint8 avg_pool conversion (#5089)

* add missing upcast to avgpool

* add avg pool test
parent e5c24d7e
...@@ -172,7 +172,7 @@ def _adaptive_avg_2d(): ...@@ -172,7 +172,7 @@ def _adaptive_avg_2d():
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
if input_types[0] == "quint8": if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func) return qnn_torch.apply_with_upcast(data, func)
return func(data) return func(data)
...@@ -484,14 +484,22 @@ def _avg_pool2d(): ...@@ -484,14 +484,22 @@ def _avg_pool2d():
ceil_mode = int(inputs[4]) ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5]) count_include_pad = int(inputs[5])
return _op.nn.avg_pool2d(data, def func(x):
return _op.nn.avg_pool2d(x,
pool_size=pool_size, pool_size=pool_size,
strides=strides, strides=strides,
padding=padding, padding=padding,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
count_include_pad=count_include_pad) count_include_pad=count_include_pad)
if input_types[0] == "quint8":
return qnn_torch.apply_with_upcast(data, func)
return func(data)
return _impl return _impl
def _dropout(): def _dropout():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
......
...@@ -359,10 +359,9 @@ def add_quant_params(params, quant_params): ...@@ -359,10 +359,9 @@ def add_quant_params(params, quant_params):
params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias) params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
def quantized_adaptive_avg_2d(data, func_fp32): def apply_with_upcast(data, func):
# this follows tflite impl
inp = _op.cast(data, dtype="int32") inp = _op.cast(data, dtype="int32")
out = func_fp32(inp) out = func(inp)
return _op.cast(out, "uint8") return _op.cast(out, "uint8")
......
...@@ -218,7 +218,6 @@ class MulScalarNegative(nn.Module): ...@@ -218,7 +218,6 @@ class MulScalarNegative(nn.Module):
class UpsamplingBilinear(nn.Module): class UpsamplingBilinear(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.relu = QuantWrapper(nn.ReLU())
self.quant = QuantStub() self.quant = QuantStub()
self.dequant = DeQuantStub() self.dequant = DeQuantStub()
...@@ -233,12 +232,25 @@ class UpsamplingBilinear(nn.Module): ...@@ -233,12 +232,25 @@ class UpsamplingBilinear(nn.Module):
pass pass
class AvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.pool = QuantWrapper(nn.AvgPool2d(kernel_size=2))
def forward(self, x):
return self.pool(x)
def fuse_model(self):
pass
def test_quantized_modules(): def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224) imagenet_ishape = (1, 3, 224, 224)
qmodules = [ qmodules = [
("relu", imagenet_ishape, ReLU(), False), ("relu", imagenet_ishape, ReLU(), False),
("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False), ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
("avgpool", imagenet_ishape, AvgPool2d(), False),
] ]
for per_channel in [False, True]: for per_channel in [False, True]:
...@@ -276,7 +288,6 @@ def test_quantized_modules(): ...@@ -276,7 +288,6 @@ def test_quantized_modules():
pt_result = script_module(inp.clone()).numpy() pt_result = script_module(inp.clone()).numpy()
input_name = get_graph_input_names(script_module)[0] input_name = get_graph_input_names(script_module)[0]
runtime = get_tvm_runtime(script_module, input_name, ishape) runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy()) runtime.set_input(input_name, inp.numpy().copy())
runtime.run() runtime.run()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment