Unverified Commit 51af454a by Animesh Jain Committed by GitHub

[Relay][FastMath] Relay pass to use fast exp/tanh (#4873)

* [Relay][FastMath] Relay pass to use fast exp/tanh

* Adding required_pass to the tests.

* FastMath test changes.
parent 900d99cd
...@@ -164,6 +164,13 @@ TVM_DLL Pass PartialEval(); ...@@ -164,6 +164,13 @@ TVM_DLL Pass PartialEval();
TVM_DLL Pass SimplifyInference(); TVM_DLL Pass SimplifyInference();
/*! /*!
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*
* \return The Pass.
*/
TVM_DLL Pass FastMath();
/*!
* \brief Infer the type of an expression. * \brief Infer the type of an expression.
* *
* The result of type checking is a new expression with unambigous * The result of type checking is a new expression with unambigous
......
...@@ -57,7 +57,8 @@ def build_config(opt_level=2, ...@@ -57,7 +57,8 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3, "CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3, "EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4, "CombineParallelConv2D": 4,
"CombineParallelDense": 4 "CombineParallelDense": 4,
"FastMath": 4
} }
fallback_device : int, str, or tvmContext, optional fallback_device : int, str, or tvmContext, optional
...@@ -175,11 +176,22 @@ def SimplifyInference(): ...@@ -175,11 +176,22 @@ def SimplifyInference():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.relay.Pass
The registered to perform operator simplification. The registered pass to perform operator simplification.
""" """
return _transform.SimplifyInference() return _transform.SimplifyInference()
def FastMath():
""" Converts the expensive non linear functions to their fast but approximate counterparts.
Returns
-------
ret: tvm.relay.Pass
The registered pass to perform fast math operations.
"""
return _transform.FastMath()
def CanonicalizeOps(): def CanonicalizeOps():
"""Canonicalize special operators to basic operators. """Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to This can simplify followed analysis, e.g. expanding bias_add to
......
...@@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (targets.size() == 1) { if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout()); pass_seqs.push_back(transform::AlterOpLayout());
} }
// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldConstant());
// Create a sequential pass and perform optimizations. // Create a sequential pass and perform optimizations.
......
...@@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp") ...@@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("fast_exp")
.describe(R"code(Returns the fast_exp input array, computed element-wise.
.. math::
\fast_exp(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));
RELAY_REGISTER_UNARY_OP("erf") RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise. .describe(R"code(Returns the error function value for input array, computed element-wise.
...@@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh") ...@@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
RELAY_REGISTER_UNARY_OP("fast_tanh")
.describe(R"code(Returns the fast_tanh of input array, computed element-wise.
.. math::
Y = sinh(X) / cosh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));
RELAY_REGISTER_UNARY_OP("negative") RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise. .describe(R"code(Returns the numeric negative of input array, computed element-wise.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file fast_math.cc
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op.h>
#include "pattern_util.h"
namespace tvm {
namespace relay {
class FastMathMutator : public ExprMutator {
public:
FastMathMutator()
: exp_op_(Op::Get("exp")),
tanh_op_(Op::Get("tanh")) {}
Expr VisitExpr_(const CallNode* n) {
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op == exp_op_) {
return FastExp(new_n.as<CallNode>()->args[0]);
} else if (n->op == tanh_op_) {
return FastTanh(new_n.as<CallNode>()->args[0]);
}
return new_n;
}
private:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const Op& exp_op_;
const Op& tanh_op_;
};
Expr FastMath(const Expr& e) {
return FastMathMutator().Mutate(e);
}
namespace transform {
Pass FastMath() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
return CreateFunctionPass(pass_func, 4, "FastMath",
{tir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.FastMath")
.set_body_typed(FastMath);
} // namespace transform
} // namespace relay
} // namespace tvm
...@@ -316,6 +316,16 @@ inline Expr Exp(Expr e) { ...@@ -316,6 +316,16 @@ inline Expr Exp(Expr e) {
return CallNode::make(op, {e}); return CallNode::make(op, {e});
} }
inline Expr FastExp(Expr e) {
static const Op& op = Op::Get("fast_exp");
return CallNode::make(op, {e});
}
inline Expr FastTanh(Expr e) {
static const Op& op = Op::Get("fast_tanh");
return CallNode::make(op, {e});
}
inline Expr Log(Expr e) { inline Expr Log(Expr e) {
static const Op& op = Op::Get("log"); static const Op& op = Op::Get("log");
return CallNode::make(op, {e}); return CallNode::make(op, {e});
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.ir import IRModule
from tvm import relay
from tvm.relay.transform import FastMath
def test_exp():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.exp(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)
fast_mod = FastMath()(mod)
assert "fast_exp" in fast_mod.astext()
# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_exp" in fast_mod[0].astext()
def test_tanh():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.tanh(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)
fast_mod = FastMath()(mod)
assert "fast_tanh" in fast_mod.astext()
# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_tanh" in fast_mod[0].astext()
if __name__ == "__main__":
test_exp()
test_tanh()
...@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos); ...@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan); TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan); TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);
/* /*
* \brief Fast_tanh_float implementation from Eigen * \brief Fast_tanh_float implementation from Eigen
...@@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in, ...@@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in,
* *
* \return A Tensor whose op member is tanh * \return A Tensor whose op member is tanh
*/ */
inline Tensor tanh(const Tensor& x, inline Tensor fast_tanh(const Tensor& x,
std::string name = "T_tanh", std::string name = "T_fast_tanh",
std::string tag = kElementWise) { std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) { if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation // invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag); return fast_tanh_float(x, name, tag);
......
...@@ -467,3 +467,19 @@ def fast_exp(x): ...@@ -467,3 +467,19 @@ def fast_exp(x):
The result. The result.
""" """
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE) return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
def fast_tanh(x):
"""Take tanhonential of input x using fast_tanh implementation
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
...@@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh") ...@@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]); *rv = tanh(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.fast_tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan") TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]); *rv = atan(args[0]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment