Unverified Commit b64a843a by masahi Committed by GitHub

[Torch, QNN] Add missing upcast to uint8 avg_pool conversion (#5089)

* add missing upcast to avgpool

* add avg pool test
parent e5c24d7e
......@@ -172,7 +172,7 @@ def _adaptive_avg_2d():
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func)
return qnn_torch.apply_with_upcast(data, func)
return func(data)
......@@ -484,14 +484,22 @@ def _avg_pool2d():
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])
return _op.nn.avg_pool2d(data,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
def func(x):
return _op.nn.avg_pool2d(x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
if input_types[0] == "quint8":
return qnn_torch.apply_with_upcast(data, func)
return func(data)
return _impl
def _dropout():
def _impl(inputs, input_types):
data = inputs[0]
......
......@@ -359,10 +359,9 @@ def add_quant_params(params, quant_params):
params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
def quantized_adaptive_avg_2d(data, func_fp32):
# this follows tflite impl
def apply_with_upcast(data, func):
inp = _op.cast(data, dtype="int32")
out = func_fp32(inp)
out = func(inp)
return _op.cast(out, "uint8")
......
......@@ -218,7 +218,6 @@ class MulScalarNegative(nn.Module):
class UpsamplingBilinear(nn.Module):
def __init__(self):
super().__init__()
self.relu = QuantWrapper(nn.ReLU())
self.quant = QuantStub()
self.dequant = DeQuantStub()
......@@ -233,12 +232,25 @@ class UpsamplingBilinear(nn.Module):
pass
class AvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.pool = QuantWrapper(nn.AvgPool2d(kernel_size=2))
def forward(self, x):
return self.pool(x)
def fuse_model(self):
pass
def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224)
qmodules = [
("relu", imagenet_ishape, ReLU(), False),
("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
("avgpool", imagenet_ishape, AvgPool2d(), False),
]
for per_channel in [False, True]:
......@@ -276,7 +288,6 @@ def test_quantized_modules():
pt_result = script_module(inp.clone()).numpy()
input_name = get_graph_input_names(script_module)[0]
runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment