Commit df88c411 by 雾雨魔理沙 Committed by Jared Roesch

save (#3033)

save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god

Update src/relay/pass/partial_eval.cc

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

save

fix test

ci is ANGRY

fix rebase problem

fix rebase

add test

save

save

comment
parent 50dd03ca
...@@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod); ...@@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
* For example, this pass should turn `let a = 1 in 2` into `2`, * For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a. * as the value of the expression does not depend on a.
* *
* As another example, `let a = 1 in a` will be optimized into 1. * As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
* *
* \param e the expression to optimize. * \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
* *
* \return the optimized expression. * \return the optimized expression.
*/ */
TVM_DLL Expr DeadCodeElimination(const Expr& e); TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
/*! /*!
* \brief Fold constant expressions. * \brief Fold constant expressions.
...@@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod); ...@@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode. * As a side effect, code size will explode.
* *
* \param e the expression, * \param e the expression
* \param mod the module
* *
* \return the optimized expression. * \return the optimized expression.
*/ */
TVM_DLL Expr PartialEval(const Expr& e); TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
/*! /*!
* \brief Bind the free variables to a Relay expression. * \brief Bind the free variables to a Relay expression.
......
...@@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< ...@@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
* *
* As another example, `let a = 1 in a` will be optimized into 1. * As another example, `let a = 1 in a` will be optimized into 1.
* *
* \param inline_once whether or not to inline binding used one.
*
* \return the pass. * \return the pass.
*/ */
TVM_DLL Pass DeadCodeElimination(); TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*! /*!
* \brief Fold constant expressions. * \brief Fold constant expressions.
......
...@@ -129,7 +129,7 @@ def well_formed(expr): ...@@ -129,7 +129,7 @@ def well_formed(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
...@@ -175,7 +175,7 @@ def free_vars(expr): ...@@ -175,7 +175,7 @@ def free_vars(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
...@@ -197,7 +197,7 @@ def bound_vars(expr): ...@@ -197,7 +197,7 @@ def bound_vars(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
...@@ -213,7 +213,7 @@ def all_vars(expr): ...@@ -213,7 +213,7 @@ def all_vars(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
...@@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None): ...@@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module The global module
Returns Returns
...@@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None): ...@@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module The global module
Returns Returns
...@@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None): ...@@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional mod : Optional[tvm.relay.Module]
The global module The global module
Returns Returns
...@@ -286,12 +288,12 @@ def simplify_inference(expr): ...@@ -286,12 +288,12 @@ def simplify_inference(expr):
Parameters Parameters
---------- ----------
e: tvm.relay.Expr expr : tvm.relay.Expr
The input Expression The input Expression
Returns Returns
------- -------
result: tvm.relay.Expr result : tvm.relay.Expr
An expression which is semantically equal to the input expression, An expression which is semantically equal to the input expression,
but with some simplification but with some simplification
""" """
...@@ -304,32 +306,34 @@ def canonicalize_ops(expr): ...@@ -304,32 +306,34 @@ def canonicalize_ops(expr):
Parameters Parameters
---------- ----------
e: tvm.relay.Expr expr : tvm.relay.Expr
The input Expression The input Expression
Returns Returns
------- -------
result: tvm.relay.Expr result : tvm.relay.Expr
An expression without bias_add An expression without bias_add
""" """
return _ir_pass.canonicalize_ops(expr) return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr): def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code). """ Remove expressions which does not effect the program result (dead code).
Parameters Parameters
---------- ----------
e: tvm.relay.Expr expr : tvm.relay.Expr
The input Expression The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns Returns
------- -------
result: tvm.relay.Expr result : tvm.relay.Expr
An expression which is semantically equal to the input expression, An expression which is semantically equal to the input expression,
but with dead code removed. but with dead code removed.
""" """
return _ir_pass.dead_code_elimination(expr) return _ir_pass.dead_code_elimination(expr, inline_once)
def alpha_equal(lhs, rhs): def alpha_equal(lhs, rhs):
...@@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs): ...@@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs):
Parameters Parameters
---------- ----------
lhs: tvm.relay.Expr lhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
rhs: tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
Returns Returns
------- -------
result: bool result : bool
True iff lhs is alpha equal to rhs. True iff lhs is alpha equal to rhs.
""" """
return bool(_make._alpha_equal(lhs, rhs)) return bool(_make._alpha_equal(lhs, rhs))
...@@ -359,15 +363,15 @@ def graph_equal(lhs, rhs): ...@@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
Parameters Parameters
---------- ----------
lhs: tvm.relay.Expr lhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
rhs: tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
Returns Returns
------- -------
result: bool result : bool
True iff lhs is data-flow equivalent to rhs. True iff lhs is data-flow equivalent to rhs.
""" """
return bool(_make._graph_equal(lhs, rhs)) return bool(_make._graph_equal(lhs, rhs))
...@@ -378,12 +382,12 @@ def structural_hash(value): ...@@ -378,12 +382,12 @@ def structural_hash(value):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr or tvm.relay.Type expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash. The expression to hash.
Returns Returns
------- -------
result: int result : int
The hash value The hash value
""" """
if isinstance(value, Expr): if isinstance(value, Expr):
...@@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None): ...@@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
mod: Optional[tvm.relay.Module] mod : Optional[tvm.relay.Module]
The global module. The global module.
Returns Returns
------- -------
expr: tvm.relay.Expr result : tvm.relay.Expr
The output expression. The output expression.
""" """
return _ir_pass.to_a_normal_form(expr, mod) return _ir_pass.to_a_normal_form(expr, mod)
...@@ -563,7 +567,7 @@ def to_graph_normal_form(expr): ...@@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
The input expression The input expression
Returns Returns
------- -------
expr : tvm.relay.Expr result : tvm.relay.Expr
The output expression The output expression
""" """
return _ir_pass.to_graph_normal_form(expr) return _ir_pass.to_graph_normal_form(expr)
...@@ -612,7 +616,7 @@ def get_total_mac_number(expr): ...@@ -612,7 +616,7 @@ def get_total_mac_number(expr):
Returns Returns
------- -------
ret : int64 result : int64
The number of MACs (multiply-accumulate) of a model The number of MACs (multiply-accumulate) of a model
""" """
return _ir_pass.GetTotalMacNumber(expr) return _ir_pass.GetTotalMacNumber(expr)
...@@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None): ...@@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
fskip: function fskip : function
The callback function that decides whether an expression should be skipped. The callback function that decides whether an expression should be skipped.
Returns Returns
------- -------
expr : tvm.relay.Expr result : tvm.relay.Expr
The output expression. The output expression.
""" """
return _ir_pass.eliminate_common_subexpr(expr, fskip) return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr): def partial_evaluate(expr, mod=None):
""" """
Evaluate the static fragment of the code. Evaluate the static fragment of the code.
...@@ -646,12 +650,15 @@ def partial_evaluate(expr): ...@@ -646,12 +650,15 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns Returns
------- -------
expr : tvm.relay.Expr result : tvm.relay.Expr
The output expression. The output expression.
""" """
return _ir_pass.partial_evaluate(expr) return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None): def unmatched_cases(match, mod=None):
""" """
......
...@@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call") ...@@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) { .set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", " p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")"; << node->attrs << ", " << node->type_args << ")";
}); });
Let LetNode::make(Var var, Expr value, Expr body) { Let LetNode::make(Var var, Expr value, Expr body) {
...@@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_API("relay._expr.TempExprRealize") TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) { .set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize(); return temp->Realize();
}); });
} // namespace relay } // namespace relay
......
...@@ -38,10 +38,10 @@ namespace relay { ...@@ -38,10 +38,10 @@ namespace relay {
// calculate the dependency graph from expression // calculate the dependency graph from expression
class CalcDep : private ExprVisitor { class CalcDep : private ExprVisitor {
public: public:
static Expr Eliminate(const Expr& e) { static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd; CalcDep cd;
cd.Calculate(e); cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_); Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e); return el(e);
} }
...@@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor { ...@@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
VarMap<Expr> expr_map_; VarMap<Expr> expr_map_;
VarMap<size_t> use_map_; VarMap<size_t> use_map_;
VarSet letrec_set_; VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map, explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map, const VarMap<size_t>& use_map,
const VarSet& letrec_set) : const VarSet& letrec_set,
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { } bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep; friend CalcDep;
bool HasLet(const Var& v) { bool HasLet(const Var& v) {
// TODO(@jroesch): MK fix me switch (use_map_[v]) {
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
} }
Expr VisitExpr_(const VarNode* op) final { Expr VisitExpr_(const VarNode* op) final {
...@@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor { ...@@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor {
}; };
}; };
Expr DeadCodeElimination(const Expr& e) { Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e); return CalcDep::Eliminate(e, inline_once);
} }
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
...@@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") ...@@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
namespace transform { namespace transform {
Pass DeadCodeElimination() { Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f)); return Downcast<Function>(DeadCodeElimination(f, inline_once));
}; };
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment