Commit 0ec70800 by Yizhi Liu Committed by Tianqi Chen

fix CorrectLayout for softmax & log_softmax (#1401)

parent a9a4329e
...@@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax) ...@@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
...@@ -404,7 +404,7 @@ NNVM_REGISTER_OP(log_softmax) ...@@ -404,7 +404,7 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -3,7 +3,6 @@ import nnvm.symbol as sym ...@@ -3,7 +3,6 @@ import nnvm.symbol as sym
import nnvm.graph as graph import nnvm.graph as graph
from nnvm.compiler import graph_attr from nnvm.compiler import graph_attr
# Level 1
def correct_layout(g, layout=None): def correct_layout(g, layout=None):
if isinstance(g, nnvm.symbol.Symbol): if isinstance(g, nnvm.symbol.Symbol):
g = graph.create(g) g = graph.create(g)
...@@ -19,6 +18,7 @@ def correct_layout(g, layout=None): ...@@ -19,6 +18,7 @@ def correct_layout(g, layout=None):
return g, ldict return g, ldict
# Level 1
def test_dense(): def test_dense():
x = sym.Variable("data", shape=(10, 20)) x = sym.Variable("data", shape=(10, 20))
y = sym.dense(x, units=30, name="fc") y = sym.dense(x, units=30, name="fc")
...@@ -169,6 +169,19 @@ def test_flatten(): ...@@ -169,6 +169,19 @@ def test_flatten():
assert(ldict["y"][0] == "__undef__") assert(ldict["y"][0] == "__undef__")
def test_softmax():
x = sym.Variable("x", shape=(10, 20, 10, 10))
y = sym.softmax(x, name="y")
g, ldict = correct_layout(y, "NCHW")
assert(ldict["x"][0] == "NCHW")
assert(ldict["y"][0] == "NCHW")
# second pass will insert layout transform
_, ldict = correct_layout(g, "NCHW16c")
assert(ldict["x"][0] == "NCHW16c")
assert(ldict["x_NCHW"][0] == "NCHW")
assert(ldict["y"][0] == "NCHW")
# Level 2 # Level 2
def test_conv2d(): def test_conv2d():
x = sym.Variable("data", shape=(1, 32, 512, 512)) x = sym.Variable("data", shape=(1, 32, 512, 512))
...@@ -327,6 +340,7 @@ if __name__ == "__main__": ...@@ -327,6 +340,7 @@ if __name__ == "__main__":
test_split() test_split()
test_batchnorm() test_batchnorm()
test_flatten() test_flatten()
test_softmax()
test_conv2d() test_conv2d()
test_conv2d_transpose() test_conv2d_transpose()
test_max_pool2d() test_max_pool2d()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment