Commit b7b74228 by Yan Huang Committed by Tianqi Chen

[FIX] several bugs found when using NNVM (#391)

parent a523311d
......@@ -146,11 +146,9 @@ inline NodeEntry MakeNode(
NodePtr p = Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
if (attrs.size() != 0) {
p->attrs.dict = attrs;
if (p->attrs.op->attr_parser) {
p->attrs.op->attr_parser(&(p->attrs));
}
p->attrs.dict = attrs;
if (p->attrs.op->attr_parser) {
p->attrs.op->attr_parser(&(p->attrs));
}
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
......
......@@ -151,6 +151,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
GraphAllocator* allocator) {
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
static auto& finplace_identity = Op::GetAttr<FInplaceIdentity>("FInplaceIdentity");
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
// Get reference
auto &storage = *storage_ptr;
......@@ -189,10 +190,13 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
auto sid_out = storage[eid_out];
auto sid_in = storage[eid_in];
bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 &&
fignore_inputs[inode.source->op()](
inode.source->attrs).size() == inode.source->num_inputs());
if (taken[kv.first] == false &&
sid_out == GraphAllocator::kBadStorageID &&
sid_in >= 0 &&
(storage_ref_count[sid_in] == 1 || identity[ipair]) &&
(storage_ref_count[sid_in] == 1 && !ignore_all_inputs || identity[ipair]) &&
entry_ref_count[eid_out] > 0 &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in]) {
......@@ -230,7 +234,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
storage[eid] = sid;
}
// check if certain inputs is ignored.
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
std::vector<uint32_t> ignore_inputs;
if (fignore_inputs.count(inode.source->op()) != 0) {
ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs);
......
......@@ -134,8 +134,11 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero",
{n->inputs[0]});
return std::vector<NodeEntry>{
MakeNode("greater", n->attrs.name + "_grad",
{n->inputs[0], zero}, {{"exclude", "true"}})
MakeNode("elemwise_mul", n->attrs.name + "_grad", {
ograds[0],
MakeNode("greater", n->attrs.name + "_grad_mask",
{n->inputs[0], zero}, {{"exclude", "true"}})
})
};
})
.set_support_level(1);
......@@ -249,7 +252,7 @@ axis to be the last item in the input shape.
.set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
return 1;
})
.set_attr<FMutateInputs>("FListMutateInputs", [](const NodeAttrs& attrs) {
.set_attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{3, 4};
})
.set_support_level(1);
......
......@@ -33,9 +33,9 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(lshape[lshape.ndim() - 1], rshape[0])
<< "dot shape inconsistent: " << lshape << " X " << rshape;
TShape oshape(lshape.ndim() + rshape.ndim() - 1);
for (size_t i = 0; i < lshape.ndim() - 1; i++) oshape[i] = lshape[i];
for (size_t i = 1; i < rshape.ndim(); i++) oshape[i + lshape.ndim() - 1] = rshape[i];
TShape oshape(lshape.ndim() + rshape.ndim() - 2);
for (int i = 0; i < lshape.ndim() - 1; i++) oshape[i] = lshape[i];
for (int i = 1; i < rshape.ndim(); i++) oshape[i + lshape.ndim() - 2] = rshape[i];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
return true;
......
......@@ -574,7 +574,7 @@ the input array into an output array with the same shape as the second input arr
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.")
.add_argument("shape_like", "Tensor", "Input data.")
.set_num_inputs(1)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FInferShape>(
"FInferShape", [](const NodeAttrs& attrs,
......@@ -585,7 +585,7 @@ the input array into an output array with the same shape as the second input arr
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, in_attrs->at(1));
return true;
})
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......
......@@ -23,6 +23,36 @@ def test_dense():
assert(sdict["fc_bias"][0] == [30])
def test_matmul():
a = sym.Variable('a', shape=(10, 20))
b = sym.Variable('b', shape=(20, 30))
c = sym.matmul(a, b, name="matmul")
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
a = sym.Variable('a', shape=(20, 10))
c = sym.matmul(a, b, name="matmul", transpose_a=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
b = sym.Variable('b', shape=(30, 20))
c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
a = sym.Variable('a', shape=(10, 20))
c = sym.matmul(a, b, name="matmul", transpose_b=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 30])
a = sym.Variable('a', shape=(10, 20, 30))
b = sym.Variable('b', shape=(30, 40, 50))
c = sym.matmul(a, b, name="matmul")
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 20, 40, 50])
a = sym.Variable('a', shape=(30, 20, 10))
b = sym.Variable('b', shape=(50, 40, 30))
c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
sdict = infer_shape(c)
assert(sdict["matmul"][0] == [10, 20, 40, 50])
def test_concatenate():
x1 = sym.Variable("x", shape=(10, 20))
x2 = sym.Variable("y", shape=(10, 30))
......@@ -275,6 +305,7 @@ def test_reduce():
if __name__ == "__main__":
test_expand_dims()
test_dense()
test_matmul()
test_concatenate()
test_split()
test_batchnorm()
......
......@@ -6,6 +6,13 @@ def test_dense():
y = sym.dense(x, units=30, name="fc")
assert y.list_input_names() == ["x", "fc_weight", "fc_bias"]
def test_batch_norm():
x = sym.Variable('x')
y = sym.dense(x, units=30, name="fc")
z = sym.batch_norm(x, name='bn')
assert z.list_input_names('aux_state') == ['bn_moving_mean', 'bn_moving_var']
assert z.list_input_names('read_only') == ['x', 'bn_gamma', 'bn_beta']
def test_compose():
x = sym.Variable('x')
z = sym.Variable('z')
......@@ -51,3 +58,4 @@ if __name__ == "__main__":
test_copy()
test_default_input()
test_compose()
test_batch_norm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment