Commit e986f87e by Animesh Jain Committed by Tianqi Chen

Adding source types to C++ reduce functions (#1771)

parent 846d9ce0
......@@ -9,7 +9,7 @@
namespace tvm {
Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type());
ir::CommReducer combiner =
......@@ -18,7 +18,7 @@ Expr sum(Expr source, Array<IterVar> rdom) {
}
Expr max(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min();
ir::CommReducer combiner =
......@@ -27,7 +27,7 @@ Expr max(Expr source, Array<IterVar> rdom) {
}
Expr min(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max();
ir::CommReducer combiner =
......@@ -36,7 +36,7 @@ Expr min(Expr source, Array<IterVar> rdom) {
}
Expr prod(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Mul::make(x, y);
Expr identity_element = make_one(source.type());
ir::CommReducer combiner =
......
import tvm
from topi.nn.pooling import pool
def test_tensor():
m = tvm.var('m')
......@@ -185,6 +186,34 @@ def test_tensor_inputs():
assert tuple(y.op.input_tensors) == (x,)
def test_tensor_pool():
def intrin_pool():
A = tvm.placeholder((64, 16, 16), name='A')
kh = tvm.reduce_axis((0, 3), name='kh')
kw = tvm.reduce_axis((0, 3), name='kw')
P = tvm.compute((64, 14, 14),
lambda c, oh, ow: tvm.max(A[c, oh + kh, ow + kw],
axis=[kh, kw]),
name='p')
def intrin_func(ins, outs):
dinp = ins[0]
dout = outs[0]
return tvm.call_packed("op", dinp, dout)
with tvm.build_config(offset_factor=1):
return tvm.decl_tensor_intrin(P.op, intrin_func)
A = tvm.placeholder((1, 64, 16, 16), name='A')
P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0),
pool_type='max')
s = tvm.create_schedule(P.op)
_, oh, _, _ = P.op.axis
intrin = intrin_pool()
s[P].tensorize(oh, intrin)
tvm.lower(s, [A, P])
if __name__ == "__main__":
test_rank_zero()
test_tensor_inputs()
......@@ -199,3 +228,4 @@ if __name__ == "__main__":
test_extern_multi_out()
test_tuple_inputs()
test_tuple_with_different_deps()
test_tensor_pool()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment