Commit 0ec70800 by Yizhi Liu Committed by Tianqi Chen

fix CorrectLayout for softmax & log_softmax (#1401)

parent a9a4329e
......@@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
......@@ -404,7 +404,7 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......
......@@ -3,7 +3,6 @@ import nnvm.symbol as sym
import nnvm.graph as graph
from nnvm.compiler import graph_attr
# Level 1
def correct_layout(g, layout=None):
if isinstance(g, nnvm.symbol.Symbol):
g = graph.create(g)
......@@ -19,6 +18,7 @@ def correct_layout(g, layout=None):
return g, ldict
# Level 1
def test_dense():
x = sym.Variable("data", shape=(10, 20))
y = sym.dense(x, units=30, name="fc")
......@@ -169,6 +169,19 @@ def test_flatten():
assert(ldict["y"][0] == "__undef__")
def test_softmax():
x = sym.Variable("x", shape=(10, 20, 10, 10))
y = sym.softmax(x, name="y")
g, ldict = correct_layout(y, "NCHW")
assert(ldict["x"][0] == "NCHW")
assert(ldict["y"][0] == "NCHW")
# second pass will insert layout transform
_, ldict = correct_layout(g, "NCHW16c")
assert(ldict["x"][0] == "NCHW16c")
assert(ldict["x_NCHW"][0] == "NCHW")
assert(ldict["y"][0] == "NCHW")
# Level 2
def test_conv2d():
x = sym.Variable("data", shape=(1, 32, 512, 512))
......@@ -327,6 +340,7 @@ if __name__ == "__main__":
test_split()
test_batchnorm()
test_flatten()
test_softmax()
test_conv2d()
test_conv2d_transpose()
test_max_pool2d()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment