Commit d50f7b66 by Sergey Mironov Committed by Tianqi Chen

Fix missing sigmoid intrinsic in C++ (#2231)

parent 94b309a1
......@@ -492,6 +492,3 @@ def _rule_float_direct(op):
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
# default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
# default pattern for sigmoid
register_intrin_rule("default", "sigmoid", lambda op: 1.0 / (1.0 + exp(-op.args[0])))
......@@ -24,6 +24,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].type(), 1);
*rv = one / (one + exp(-call->args[0]));
});
} // namespace intrin
} // namespace codegen
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment