Commit e986f87e by Animesh Jain Committed by Tianqi Chen

Adding source types to C++ reduce functions (#1771)

parent 846d9ce0
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace tvm { namespace tvm {
Expr sum(Expr source, Array<IterVar> rdom) { Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y"); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Add::make(x, y); Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type()); Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducer combiner =
...@@ -18,7 +18,7 @@ Expr sum(Expr source, Array<IterVar> rdom) { ...@@ -18,7 +18,7 @@ Expr sum(Expr source, Array<IterVar> rdom) {
} }
Expr max(Expr source, Array<IterVar> rdom) { Expr max(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y"); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Max::make(x, y); Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min(); Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducer combiner =
...@@ -27,7 +27,7 @@ Expr max(Expr source, Array<IterVar> rdom) { ...@@ -27,7 +27,7 @@ Expr max(Expr source, Array<IterVar> rdom) {
} }
Expr min(Expr source, Array<IterVar> rdom) { Expr min(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y"); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Min::make(x, y); Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max(); Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducer combiner =
...@@ -36,7 +36,7 @@ Expr min(Expr source, Array<IterVar> rdom) { ...@@ -36,7 +36,7 @@ Expr min(Expr source, Array<IterVar> rdom) {
} }
Expr prod(Expr source, Array<IterVar> rdom) { Expr prod(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y"); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Mul::make(x, y); Expr result = ir::Mul::make(x, y);
Expr identity_element = make_one(source.type()); Expr identity_element = make_one(source.type());
ir::CommReducer combiner = ir::CommReducer combiner =
......
import tvm import tvm
from topi.nn.pooling import pool
def test_tensor(): def test_tensor():
m = tvm.var('m') m = tvm.var('m')
...@@ -185,6 +186,34 @@ def test_tensor_inputs(): ...@@ -185,6 +186,34 @@ def test_tensor_inputs():
assert tuple(y.op.input_tensors) == (x,) assert tuple(y.op.input_tensors) == (x,)
def test_tensor_pool():
def intrin_pool():
A = tvm.placeholder((64, 16, 16), name='A')
kh = tvm.reduce_axis((0, 3), name='kh')
kw = tvm.reduce_axis((0, 3), name='kw')
P = tvm.compute((64, 14, 14),
lambda c, oh, ow: tvm.max(A[c, oh + kh, ow + kw],
axis=[kh, kw]),
name='p')
def intrin_func(ins, outs):
dinp = ins[0]
dout = outs[0]
return tvm.call_packed("op", dinp, dout)
with tvm.build_config(offset_factor=1):
return tvm.decl_tensor_intrin(P.op, intrin_func)
A = tvm.placeholder((1, 64, 16, 16), name='A')
P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0),
pool_type='max')
s = tvm.create_schedule(P.op)
_, oh, _, _ = P.op.axis
intrin = intrin_pool()
s[P].tensorize(oh, intrin)
tvm.lower(s, [A, P])
if __name__ == "__main__": if __name__ == "__main__":
test_rank_zero() test_rank_zero()
test_tensor_inputs() test_tensor_inputs()
...@@ -199,3 +228,4 @@ if __name__ == "__main__": ...@@ -199,3 +228,4 @@ if __name__ == "__main__":
test_extern_multi_out() test_extern_multi_out()
test_tuple_inputs() test_tuple_inputs()
test_tuple_with_different_deps() test_tuple_with_different_deps()
test_tensor_pool()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment