Commit 55599b93 by Siju Committed by Yizhi Liu

[Relay]resize op compute and schedule (#2172)

parent 7af48f1a
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
"""Image network related operators.""" """Image network related operators."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .image import * from .image import *
from ._image import *
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..op import register_schedule, schedule_injective
# resize
register_schedule("image.resize", schedule_injective)
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
#include <topi/elemwise.h>
#include <topi/image/resize.h>
#include "../layout.h" #include "../layout.h"
#include "../op_common.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -40,6 +43,29 @@ bool ResizeRel(const Array<Type>& types, ...@@ -40,6 +43,29 @@ bool ResizeRel(const Array<Type>& types,
return true; return true;
} }
Array<Tensor> ResizeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
CHECK(param->layout == "NCHW" || param->layout == "NHWC");
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
Array<IndexExpr> oshape;
if (param->layout == "NCHW") {
oshape.push_back(out_ttype->shape[2]);
oshape.push_back(out_ttype->shape[3]);
} else if (param->layout == "NHWC") {
oshape.push_back(out_ttype->shape[1]);
oshape.push_back(out_ttype->shape[2]);
}
return Array<Tensor>{ topi::image::resize(inputs[0],
oshape,
param->layout,
param->align_corners,
param->method) };
}
// Positional relay function to create image operator // Positional relay function to create image operator
// used by frontend FFI. // used by frontend FFI.
...@@ -82,7 +108,9 @@ RELAY_REGISTER_OP("image.resize") ...@@ -82,7 +108,9 @@ RELAY_REGISTER_OP("image.resize")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5) .set_support_level(5)
.add_type_rel("Resize", ResizeRel); .add_type_rel("Resize", ResizeRel)
.set_attr<FTVMCompute>("FTVMCompute", ResizeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
""" Support level5 operator test cases. """ Support level5 operator test cases.
""" """
import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import ctx_list
import topi.testing
def test_resize_infer_type(): def test_resize_infer_type():
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
...@@ -17,6 +20,33 @@ def test_resize_infer_type(): ...@@ -17,6 +20,33 @@ def test_resize_infer_type():
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
def test_resize():
def verify_resize(dshape, scale, method, layout):
if layout == "NHWC":
size = (dshape[1] * scale, dshape[2] * scale)
else:
size = (dshape[2] * scale, dshape[3] * scale)
x_data = np.random.uniform(size=dshape).astype("float32")
if method == "BILINEAR":
ref_res = topi.testing.bilinear_resize_python(x_data, size, layout)
else:
ref_res = topi.testing.upsampling_python(x_data, scale, layout)
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.image.resize(x, size, layout, method, False)
assert "size=" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
func = relay.Function([x], z)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
for method in ["BILINEAR", "NEAREST_NEIGHBOR"]:
for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout)
def test_multibox_prior(): def test_multibox_prior():
sizes = (0.3, 1.5, 0.7) sizes = (0.3, 1.5, 0.7)
...@@ -74,5 +104,6 @@ def test_nms(): ...@@ -74,5 +104,6 @@ def test_nms():
if __name__ == "__main__": if __name__ == "__main__":
test_resize_infer_type() test_resize_infer_type()
test_resize()
test_multibox_prior() test_multibox_prior()
test_nms() test_nms()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment