Commit b7b74228 by Yan Huang Committed by Tianqi Chen

[FIX] several bugs found when using NNVM (#391)

parent a523311d
...@@ -146,11 +146,9 @@ inline NodeEntry MakeNode( ...@@ -146,11 +146,9 @@ inline NodeEntry MakeNode(
NodePtr p = Node::Create(); NodePtr p = Node::Create();
p->attrs.op = nnvm::Op::Get(op_name); p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name); p->attrs.name = std::move(node_name);
if (attrs.size() != 0) { p->attrs.dict = attrs;
p->attrs.dict = attrs; if (p->attrs.op->attr_parser) {
if (p->attrs.op->attr_parser) { p->attrs.op->attr_parser(&(p->attrs));
p->attrs.op->attr_parser(&(p->attrs));
}
} }
p->inputs = std::move(inputs); p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0}; return NodeEntry{p, 0, 0};
......
...@@ -151,6 +151,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, ...@@ -151,6 +151,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
GraphAllocator* allocator) { GraphAllocator* allocator) {
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption"); static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
static auto& finplace_identity = Op::GetAttr<FInplaceIdentity>("FInplaceIdentity"); static auto& finplace_identity = Op::GetAttr<FInplaceIdentity>("FInplaceIdentity");
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
// Get reference // Get reference
auto &storage = *storage_ptr; auto &storage = *storage_ptr;
...@@ -189,10 +190,13 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, ...@@ -189,10 +190,13 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
auto sid_out = storage[eid_out]; auto sid_out = storage[eid_out];
auto sid_in = storage[eid_in]; auto sid_in = storage[eid_in];
bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 &&
fignore_inputs[inode.source->op()](
inode.source->attrs).size() == inode.source->num_inputs());
if (taken[kv.first] == false && if (taken[kv.first] == false &&
sid_out == GraphAllocator::kBadStorageID && sid_out == GraphAllocator::kBadStorageID &&
sid_in >= 0 && sid_in >= 0 &&
(storage_ref_count[sid_in] == 1 || identity[ipair]) && (storage_ref_count[sid_in] == 1 && !ignore_all_inputs || identity[ipair]) &&
entry_ref_count[eid_out] > 0 && entry_ref_count[eid_out] > 0 &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in]) { dtype_vec[eid_out] == dtype_vec[eid_in]) {
...@@ -230,7 +234,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, ...@@ -230,7 +234,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
storage[eid] = sid; storage[eid] = sid;
} }
// check if certain inputs is ignored. // check if certain inputs is ignored.
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
std::vector<uint32_t> ignore_inputs; std::vector<uint32_t> ignore_inputs;
if (fignore_inputs.count(inode.source->op()) != 0) { if (fignore_inputs.count(inode.source->op()) != 0) {
ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs); ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs);
......
...@@ -134,8 +134,11 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu) ...@@ -134,8 +134,11 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero", NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero",
{n->inputs[0]}); {n->inputs[0]});
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
MakeNode("greater", n->attrs.name + "_grad", MakeNode("elemwise_mul", n->attrs.name + "_grad", {
{n->inputs[0], zero}, {{"exclude", "true"}}) ograds[0],
MakeNode("greater", n->attrs.name + "_grad_mask",
{n->inputs[0], zero}, {{"exclude", "true"}})
})
}; };
}) })
.set_support_level(1); .set_support_level(1);
...@@ -249,7 +252,7 @@ axis to be the last item in the input shape. ...@@ -249,7 +252,7 @@ axis to be the last item in the input shape.
.set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { .set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
return 1; return 1;
}) })
.set_attr<FMutateInputs>("FListMutateInputs", [](const NodeAttrs& attrs) { .set_attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{3, 4}; return std::vector<uint32_t>{3, 4};
}) })
.set_support_level(1); .set_support_level(1);
......
...@@ -33,9 +33,9 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, ...@@ -33,9 +33,9 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(lshape[lshape.ndim() - 1], rshape[0]) CHECK_EQ(lshape[lshape.ndim() - 1], rshape[0])
<< "dot shape inconsistent: " << lshape << " X " << rshape; << "dot shape inconsistent: " << lshape << " X " << rshape;
TShape oshape(lshape.ndim() + rshape.ndim() - 1); TShape oshape(lshape.ndim() + rshape.ndim() - 2);
for (size_t i = 0; i < lshape.ndim() - 1; i++) oshape[i] = lshape[i]; for (int i = 0; i < lshape.ndim() - 1; i++) oshape[i] = lshape[i];
for (size_t i = 1; i < rshape.ndim(); i++) oshape[i + lshape.ndim() - 1] = rshape[i]; for (int i = 1; i < rshape.ndim(); i++) oshape[i + lshape.ndim() - 2] = rshape[i];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
return true; return true;
......
...@@ -574,7 +574,7 @@ the input array into an output array with the same shape as the second input arr ...@@ -574,7 +574,7 @@ the input array into an output array with the same shape as the second input arr
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.add_argument("shape_like", "Tensor", "Input data.") .add_argument("shape_like", "Tensor", "Input data.")
.set_num_inputs(1) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>( .set_attr<FInferShape>(
"FInferShape", [](const NodeAttrs& attrs, "FInferShape", [](const NodeAttrs& attrs,
...@@ -585,7 +585,7 @@ the input array into an output array with the same shape as the second input arr ...@@ -585,7 +585,7 @@ the input array into an output array with the same shape as the second input arr
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, in_attrs->at(1)); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, in_attrs->at(1));
return true; return true;
}) })
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
......
...@@ -23,6 +23,36 @@ def test_dense(): ...@@ -23,6 +23,36 @@ def test_dense():
assert(sdict["fc_bias"][0] == [30]) assert(sdict["fc_bias"][0] == [30])
def test_matmul():
a = sym.Variable('a', shape=(10, 20))
b = sym.Variable('b', shape=(20, 30))
c = sym.matmul(a, b, name="matmul")
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
a = sym.Variable('a', shape=(20, 10))
c = sym.matmul(a, b, name="matmul", transpose_a=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
b = sym.Variable('b', shape=(30, 20))
c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
a = sym.Variable('a', shape=(10, 20))
c = sym.matmul(a, b, name="matmul", transpose_b=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
a = sym.Variable('a', shape=(10, 20, 30))
b = sym.Variable('b', shape=(30, 40, 50))
c = sym.matmul(a, b, name="matmul")
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 20, 40, 50])
a = sym.Variable('a', shape=(30, 20, 10))
b = sym.Variable('b', shape=(50, 40, 30))
c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 20, 40, 50])
def test_concatenate(): def test_concatenate():
x1 = sym.Variable("x", shape=(10, 20)) x1 = sym.Variable("x", shape=(10, 20))
x2 = sym.Variable("y", shape=(10, 30)) x2 = sym.Variable("y", shape=(10, 30))
...@@ -275,6 +305,7 @@ def test_reduce(): ...@@ -275,6 +305,7 @@ def test_reduce():
if __name__ == "__main__": if __name__ == "__main__":
test_expand_dims() test_expand_dims()
test_dense() test_dense()
test_matmul()
test_concatenate() test_concatenate()
test_split() test_split()
test_batchnorm() test_batchnorm()
......
...@@ -6,6 +6,13 @@ def test_dense(): ...@@ -6,6 +6,13 @@ def test_dense():
y = sym.dense(x, units=30, name="fc") y = sym.dense(x, units=30, name="fc")
assert y.list_input_names() == ["x", "fc_weight", "fc_bias"] assert y.list_input_names() == ["x", "fc_weight", "fc_bias"]
def test_batch_norm():
x = sym.Variable('x')
y = sym.dense(x, units=30, name="fc")
z = sym.batch_norm(x, name='bn')
assert z.list_input_names('aux_state') == ['bn_moving_mean', 'bn_moving_var']
assert z.list_input_names('read_only') == ['x', 'bn_gamma', 'bn_beta']
def test_compose(): def test_compose():
x = sym.Variable('x') x = sym.Variable('x')
z = sym.Variable('z') z = sym.Variable('z')
...@@ -51,3 +58,4 @@ if __name__ == "__main__": ...@@ -51,3 +58,4 @@ if __name__ == "__main__":
test_copy() test_copy()
test_default_input() test_default_input()
test_compose() test_compose()
test_batch_norm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment