Commit 669b44c1 by Tatsuya Nishiyama Committed by Tianqi Chen

Fix the gemm conversion in onnx frontend (#1241)

parent 552fa71b
...@@ -186,6 +186,7 @@ class Gemm(OnnxOpConverter): ...@@ -186,6 +186,7 @@ class Gemm(OnnxOpConverter):
inputs[0] = _sym.transpose(inputs[0], axes=(1, 0)) inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
if not transB: if not transB:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0)) inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
inputs[0] = _sym.flatten(inputs[0])
return _sym.dense( return _sym.dense(
alpha * inputs[0], inputs[1], beta * inputs[2], units=channels) alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment