Unverified Commit 51af454a by Animesh Jain Committed by GitHub

[Relay][FastMath] Relay pass to use fast exp/tanh (#4873)

* [Relay][FastMath] Relay pass to use fast exp/tanh

* Adding required_pass to the tests.

* FastMath test changes.
parent 900d99cd
......@@ -164,6 +164,13 @@ TVM_DLL Pass PartialEval();
TVM_DLL Pass SimplifyInference();
/*!
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*
* \return The Pass.
*/
TVM_DLL Pass FastMath();
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
......
......@@ -57,7 +57,8 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
"CombineParallelDense": 4,
"FastMath": 4
}
fallback_device : int, str, or tvmContext, optional
......@@ -175,11 +176,22 @@ def SimplifyInference():
Returns
-------
ret: tvm.relay.Pass
The registered to perform operator simplification.
The registered pass to perform operator simplification.
"""
return _transform.SimplifyInference()
def FastMath():
""" Converts the expensive non linear functions to their fast but approximate counterparts.
Returns
-------
ret: tvm.relay.Pass
The registered pass to perform fast math operations.
"""
return _transform.FastMath()
def CanonicalizeOps():
"""Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to
......
......@@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}
// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
// Create a sequential pass and perform optimizations.
......
......@@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("fast_exp")
.describe(R"code(Returns the fast_exp input array, computed element-wise.
.. math::
\fast_exp(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));
RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
......@@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
RELAY_REGISTER_UNARY_OP("fast_tanh")
.describe(R"code(Returns the fast_tanh of input array, computed element-wise.
.. math::
Y = sinh(X) / cosh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));
RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file fast_math.cc
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op.h>
#include "pattern_util.h"
namespace tvm {
namespace relay {
class FastMathMutator : public ExprMutator {
public:
FastMathMutator()
: exp_op_(Op::Get("exp")),
tanh_op_(Op::Get("tanh")) {}
Expr VisitExpr_(const CallNode* n) {
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op == exp_op_) {
return FastExp(new_n.as<CallNode>()->args[0]);
} else if (n->op == tanh_op_) {
return FastTanh(new_n.as<CallNode>()->args[0]);
}
return new_n;
}
private:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const Op& exp_op_;
const Op& tanh_op_;
};
Expr FastMath(const Expr& e) {
return FastMathMutator().Mutate(e);
}
namespace transform {
Pass FastMath() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
return CreateFunctionPass(pass_func, 4, "FastMath",
{tir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.FastMath")
.set_body_typed(FastMath);
} // namespace transform
} // namespace relay
} // namespace tvm
......@@ -316,6 +316,16 @@ inline Expr Exp(Expr e) {
return CallNode::make(op, {e});
}
inline Expr FastExp(Expr e) {
static const Op& op = Op::Get("fast_exp");
return CallNode::make(op, {e});
}
inline Expr FastTanh(Expr e) {
static const Op& op = Op::Get("fast_tanh");
return CallNode::make(op, {e});
}
inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.ir import IRModule
from tvm import relay
from tvm.relay.transform import FastMath
def test_exp():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.exp(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)
fast_mod = FastMath()(mod)
assert "fast_exp" in fast_mod.astext()
# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_exp" in fast_mod[0].astext()
def test_tanh():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.tanh(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)
fast_mod = FastMath()(mod)
assert "fast_tanh" in fast_mod.astext()
# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_tanh" in fast_mod[0].astext()
if __name__ == "__main__":
test_exp()
test_tanh()
......@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);
/*
* \brief Fast_tanh_float implementation from Eigen
......@@ -113,8 +114,8 @@ inline Tensor fast_tanh_float(const Tensor& in,
*
* \return A Tensor whose op member is tanh
*/
inline Tensor tanh(const Tensor& x,
std::string name = "T_tanh",
inline Tensor fast_tanh(const Tensor& x,
std::string name = "T_fast_tanh",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation
......
......@@ -467,3 +467,19 @@ def fast_exp(x):
The result.
"""
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
def fast_tanh(x):
"""Take tanhonential of input x using fast_tanh implementation
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
......@@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.fast_tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment