Commit d50f7b66 by Sergey Mironov Committed by Tianqi Chen

Fix missing sigmoid intrinsic in C++ (#2231)

parent 94b309a1
...@@ -492,6 +492,3 @@ def _rule_float_direct(op): ...@@ -492,6 +492,3 @@ def _rule_float_direct(op):
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
# default pattern for exp # default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True) register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
# default pattern for sigmoid
register_intrin_rule("default", "sigmoid", lambda op: 1.0 / (1.0 + exp(-op.args[0])))
...@@ -24,6 +24,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") ...@@ -24,6 +24,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].type(), 1);
*rv = one / (one + exp(-call->args[0]));
});
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment