Commit d5164466 by lixiaoquan Committed by Yizhi Liu

[NNVM] Fix dtype of output of pad. (#2331)

Dtype of output of pad should follows input, but if dtype of input is not float,
  output will still be float becase pad_value is float.
parent 395804e5
...@@ -620,7 +620,8 @@ NNVM_REGISTER_OP(pad) ...@@ -620,7 +620,8 @@ NNVM_REGISTER_OP(pad)
for (size_t i = 0; i < pad_width.ndim(); ++i) { for (size_t i = 0; i < pad_width.ndim(); ++i) {
pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1])); pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1]));
} }
return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after, param.pad_value) }; return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
tvm::make_const(inputs[0]->dtype, param.pad_value)) };
}) })
.set_support_level(1); .set_support_level(1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment