Commit df88c411 by 雾雨魔理沙 Committed by Jared Roesch

save (#3033)

save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god

Update src/relay/pass/partial_eval.cc

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

save

fix test

ci is ANGRY

fix rebase problem

fix rebase

add test

save

save

comment
parent 50dd03ca
......@@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e);
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
/*!
* \brief Fold constant expressions.
......@@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e);
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
/*!
* \brief Bind the free variables to a Relay expression.
......
......@@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param inline_once whether or not to inline binding used one.
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination();
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*!
* \brief Fold constant expressions.
......
......@@ -129,7 +129,7 @@ def well_formed(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
......@@ -175,7 +175,7 @@ def free_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
......@@ -197,7 +197,7 @@ def bound_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
......@@ -213,7 +213,7 @@ def all_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
......@@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
......@@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
......@@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
......@@ -286,12 +288,12 @@ def simplify_inference(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
......@@ -304,32 +306,34 @@ def canonicalize_ops(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr):
def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr)
return _ir_pass.dead_code_elimination(expr, inline_once)
def alpha_equal(lhs, rhs):
......@@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
......@@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
......@@ -378,12 +382,12 @@ def structural_hash(value):
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result: int
result : int
The hash value
"""
if isinstance(value, Expr):
......@@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
......@@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
The input expression
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
......@@ -612,7 +616,7 @@ def get_total_mac_number(expr):
Returns
-------
ret : int64
result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
......@@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr
The input expression.
fskip: function
fskip : function
The callback function that decides whether an expression should be skipped.
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr):
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
......@@ -646,12 +650,15 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None):
"""
......
......@@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});
Let LetNode::make(Var var, Expr value, Expr body) {
......@@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
return temp->Realize();
});
} // namespace relay
......
......@@ -38,10 +38,10 @@ namespace relay {
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e) {
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e);
}
......@@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;
bool HasLet(const Var& v) {
// TODO(@jroesch): MK fix me
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}
Expr VisitExpr_(const VarNode* op) final {
......@@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor {
};
};
Expr DeadCodeElimination(const Expr& e) {
return CalcDep::Eliminate(e);
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once);
}
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
......@@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
namespace transform {
Pass DeadCodeElimination() {
Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f));
return Downcast<Function>(DeadCodeElimination(f, inline_once));
};
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment