Commit 00833e0f by Tianqi Chen

cleanup (#21)

parent 5cf08d6c
......@@ -98,7 +98,6 @@ NNVM_REGISTER_OP(__add_symbol__)
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr("inplace_pair", std::make_pair(0, 0))
.attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(cross_device_copy)
......
......@@ -49,20 +49,6 @@ def test_infer_shape():
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_shape():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
g._set_json_attr("shape_attr_key", "shape")
g = g.apply('InferShape')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_type():
x = sym.Variable('x')
y = sym.add(x, x, name='add1')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment