Commit 3e527669 by 雾雨魔理沙 Committed by Tianqi Chen

[RELAY][PASS] Dead Code Elimination (#1776)

parent d8394e87
...@@ -80,7 +80,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2); ...@@ -80,7 +80,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2);
*/ */
bool AlphaEqual(const Type& t1, const Type& t2); bool AlphaEqual(const Type& t1, const Type& t2);
/*! brief Check that each Var is only bind once. /*! \brief Check that each Var is only bound once.
* *
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
* *
...@@ -88,9 +88,9 @@ bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -88,9 +88,9 @@ bool AlphaEqual(const Type& t1, const Type& t2);
* *
* \param e the expression to check. * \param e the expression to check.
* *
* \return true iff all Var in e is bind at most once. * \return true iff all Var in e is bound at most once.
*/ */
bool WellFormed(const Expr & e); bool WellFormed(const Expr& e);
/*! \brief Get free variables from expression e. /*! \brief Get free variables from expression e.
* *
...@@ -100,7 +100,7 @@ bool WellFormed(const Expr & e); ...@@ -100,7 +100,7 @@ bool WellFormed(const Expr & e);
* *
* \return the set of free variable. * \return the set of free variable.
*/ */
tvm::Array<Var> FreeVariables(const Expr & e); tvm::Array<Var> FreeVariables(const Expr& e);
/*! \brief Get free type parameters from expression e. /*! \brief Get free type parameters from expression e.
* *
...@@ -110,7 +110,7 @@ tvm::Array<Var> FreeVariables(const Expr & e); ...@@ -110,7 +110,7 @@ tvm::Array<Var> FreeVariables(const Expr & e);
* *
* \return the set of free type variables. * \return the set of free type variables.
*/ */
tvm::Array<TypeParam> FreeTypeVariables(const Expr & e); tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
/*! \brief Get free type parameters from type t. /*! \brief Get free type parameters from type t.
* *
...@@ -120,7 +120,20 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr & e); ...@@ -120,7 +120,20 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr & e);
* *
* \return the set of free type variables. * \return the set of free type variables.
*/ */
tvm::Array<TypeParam> FreeTypeVariables(const Type & t); tvm::Array<TypeParam> FreeTypeVariables(const Type& t);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let binding that are not referenced, and if branch that are not entered.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a.
* Another example is `if (true) then 1 else 2` will be optimized into 1.
*
* \param e the expression to optimize.
*
* \return the optimized expression.
*/
Expr DeadCodeElimination(const Expr& e);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -282,6 +282,21 @@ inline void NDArray::reset() { ...@@ -282,6 +282,21 @@ inline void NDArray::reset() {
} }
} }
/*! \brief return the size of data the DLTensor hold, in term of number of bytes
*
* \param arr the input DLTensor
*
* \return number of bytes of data in the DLTensor.
*/
inline size_t GetDataSize(const DLTensor& arr) {
size_t size = 1;
for (tvm_index_t i = 0; i < arr.ndim; ++i) {
size *= static_cast<size_t>(arr.shape[i]);
}
size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8;
return size;
}
inline void NDArray::CopyFrom(DLTensor* other) { inline void NDArray::CopyFrom(DLTensor* other) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor)); CopyFromTo(other, &(data_->dl_tensor));
......
...@@ -5,3 +5,4 @@ def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... ...@@ -5,3 +5,4 @@ def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ... def well_formed(expr: ir.Expr) -> bool: ...
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
\ No newline at end of file
...@@ -16,12 +16,12 @@ def _convert_to_value(arg, ctxt=tvm.cpu(0)): ...@@ -16,12 +16,12 @@ def _convert_to_value(arg, ctxt=tvm.cpu(0)):
"""Convert Python values into the appropriate types """Convert Python values into the appropriate types
for the Relay evaluator. for the Relay evaluator.
""" """
if isinstance(arg, int): if isinstance(arg, bool): # bool is subclass of int
return tvm.nd.array(np.array(arg, dtype='uint8'), ctxt)
elif isinstance(arg, int):
return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) return tvm.nd.array(np.array(arg, dtype='int32'), ctxt)
elif isinstance(arg, float): elif isinstance(arg, float):
return tvm.nd.array(arg, ctxt) return tvm.nd.array(arg, ctxt)
elif isinstance(arg, bool):
return tvm.nd.array(np.array(arg, dtype='float32'), ctxt)
elif isinstance(arg, np.ndarray): elif isinstance(arg, np.ndarray):
return tvm.nd.array(arg, ctxt) return tvm.nd.array(arg, ctxt)
elif isinstance(arg, tvm.ndarray.NDArray): elif isinstance(arg, tvm.ndarray.NDArray):
......
...@@ -6,15 +6,16 @@ Exposes an interface for configuring the passes and scripting ...@@ -6,15 +6,16 @@ Exposes an interface for configuring the passes and scripting
them in Python. them in Python.
""" """
from . import _ir_pass from . import _ir_pass
from . import _make
# pylint: disable=invalid-name # pylint: disable=invalid-name
def infer_type(env, expr): def infer_type(env, expr):
"""Infer the type of expr under the context of env """Infer the type of expr under the context of env.
Parameters Parameters
---------- ----------
env : relay.Environment env : relay.Environment
The global environmemt. The global environment.
expr : relay.Expr expr : relay.Expr
The input expression. The input expression.
...@@ -34,3 +35,37 @@ check_kind = _ir_pass.check_kind ...@@ -34,3 +35,37 @@ check_kind = _ir_pass.check_kind
free_vars = _ir_pass.free_vars free_vars = _ir_pass.free_vars
free_type_vars = _ir_pass.free_type_vars free_type_vars = _ir_pass.free_type_vars
def dead_code_elimination(e):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: relay.Expr
The input Expression
Returns
-------
result: relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(e)
def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs: relay.Expr
One of the input Expression.
rhs: relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
...@@ -12,7 +12,7 @@ class Type(NodeBase): ...@@ -12,7 +12,7 @@ class Type(NodeBase):
"""Compare two Relay types for structural equivalence using """Compare two Relay types for structural equivalence using
alpha equivalence. alpha equivalence.
""" """
return bool(_make._type_alpha_eq(self, other)) return bool(_make._type_alpha_equal(self, other))
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/alpha_eq.cc * \file src/tvm/relay/pass/alpha_eq.cc
* \brief The structral equivalence comparison. * \brief Check that two type are syntactically equal up to alpha equivalence.
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include "./type_visitor.h" #include "./type_visitor.h"
#include "tvm/relay/pass.h" #include "tvm/relay/pass.h"
...@@ -13,6 +14,25 @@ namespace relay { ...@@ -13,6 +14,25 @@ namespace relay {
using namespace tvm::runtime; using namespace tvm::runtime;
bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
if (lhs.defined() != rhs.defined()) {
return false;
} else if (lhs.same_as(rhs)) {
return true;
} else {
auto ldt = lhs->dtype;
auto rdt = rhs->dtype;
CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t s = GetDataSize(*lhs.operator->());
return memcmp(lhs->data, rhs->data, s) == 0;
} else {
return false;
}
}
}
struct TypeAlphaEq : TypeVisitor<const Type&> { struct TypeAlphaEq : TypeVisitor<const Type&> {
tvm::Map<TypeParam, TypeParam> eq_map; tvm::Map<TypeParam, TypeParam> eq_map;
bool equal; bool equal;
...@@ -38,8 +58,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -38,8 +58,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
} }
} }
void VisitType_(const TensorTypeNode *tt1, const Type& t2) final { void VisitType_(const TensorTypeNode* tt1, const Type& t2) final {
if (const TensorTypeNode *tt2 = t2.as<TensorTypeNode>()) { if (const TensorTypeNode* tt2 = t2.as<TensorTypeNode>()) {
DataTypeEqual(tt1->dtype, tt2->dtype); DataTypeEqual(tt1->dtype, tt2->dtype);
ShapeEqual(tt1->shape, tt2->shape); ShapeEqual(tt1->shape, tt2->shape);
} else { } else {
...@@ -47,8 +67,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -47,8 +67,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
} }
} }
void VisitType_(const IncompleteTypeNode *bt1, const Type& t2) final { void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final {
if (const IncompleteTypeNode *bt2 = t2.as<IncompleteTypeNode>()) { if (const IncompleteTypeNode* bt2 = t2.as<IncompleteTypeNode>()) {
equal = equal && bt1 == bt2; equal = equal && bt1 == bt2;
return; return;
} else { } else {
...@@ -56,8 +76,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -56,8 +76,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
} }
} }
void VisitType_(const TypeParamNode *ti1, const Type& t2) final { void VisitType_(const TypeParamNode* ti1, const Type& t2) final {
if (const TypeParamNode *ti2 = t2.as<TypeParamNode>()) { if (const TypeParamNode* ti2 = t2.as<TypeParamNode>()) {
auto tid1 = GetRef<TypeParam>(ti1); auto tid1 = GetRef<TypeParam>(ti1);
auto tid2 = GetRef<TypeParam>(ti2); auto tid2 = GetRef<TypeParam>(ti2);
...@@ -86,8 +106,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -86,8 +106,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
} }
} }
void VisitType_(const FuncTypeNode *op, const Type& t2) final { void VisitType_(const FuncTypeNode* op, const Type& t2) final {
if (const FuncTypeNode *ta2 = t2.as<FuncTypeNode>()) { if (const FuncTypeNode* ta2 = t2.as<FuncTypeNode>()) {
if (op->arg_types.size() != ta2->arg_types.size() if (op->arg_types.size() != ta2->arg_types.size()
|| op->type_params.size() != ta2->type_params.size() || op->type_params.size() != ta2->type_params.size()
|| op->type_constraints.size() != ta2->type_constraints.size()) { || op->type_constraints.size() != ta2->type_constraints.size()) {
...@@ -128,8 +148,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -128,8 +148,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
} }
} }
void VisitType_(const TypeRelationNode *tr1, const Type& t2) final { void VisitType_(const TypeRelationNode* tr1, const Type& t2) final {
if (const TypeRelationNode *tr2 = t2.as<TypeRelationNode>()) { if (const TypeRelationNode* tr2 = t2.as<TypeRelationNode>()) {
if (tr1->func != tr2->func if (tr1->func != tr2->func
|| tr1->num_inputs != tr2->num_inputs || tr1->num_inputs != tr2->num_inputs
|| tr1->attrs != tr2->attrs) { || tr1->attrs != tr2->attrs) {
...@@ -153,8 +173,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -153,8 +173,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
} }
} }
void VisitType_(const TupleTypeNode *op, const Type& t2) final { void VisitType_(const TupleTypeNode* op, const Type& t2) final {
if (const TupleTypeNode *pt = t2.as<TupleTypeNode>()) { if (const TupleTypeNode* pt = t2.as<TupleTypeNode>()) {
if (op->fields.size() != pt->fields.size()) { if (op->fields.size() != pt->fields.size()) {
equal = false; equal = false;
return; return;
...@@ -185,8 +205,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -185,8 +205,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
bool equal; bool equal;
AlphaEq() : eq_map(), equal(true) {} AlphaEq() : eq_map(), equal(true) {}
void VisitExpr_(const VarNode *e1, const Expr& e2) final { void VisitExpr_(const VarNode* e1, const Expr& e2) final {
if (const VarNode *id2 = e2.as<VarNode>()) { if (const VarNode* id2 = e2.as<VarNode>()) {
auto local1 = GetRef<Var>(e1); auto local1 = GetRef<Var>(e1);
auto local2 = GetRef<Var>(id2); auto local2 = GetRef<Var>(id2);
// We handle open terms with this rule assuming variables are identical. // We handle open terms with this rule assuming variables are identical.
...@@ -207,17 +227,17 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -207,17 +227,17 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
} }
} }
void VisitExpr_(const GlobalVarNode *g1, const Expr& e2) final { void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final {
if (const GlobalVarNode *g2 = e2.as<GlobalVarNode>()) { if (const GlobalVarNode* g2 = e2.as<GlobalVarNode>()) {
equal = equal && g1 == g2; equal = equal && g1 == g2;
} else { } else {
equal = false; equal = false;
} }
} }
void VisitExpr_(const TupleNode *pl1, const Expr& e2) final { void VisitExpr_(const TupleNode* pl1, const Expr& e2) final {
Tuple prod1 = GetRef<Tuple>(pl1); Tuple prod1 = GetRef<Tuple>(pl1);
if (const TupleNode *pl2 = e2.as<TupleNode>()) { if (const TupleNode* pl2 = e2.as<TupleNode>()) {
Tuple prod2 = GetRef<Tuple>(pl2); Tuple prod2 = GetRef<Tuple>(pl2);
if (prod1->fields.size() != prod2->fields.size()) { if (prod1->fields.size() != prod2->fields.size()) {
equal = false; equal = false;
...@@ -232,8 +252,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -232,8 +252,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
} }
} }
void VisitExpr_(const ParamNode *p1, const Expr& e2) final { void VisitExpr_(const ParamNode* p1, const Expr& e2) final {
if (const ParamNode *p2 = e2.as<ParamNode>()) { if (const ParamNode* p2 = e2.as<ParamNode>()) {
eq_map.Set(p1->var, p2->var); eq_map.Set(p1->var, p2->var);
equal = equal && AlphaEqual(p1->type, p2->type); equal = equal && AlphaEqual(p1->type, p2->type);
} else { } else {
...@@ -241,8 +261,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -241,8 +261,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
} }
} }
void VisitExpr_(const FunctionNode *func1, const Expr& e2) final { void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
if (const FunctionNode *func2 = e2.as<FunctionNode>()) { if (const FunctionNode* func2 = e2.as<FunctionNode>()) {
if (func1->params.size() != func2->params.size()) { if (func1->params.size() != func2->params.size()) {
equal = false; equal = false;
return; return;
...@@ -258,8 +278,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -258,8 +278,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
} }
} }
void VisitExpr_(const CallNode *op, const Expr& e2) final { void VisitExpr_(const CallNode* op, const Expr& e2) final {
if (const CallNode *call = e2.as<CallNode>()) { if (const CallNode* call = e2.as<CallNode>()) {
this->VisitExpr(op->op, call->op); this->VisitExpr(op->op, call->op);
if (op->args.size() != call->args.size()) { if (op->args.size() != call->args.size()) {
...@@ -276,8 +296,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -276,8 +296,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
} }
} }
void VisitExpr_(const LetNode *op, const Expr& e2) final { void VisitExpr_(const LetNode* op, const Expr& e2) final {
if (const LetNode *let = e2.as<LetNode>()) { if (const LetNode* let = e2.as<LetNode>()) {
eq_map.Set(op->var, let->var); eq_map.Set(op->var, let->var);
this->VisitExpr(op->value, let->value); this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body); this->VisitExpr(op->body, let->body);
...@@ -285,6 +305,36 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -285,6 +305,36 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal = false; equal = false;
} }
} }
void VisitExpr_(const IfNode* op, const Expr& e2) final {
if (const IfNode* i = e2.as<IfNode>()) {
VisitExpr(op->cond, i->cond);
VisitExpr(op->true_branch, i->true_branch);
VisitExpr(op->false_branch, i->false_branch);
} else {
equal = false;
}
}
void VisitExpr_(const OpNode* op, const Expr& e2) final {
if (const OpNode* o = e2.as<OpNode>()) {
equal = equal && op->name == o->name;
} else {
equal = false;
}
}
void VisitExpr_(const ConstantNode* op, const Expr& e2) final {
if (const ConstantNode* c = e2.as<ConstantNode>()) {
if (AlphaEqual(op->tensor_type(), c->tensor_type())) {
equal = equal && SameNDArray(op->data, c->data);
} else {
equal = false;
}
} else {
equal = false;
}
}
}; };
bool AlphaEqual(const Expr& e1, const Expr& e2) { bool AlphaEqual(const Expr& e1, const Expr& e2) {
...@@ -294,15 +344,15 @@ bool AlphaEqual(const Expr& e1, const Expr& e2) { ...@@ -294,15 +344,15 @@ bool AlphaEqual(const Expr& e1, const Expr& e2) {
} }
// TODO(@jroesch): move to correct namespace? // TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_eq") TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Expr e1 = args[0]; Expr e1 = args[0];
Expr e2 = args[1]; Expr e2 = args[1];
*ret = AlphaEqual(e1, e2); *ret = AlphaEqual(e1, e2);
}); });
TVM_REGISTER_API("relay._make._type_alpha_eq") TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Type t1 = args[0]; Type t1 = args[0];
Type t2 = args[1]; Type t2 = args[1];
*ret = AlphaEqual(t1, t2); *ret = AlphaEqual(t1, t2);
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file dead_code.cc
*
* \brief Remove code that does not effect the program result.
*
* The algorithm is implemented by two visitor:
* CalcDep turn an expr into a dependency graph of expr,
* GenLet turn the dependency graph into a let list, taking only the used value.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
namespace tvm {
namespace relay {
bool IsBoolLit(const Expr& e, bool b) {
if (const ConstantNode* c = e.as<ConstantNode>()) {
if (c->is_scalar()) {
auto dt = c->tensor_type()->dtype;
if (dt == UInt(8)) {
return *reinterpret_cast<const uint8_t*>(c->data->data) == b;
} else if (dt == UInt(16)) {
return *reinterpret_cast<const uint16_t*>(c->data->data) == b;
} else if (dt == UInt(32)) {
return *reinterpret_cast<const uint32_t*>(c->data->data) == b;
} else if (dt == UInt(64)) {
return *reinterpret_cast<const uint64_t*>(c->data->data) == b;
} else if (dt == Int(8)) {
return *reinterpret_cast<const int8_t*>(c->data->data) == b;
} else if (dt == Int(16)) {
return *reinterpret_cast<const int16_t*>(c->data->data) == b;
} else if (dt == Int(32)) {
return *reinterpret_cast<const int32_t*>(c->data->data) == b;
} else if (dt == Int(64)) {
return *reinterpret_cast<const int64_t*>(c->data->data) == b;
}
}
}
return false;
}
// calculate the dependency graph from expression
class CalcDep : private ExprMutator {
public:
static Expr Eliminate(const Expr& e) {
CalcDep cd;
auto res = cd(e);
GenLet gl(cd.var_map_);
gl(res);
return gl.lets_.Get(res);
}
private:
struct Binder {
Type t;
Expr e;
Binder(const Type& t, const Expr& e) : t(t), e(e) { }
};
using VarMap = std::unordered_map<Var, Binder, NodeHash, NodeEqual>;
VarMap var_map_;
Expr VisitExpr_(const IfNode* i) final {
auto cond = VisitExpr(i->cond);
if (IsBoolLit(cond, true)) {
return Eliminate(i->true_branch);
} else if (IsBoolLit(cond, false)) {
return Eliminate(i->false_branch);
} else {
return IfNode::make(cond, Eliminate(i->true_branch), Eliminate(i->false_branch));
}
}
Expr VisitExpr_(const LetNode* l) final {
var_map_.insert(std::pair<Var, Binder>(l->var,
Binder(l->value_type,
Eliminate(l->value))));
return VisitExpr(l->body);
}
Expr VisitExpr_(const FunctionNode* f) final {
return FunctionNode::make(f->params, f->ret_type, Eliminate(f->body), f->type_params);
}
// generate the let list from dependency graph
class GenLet : private ExprVisitor {
private:
LetList lets_;
VarMap var_map_;
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { }
friend CalcDep;
void VisitExpr_(const VarNode* vn) final {
Var v = GetRef<Var>(vn);
if (var_map_.count(v) != 0) {
auto val = var_map_.at(v);
var_map_.erase(v);
// erase before visit to handle letrec
VisitExpr(val.e);
// visit before push back so the dependency of dependency is before the dependency
lets_.Push(v, val.t, val.e);
}
}
};
};
Expr DeadCodeElimination(const Expr& e) {
return CalcDep::Eliminate(e);
}
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = DeadCodeElimination(args[0]);
});
} // namespace relay
} // namespace tvm
...@@ -25,15 +25,6 @@ inline void VerifyDataType(DLDataType dtype) { ...@@ -25,15 +25,6 @@ inline void VerifyDataType(DLDataType dtype) {
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
} }
inline size_t GetDataSize(const DLTensor& arr) {
size_t size = 1;
for (tvm_index_t i = 0; i < arr.ndim; ++i) {
size *= arr.shape[i];
}
size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8;
return size;
}
inline size_t GetDataAlignment(const DLTensor& arr) { inline size_t GetDataAlignment(const DLTensor& arr) {
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment; if (align < kAllocAlignment) return kAllocAlignment;
......
import tvm
from tvm import relay
from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from tvm.relay.ir_builder import convert, IRBuilder
from tvm.relay.op import log, add, equal, subtract, concat
class env:
def __init__(self):
self.a = relay.Var("a")
self.b = relay.Var("b")
self.c = relay.Var("c")
self.d = relay.Var("d")
self.e = relay.Var("e")
self.x = relay.Var("x")
self.y = relay.Var("y")
self.z = relay.Var("z")
self.shape = tvm.convert([1, 2, 3])
self.tt = relay.TensorType(self.shape, "float32")
self.int32 = relay.TensorType([], "int32")
self.float32 = relay.TensorType([], "float32")
self.one = convert(1.0)
self.two = convert(2.0)
self.three = convert(3.0)
e = env()
def test_let():
orig = relay.Let(e.x, e.y, e.z, e.tt)
assert alpha_equal(dead_code_elimination(orig), e.z)
def test_used_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))
def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), e.e)
# make sure we dont infinite loop
def test_recursion():
"""
Program:
let f(n: i32, data: f32) -> f32 = {
if (n == 0) {
return data;
} else {
return f(n - 1, log(data));
}
}
f(2, 10000);
"""
f = relay.Var("f")
n = relay.Var("n")
np = relay.Param(n, e.int32)
data = relay.Var("data")
datap = relay.Param(data, e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
value = relay.Function([np, datap], e.float32, funcbody, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32)
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))
def test_if():
orig = relay.If(convert(True), e.a, e.b)
assert alpha_equal(dead_code_elimination(orig), e.a)
if __name__ == "__main__":
test_let()
test_used_let()
test_chain_unused_let()
test_recursion()
test_op_let()
test_if()
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import alpha_equal
from tvm.relay.ir_builder import convert
def test_tensor_type_alpha_equal():
def test_tensor_type_alpha_eq():
t1 = relay.TensorType((3, 4), "float32") t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32") t2 = relay.TensorType((3, 4), "float32")
t3 = relay.TensorType((3, 4, 5), "float32") t3 = relay.TensorType((3, 4, 5), "float32")
...@@ -13,8 +14,14 @@ def test_tensor_type_alpha_eq(): ...@@ -13,8 +14,14 @@ def test_tensor_type_alpha_eq():
t2 = relay.TensorType((), "float32") t2 = relay.TensorType((), "float32")
assert t1 == t2 assert t1 == t2
def test_constant_alpha_equal():
x = convert(1)
y = convert(2)
assert alpha_equal(x, x)
assert not alpha_equal(x, y)
assert alpha_equal(x, convert(1))
def test_incomplete_type_alpha_eq(): def test_incomplete_type_alpha_equal():
t1 = relay.IncompleteType(relay.Kind.Shape) t1 = relay.IncompleteType(relay.Kind.Shape)
t2 = relay.IncompleteType(relay.Kind.Type) t2 = relay.IncompleteType(relay.Kind.Type)
t3 = relay.IncompleteType(relay.Kind.Type) t3 = relay.IncompleteType(relay.Kind.Type)
...@@ -26,7 +33,7 @@ def test_incomplete_type_alpha_eq(): ...@@ -26,7 +33,7 @@ def test_incomplete_type_alpha_eq():
assert t2 != t3 assert t2 != t3
def test_type_param_alpha_eq(): def test_type_param_alpha_equal():
t1 = relay.TypeParam("v1", relay.Kind.Type) t1 = relay.TypeParam("v1", relay.Kind.Type)
t2 = relay.TypeParam("v2", relay.Kind.Shape) t2 = relay.TypeParam("v2", relay.Kind.Shape)
t3 = relay.TypeParam("v3", relay.Kind.Type) t3 = relay.TypeParam("v3", relay.Kind.Type)
...@@ -48,7 +55,7 @@ def test_type_param_alpha_eq(): ...@@ -48,7 +55,7 @@ def test_type_param_alpha_eq():
assert ft1 != ft3 # kinds still do not match assert ft1 != ft3 # kinds still do not match
def test_func_type_alpha_eq(): def test_func_type_alpha_equal():
t1 = relay.TensorType((1, 2), "float32") t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32") t2 = relay.TensorType((1, 2, 3), "float32")
...@@ -108,7 +115,7 @@ def test_func_type_alpha_eq(): ...@@ -108,7 +115,7 @@ def test_func_type_alpha_eq():
assert ft != more_rels assert ft != more_rels
def test_tuple_type_alpha_eq(): def test_tuple_type_alpha_equal():
t1 = relay.TensorType((1, 2, 3), "float32") t1 = relay.TensorType((1, 2, 3), "float32")
t2 = relay.TensorType((1, 2, 3, 4), "float32") t2 = relay.TensorType((1, 2, 3, 4), "float32")
tp1 = relay.TypeParam("v1", relay.Kind.Type) tp1 = relay.TypeParam("v1", relay.Kind.Type)
...@@ -126,7 +133,7 @@ def test_tuple_type_alpha_eq(): ...@@ -126,7 +133,7 @@ def test_tuple_type_alpha_eq():
assert tup1 != tup4 assert tup1 != tup4
def test_type_relation_alpha_eq(): def test_type_relation_alpha_equal():
t1 = relay.TensorType((1, 2), "float32") t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32") t2 = relay.TensorType((1, 2, 3), "float32")
t3 = relay.TensorType((1, 2, 3, 4), "float32") t3 = relay.TensorType((1, 2, 3, 4), "float32")
...@@ -162,9 +169,9 @@ def test_type_relation_alpha_eq(): ...@@ -162,9 +169,9 @@ def test_type_relation_alpha_eq():
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_type_alpha_eq() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_eq() test_incomplete_type_alpha_equal()
test_type_param_alpha_eq() test_type_param_alpha_equal()
test_func_type_alpha_eq() test_func_type_alpha_equal()
test_tuple_type_alpha_eq() test_tuple_type_alpha_equal()
test_type_relation_alpha_eq() test_type_relation_alpha_equal()
...@@ -120,9 +120,9 @@ def test_recursion(): ...@@ -120,9 +120,9 @@ def test_recursion():
Program: Program:
def f(n: i32, data: f32) -> f32 { def f(n: i32, data: f32) -> f32 {
if (n == 0) { if (n == 0) {
return f(n - 1, log(data));
} else {
return data; return data;
} else {
return f(n - 1, log(data));
} }
} }
f(2, 10000); f(2, 10000);
...@@ -133,9 +133,9 @@ def test_recursion(): ...@@ -133,9 +133,9 @@ def test_recursion():
data = b.param('data', ty='float32') data = b.param('data', ty='float32')
with b.decl(f, n, data): with b.decl(f, n, data):
with b.if_scope(equal(n, convert(0))): with b.if_scope(equal(n, convert(0))):
b.ret(f(subtract(n, convert(1)), log(data)))
with b.else_scope():
b.ret(data) b.ret(data)
with b.else_scope():
b.ret(f(subtract(n, convert(1)), log(data)))
b.ret(f(convert(2.0), convert(10000.0))) b.ret(f(convert(2.0), convert(10000.0)))
assert_decl_has_type(b.env, 'f', func_type( assert_decl_has_type(b.env, 'f', func_type(
['int32', 'float32'], 'float32')) ['int32', 'float32'], 'float32'))
...@@ -160,11 +160,11 @@ def test_concat(): ...@@ -160,11 +160,11 @@ def test_concat():
if __name__ == "__main__": if __name__ == "__main__":
test_dual_op() test_dual_op()
test_recursion() test_recursion()
test_monomorphic_let() test_monomorphic_let()
test_single_op() test_single_op()
test_add_op() test_add_op()
test_add_broadcast_op() test_add_broadcast_op()
test_decl() test_decl()
test_recursion()
test_concat() test_concat()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment