Commit cf2f5197 by William Moses Committed by Tianqi Chen

Fix issue relating to serialization of reducer (#282)

parent cf4e7775
...@@ -108,6 +108,7 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -108,6 +108,7 @@ struct Reduce : public ExprNode<Reduce> {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("combiner", &combiner);
v->Visit("source", &source); v->Visit("source", &source);
v->Visit("axis", &axis); v->Visit("axis", &axis);
v->Visit("condition", &condition); v->Visit("condition", &condition);
......
...@@ -97,6 +97,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source, ...@@ -97,6 +97,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
return Expr(n); return Expr(n);
} }
TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce); TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt); TVM_REGISTER_NODE_TYPE(AttrStmt);
......
...@@ -24,7 +24,16 @@ def test_make_node(): ...@@ -24,7 +24,16 @@ def test_make_node():
assert AA.op == A.op assert AA.op == A.op
assert AA.value_index == A.value_index assert AA.value_index == A.value_index
def test_make_sum():
A = tvm.placeholder((2, 10), name='A')
k = tvm.reduce_axis((0,10), "k")
B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
json_str = tvm.save_json(B)
BB = tvm.load_json(json_str)
assert B.op.body[0].combiner.handle.value != 0
assert BB.op.body[0].combiner.handle.value != 0
if __name__ == "__main__": if __name__ == "__main__":
test_make_node() test_make_node()
test_const_saveload_json() test_const_saveload_json()
test_make_sum()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment