Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3e527669
Commit
3e527669
authored
Oct 06, 2018
by
雾雨魔理沙
Committed by
Tianqi Chen
Oct 06, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] Dead Code Elimination (#1776)
parent
d8394e87
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
382 additions
and
74 deletions
+382
-74
include/tvm/relay/pass.h
+19
-6
include/tvm/runtime/ndarray.h
+15
-0
python/tvm/relay/_ir_pass.pyi
+3
-2
python/tvm/relay/ir_builder.py
+3
-3
python/tvm/relay/ir_pass.py
+37
-2
python/tvm/relay/ty.py
+1
-1
src/relay/pass/alpha_eq.cc
+81
-31
src/relay/pass/dead_code.cc
+119
-0
src/runtime/ndarray.cc
+2
-11
tests/python/relay/test_dead_code_elimination.py
+77
-0
tests/python/relay/test_pass_alpha_equal.py
+20
-13
tests/python/relay/test_type_infer.py
+5
-5
No files found.
include/tvm/relay/pass.h
View file @
3e527669
...
...
@@ -80,7 +80,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2);
*/
bool
AlphaEqual
(
const
Type
&
t1
,
const
Type
&
t2
);
/*!
brief Check that each Var is only bi
nd once.
/*!
\brief Check that each Var is only bou
nd once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
...
...
@@ -88,9 +88,9 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*
* \param e the expression to check.
*
* \return true iff all Var in e is b
i
nd at most once.
* \return true iff all Var in e is b
ou
nd at most once.
*/
bool
WellFormed
(
const
Expr
&
e
);
bool
WellFormed
(
const
Expr
&
e
);
/*! \brief Get free variables from expression e.
*
...
...
@@ -100,7 +100,7 @@ bool WellFormed(const Expr & e);
*
* \return the set of free variable.
*/
tvm
::
Array
<
Var
>
FreeVariables
(
const
Expr
&
e
);
tvm
::
Array
<
Var
>
FreeVariables
(
const
Expr
&
e
);
/*! \brief Get free type parameters from expression e.
*
...
...
@@ -110,7 +110,7 @@ tvm::Array<Var> FreeVariables(const Expr & e);
*
* \return the set of free type variables.
*/
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Expr
&
e
);
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Expr
&
e
);
/*! \brief Get free type parameters from type t.
*
...
...
@@ -120,7 +120,20 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr & e);
*
* \return the set of free type variables.
*/
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Type
&
t
);
tvm
::
Array
<
TypeParam
>
FreeTypeVariables
(
const
Type
&
t
);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let binding that are not referenced, and if branch that are not entered.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a.
* Another example is `if (true) then 1 else 2` will be optimized into 1.
*
* \param e the expression to optimize.
*
* \return the optimized expression.
*/
Expr
DeadCodeElimination
(
const
Expr
&
e
);
}
// namespace relay
}
// namespace tvm
...
...
include/tvm/runtime/ndarray.h
View file @
3e527669
...
...
@@ -282,6 +282,21 @@ inline void NDArray::reset() {
}
}
/*! \brief return the size of data the DLTensor hold, in term of number of bytes
*
* \param arr the input DLTensor
*
* \return number of bytes of data in the DLTensor.
*/
inline
size_t
GetDataSize
(
const
DLTensor
&
arr
)
{
size_t
size
=
1
;
for
(
tvm_index_t
i
=
0
;
i
<
arr
.
ndim
;
++
i
)
{
size
*=
static_cast
<
size_t
>
(
arr
.
shape
[
i
]);
}
size
*=
(
arr
.
dtype
.
bits
*
arr
.
dtype
.
lanes
+
7
)
/
8
;
return
size
;
}
inline
void
NDArray
::
CopyFrom
(
DLTensor
*
other
)
{
CHECK
(
data_
!=
nullptr
);
CopyFromTo
(
other
,
&
(
data_
->
dl_tensor
));
...
...
python/tvm/relay/_ir_pass.pyi
View file @
3e527669
...
...
@@ -4,4 +4,5 @@ from . import ir
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ...
\ No newline at end of file
def well_formed(expr: ir.Expr) -> bool: ...
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
\ No newline at end of file
python/tvm/relay/ir_builder.py
View file @
3e527669
...
...
@@ -16,12 +16,12 @@ def _convert_to_value(arg, ctxt=tvm.cpu(0)):
"""Convert Python values into the appropriate types
for the Relay evaluator.
"""
if
isinstance
(
arg
,
int
):
if
isinstance
(
arg
,
bool
):
# bool is subclass of int
return
tvm
.
nd
.
array
(
np
.
array
(
arg
,
dtype
=
'uint8'
),
ctxt
)
elif
isinstance
(
arg
,
int
):
return
tvm
.
nd
.
array
(
np
.
array
(
arg
,
dtype
=
'int32'
),
ctxt
)
elif
isinstance
(
arg
,
float
):
return
tvm
.
nd
.
array
(
arg
,
ctxt
)
elif
isinstance
(
arg
,
bool
):
return
tvm
.
nd
.
array
(
np
.
array
(
arg
,
dtype
=
'float32'
),
ctxt
)
elif
isinstance
(
arg
,
np
.
ndarray
):
return
tvm
.
nd
.
array
(
arg
,
ctxt
)
elif
isinstance
(
arg
,
tvm
.
ndarray
.
NDArray
):
...
...
python/tvm/relay/ir_pass.py
View file @
3e527669
...
...
@@ -6,15 +6,16 @@ Exposes an interface for configuring the passes and scripting
them in Python.
"""
from
.
import
_ir_pass
from
.
import
_make
# pylint: disable=invalid-name
def
infer_type
(
env
,
expr
):
"""Infer the type of expr under the context of env
"""Infer the type of expr under the context of env
.
Parameters
----------
env : relay.Environment
The global environme
m
t.
The global environme
n
t.
expr : relay.Expr
The input expression.
...
...
@@ -34,3 +35,37 @@ check_kind = _ir_pass.check_kind
free_vars
=
_ir_pass
.
free_vars
free_type_vars
=
_ir_pass
.
free_type_vars
def
dead_code_elimination
(
e
):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: relay.Expr
The input Expression
Returns
-------
result: relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return
_ir_pass
.
dead_code_elimination
(
e
)
def
alpha_equal
(
lhs
,
rhs
):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs: relay.Expr
One of the input Expression.
rhs: relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is alpha equal to rhs.
"""
return
bool
(
_make
.
_alpha_equal
(
lhs
,
rhs
))
python/tvm/relay/ty.py
View file @
3e527669
...
...
@@ -12,7 +12,7 @@ class Type(NodeBase):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return
bool
(
_make
.
_type_alpha_eq
(
self
,
other
))
return
bool
(
_make
.
_type_alpha_eq
ual
(
self
,
other
))
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
...
...
src/relay/pass/alpha_eq.cc
View file @
3e527669
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/alpha_eq.cc
* \brief
The structral equivalence comparison
.
* \brief
Check that two type are syntactically equal up to alpha equivalence
.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include "./type_visitor.h"
#include "tvm/relay/pass.h"
...
...
@@ -13,6 +14,25 @@ namespace relay {
using
namespace
tvm
::
runtime
;
bool
SameNDArray
(
const
NDArray
&
lhs
,
const
NDArray
&
rhs
)
{
if
(
lhs
.
defined
()
!=
rhs
.
defined
())
{
return
false
;
}
else
if
(
lhs
.
same_as
(
rhs
))
{
return
true
;
}
else
{
auto
ldt
=
lhs
->
dtype
;
auto
rdt
=
rhs
->
dtype
;
CHECK_EQ
(
lhs
->
ctx
.
device_type
,
kDLCPU
)
<<
"can only compare CPU tensor"
;
CHECK_EQ
(
rhs
->
ctx
.
device_type
,
kDLCPU
)
<<
"can only compare CPU tensor"
;
if
(
ldt
.
code
==
rdt
.
code
&&
ldt
.
lanes
==
rdt
.
lanes
&&
ldt
.
bits
==
rdt
.
bits
)
{
size_t
s
=
GetDataSize
(
*
lhs
.
operator
->
());
return
memcmp
(
lhs
->
data
,
rhs
->
data
,
s
)
==
0
;
}
else
{
return
false
;
}
}
}
struct
TypeAlphaEq
:
TypeVisitor
<
const
Type
&>
{
tvm
::
Map
<
TypeParam
,
TypeParam
>
eq_map
;
bool
equal
;
...
...
@@ -38,8 +58,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
void
VisitType_
(
const
TensorTypeNode
*
tt1
,
const
Type
&
t2
)
final
{
if
(
const
TensorTypeNode
*
tt2
=
t2
.
as
<
TensorTypeNode
>
())
{
void
VisitType_
(
const
TensorTypeNode
*
tt1
,
const
Type
&
t2
)
final
{
if
(
const
TensorTypeNode
*
tt2
=
t2
.
as
<
TensorTypeNode
>
())
{
DataTypeEqual
(
tt1
->
dtype
,
tt2
->
dtype
);
ShapeEqual
(
tt1
->
shape
,
tt2
->
shape
);
}
else
{
...
...
@@ -47,8 +67,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
void
VisitType_
(
const
IncompleteTypeNode
*
bt1
,
const
Type
&
t2
)
final
{
if
(
const
IncompleteTypeNode
*
bt2
=
t2
.
as
<
IncompleteTypeNode
>
())
{
void
VisitType_
(
const
IncompleteTypeNode
*
bt1
,
const
Type
&
t2
)
final
{
if
(
const
IncompleteTypeNode
*
bt2
=
t2
.
as
<
IncompleteTypeNode
>
())
{
equal
=
equal
&&
bt1
==
bt2
;
return
;
}
else
{
...
...
@@ -56,8 +76,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
void
VisitType_
(
const
TypeParamNode
*
ti1
,
const
Type
&
t2
)
final
{
if
(
const
TypeParamNode
*
ti2
=
t2
.
as
<
TypeParamNode
>
())
{
void
VisitType_
(
const
TypeParamNode
*
ti1
,
const
Type
&
t2
)
final
{
if
(
const
TypeParamNode
*
ti2
=
t2
.
as
<
TypeParamNode
>
())
{
auto
tid1
=
GetRef
<
TypeParam
>
(
ti1
);
auto
tid2
=
GetRef
<
TypeParam
>
(
ti2
);
...
...
@@ -86,8 +106,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
void
VisitType_
(
const
FuncTypeNode
*
op
,
const
Type
&
t2
)
final
{
if
(
const
FuncTypeNode
*
ta2
=
t2
.
as
<
FuncTypeNode
>
())
{
void
VisitType_
(
const
FuncTypeNode
*
op
,
const
Type
&
t2
)
final
{
if
(
const
FuncTypeNode
*
ta2
=
t2
.
as
<
FuncTypeNode
>
())
{
if
(
op
->
arg_types
.
size
()
!=
ta2
->
arg_types
.
size
()
||
op
->
type_params
.
size
()
!=
ta2
->
type_params
.
size
()
||
op
->
type_constraints
.
size
()
!=
ta2
->
type_constraints
.
size
())
{
...
...
@@ -128,8 +148,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
void
VisitType_
(
const
TypeRelationNode
*
tr1
,
const
Type
&
t2
)
final
{
if
(
const
TypeRelationNode
*
tr2
=
t2
.
as
<
TypeRelationNode
>
())
{
void
VisitType_
(
const
TypeRelationNode
*
tr1
,
const
Type
&
t2
)
final
{
if
(
const
TypeRelationNode
*
tr2
=
t2
.
as
<
TypeRelationNode
>
())
{
if
(
tr1
->
func
!=
tr2
->
func
||
tr1
->
num_inputs
!=
tr2
->
num_inputs
||
tr1
->
attrs
!=
tr2
->
attrs
)
{
...
...
@@ -153,8 +173,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}
}
void
VisitType_
(
const
TupleTypeNode
*
op
,
const
Type
&
t2
)
final
{
if
(
const
TupleTypeNode
*
pt
=
t2
.
as
<
TupleTypeNode
>
())
{
void
VisitType_
(
const
TupleTypeNode
*
op
,
const
Type
&
t2
)
final
{
if
(
const
TupleTypeNode
*
pt
=
t2
.
as
<
TupleTypeNode
>
())
{
if
(
op
->
fields
.
size
()
!=
pt
->
fields
.
size
())
{
equal
=
false
;
return
;
...
...
@@ -185,8 +205,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
bool
equal
;
AlphaEq
()
:
eq_map
(),
equal
(
true
)
{}
void
VisitExpr_
(
const
VarNode
*
e1
,
const
Expr
&
e2
)
final
{
if
(
const
VarNode
*
id2
=
e2
.
as
<
VarNode
>
())
{
void
VisitExpr_
(
const
VarNode
*
e1
,
const
Expr
&
e2
)
final
{
if
(
const
VarNode
*
id2
=
e2
.
as
<
VarNode
>
())
{
auto
local1
=
GetRef
<
Var
>
(
e1
);
auto
local2
=
GetRef
<
Var
>
(
id2
);
// We handle open terms with this rule assuming variables are identical.
...
...
@@ -207,17 +227,17 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
void
VisitExpr_
(
const
GlobalVarNode
*
g1
,
const
Expr
&
e2
)
final
{
if
(
const
GlobalVarNode
*
g2
=
e2
.
as
<
GlobalVarNode
>
())
{
void
VisitExpr_
(
const
GlobalVarNode
*
g1
,
const
Expr
&
e2
)
final
{
if
(
const
GlobalVarNode
*
g2
=
e2
.
as
<
GlobalVarNode
>
())
{
equal
=
equal
&&
g1
==
g2
;
}
else
{
equal
=
false
;
}
}
void
VisitExpr_
(
const
TupleNode
*
pl1
,
const
Expr
&
e2
)
final
{
void
VisitExpr_
(
const
TupleNode
*
pl1
,
const
Expr
&
e2
)
final
{
Tuple
prod1
=
GetRef
<
Tuple
>
(
pl1
);
if
(
const
TupleNode
*
pl2
=
e2
.
as
<
TupleNode
>
())
{
if
(
const
TupleNode
*
pl2
=
e2
.
as
<
TupleNode
>
())
{
Tuple
prod2
=
GetRef
<
Tuple
>
(
pl2
);
if
(
prod1
->
fields
.
size
()
!=
prod2
->
fields
.
size
())
{
equal
=
false
;
...
...
@@ -232,8 +252,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
void
VisitExpr_
(
const
ParamNode
*
p1
,
const
Expr
&
e2
)
final
{
if
(
const
ParamNode
*
p2
=
e2
.
as
<
ParamNode
>
())
{
void
VisitExpr_
(
const
ParamNode
*
p1
,
const
Expr
&
e2
)
final
{
if
(
const
ParamNode
*
p2
=
e2
.
as
<
ParamNode
>
())
{
eq_map
.
Set
(
p1
->
var
,
p2
->
var
);
equal
=
equal
&&
AlphaEqual
(
p1
->
type
,
p2
->
type
);
}
else
{
...
...
@@ -241,8 +261,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
void
VisitExpr_
(
const
FunctionNode
*
func1
,
const
Expr
&
e2
)
final
{
if
(
const
FunctionNode
*
func2
=
e2
.
as
<
FunctionNode
>
())
{
void
VisitExpr_
(
const
FunctionNode
*
func1
,
const
Expr
&
e2
)
final
{
if
(
const
FunctionNode
*
func2
=
e2
.
as
<
FunctionNode
>
())
{
if
(
func1
->
params
.
size
()
!=
func2
->
params
.
size
())
{
equal
=
false
;
return
;
...
...
@@ -258,8 +278,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
void
VisitExpr_
(
const
CallNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
CallNode
*
call
=
e2
.
as
<
CallNode
>
())
{
void
VisitExpr_
(
const
CallNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
CallNode
*
call
=
e2
.
as
<
CallNode
>
())
{
this
->
VisitExpr
(
op
->
op
,
call
->
op
);
if
(
op
->
args
.
size
()
!=
call
->
args
.
size
())
{
...
...
@@ -276,8 +296,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
}
}
void
VisitExpr_
(
const
LetNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
LetNode
*
let
=
e2
.
as
<
LetNode
>
())
{
void
VisitExpr_
(
const
LetNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
LetNode
*
let
=
e2
.
as
<
LetNode
>
())
{
eq_map
.
Set
(
op
->
var
,
let
->
var
);
this
->
VisitExpr
(
op
->
value
,
let
->
value
);
this
->
VisitExpr
(
op
->
body
,
let
->
body
);
...
...
@@ -285,6 +305,36 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal
=
false
;
}
}
void
VisitExpr_
(
const
IfNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
IfNode
*
i
=
e2
.
as
<
IfNode
>
())
{
VisitExpr
(
op
->
cond
,
i
->
cond
);
VisitExpr
(
op
->
true_branch
,
i
->
true_branch
);
VisitExpr
(
op
->
false_branch
,
i
->
false_branch
);
}
else
{
equal
=
false
;
}
}
void
VisitExpr_
(
const
OpNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
OpNode
*
o
=
e2
.
as
<
OpNode
>
())
{
equal
=
equal
&&
op
->
name
==
o
->
name
;
}
else
{
equal
=
false
;
}
}
void
VisitExpr_
(
const
ConstantNode
*
op
,
const
Expr
&
e2
)
final
{
if
(
const
ConstantNode
*
c
=
e2
.
as
<
ConstantNode
>
())
{
if
(
AlphaEqual
(
op
->
tensor_type
(),
c
->
tensor_type
()))
{
equal
=
equal
&&
SameNDArray
(
op
->
data
,
c
->
data
);
}
else
{
equal
=
false
;
}
}
else
{
equal
=
false
;
}
}
};
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
)
{
...
...
@@ -294,15 +344,15 @@ bool AlphaEqual(const Expr& e1, const Expr& e2) {
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API
(
"relay._make._alpha_eq"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TVM_REGISTER_API
(
"relay._make._alpha_eq
ual
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Expr
e1
=
args
[
0
];
Expr
e2
=
args
[
1
];
*
ret
=
AlphaEqual
(
e1
,
e2
);
});
TVM_REGISTER_API
(
"relay._make._type_alpha_eq"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TVM_REGISTER_API
(
"relay._make._type_alpha_eq
ual
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Type
t1
=
args
[
0
];
Type
t2
=
args
[
1
];
*
ret
=
AlphaEqual
(
t1
,
t2
);
...
...
src/relay/pass/dead_code.cc
0 → 100644
View file @
3e527669
/*!
* Copyright (c) 2018 by Contributors
*
* \file dead_code.cc
*
* \brief Remove code that does not effect the program result.
*
* The algorithm is implemented by two visitor:
* CalcDep turn an expr into a dependency graph of expr,
* GenLet turn the dependency graph into a let list, taking only the used value.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
namespace
tvm
{
namespace
relay
{
bool
IsBoolLit
(
const
Expr
&
e
,
bool
b
)
{
if
(
const
ConstantNode
*
c
=
e
.
as
<
ConstantNode
>
())
{
if
(
c
->
is_scalar
())
{
auto
dt
=
c
->
tensor_type
()
->
dtype
;
if
(
dt
==
UInt
(
8
))
{
return
*
reinterpret_cast
<
const
uint8_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
UInt
(
16
))
{
return
*
reinterpret_cast
<
const
uint16_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
UInt
(
32
))
{
return
*
reinterpret_cast
<
const
uint32_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
UInt
(
64
))
{
return
*
reinterpret_cast
<
const
uint64_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
Int
(
8
))
{
return
*
reinterpret_cast
<
const
int8_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
Int
(
16
))
{
return
*
reinterpret_cast
<
const
int16_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
Int
(
32
))
{
return
*
reinterpret_cast
<
const
int32_t
*>
(
c
->
data
->
data
)
==
b
;
}
else
if
(
dt
==
Int
(
64
))
{
return
*
reinterpret_cast
<
const
int64_t
*>
(
c
->
data
->
data
)
==
b
;
}
}
}
return
false
;
}
// calculate the dependency graph from expression
class
CalcDep
:
private
ExprMutator
{
public
:
static
Expr
Eliminate
(
const
Expr
&
e
)
{
CalcDep
cd
;
auto
res
=
cd
(
e
);
GenLet
gl
(
cd
.
var_map_
);
gl
(
res
);
return
gl
.
lets_
.
Get
(
res
);
}
private
:
struct
Binder
{
Type
t
;
Expr
e
;
Binder
(
const
Type
&
t
,
const
Expr
&
e
)
:
t
(
t
),
e
(
e
)
{
}
};
using
VarMap
=
std
::
unordered_map
<
Var
,
Binder
,
NodeHash
,
NodeEqual
>
;
VarMap
var_map_
;
Expr
VisitExpr_
(
const
IfNode
*
i
)
final
{
auto
cond
=
VisitExpr
(
i
->
cond
);
if
(
IsBoolLit
(
cond
,
true
))
{
return
Eliminate
(
i
->
true_branch
);
}
else
if
(
IsBoolLit
(
cond
,
false
))
{
return
Eliminate
(
i
->
false_branch
);
}
else
{
return
IfNode
::
make
(
cond
,
Eliminate
(
i
->
true_branch
),
Eliminate
(
i
->
false_branch
));
}
}
Expr
VisitExpr_
(
const
LetNode
*
l
)
final
{
var_map_
.
insert
(
std
::
pair
<
Var
,
Binder
>
(
l
->
var
,
Binder
(
l
->
value_type
,
Eliminate
(
l
->
value
))));
return
VisitExpr
(
l
->
body
);
}
Expr
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
return
FunctionNode
::
make
(
f
->
params
,
f
->
ret_type
,
Eliminate
(
f
->
body
),
f
->
type_params
);
}
// generate the let list from dependency graph
class
GenLet
:
private
ExprVisitor
{
private
:
LetList
lets_
;
VarMap
var_map_
;
explicit
GenLet
(
const
VarMap
&
var_map
)
:
var_map_
(
var_map
)
{
}
friend
CalcDep
;
void
VisitExpr_
(
const
VarNode
*
vn
)
final
{
Var
v
=
GetRef
<
Var
>
(
vn
);
if
(
var_map_
.
count
(
v
)
!=
0
)
{
auto
val
=
var_map_
.
at
(
v
);
var_map_
.
erase
(
v
);
// erase before visit to handle letrec
VisitExpr
(
val
.
e
);
// visit before push back so the dependency of dependency is before the dependency
lets_
.
Push
(
v
,
val
.
t
,
val
.
e
);
}
}
};
};
Expr
DeadCodeElimination
(
const
Expr
&
e
)
{
return
CalcDep
::
Eliminate
(
e
);
}
TVM_REGISTER_API
(
"relay._ir_pass.dead_code_elimination"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
DeadCodeElimination
(
args
[
0
]);
});
}
// namespace relay
}
// namespace tvm
src/runtime/ndarray.cc
View file @
3e527669
...
...
@@ -25,15 +25,6 @@ inline void VerifyDataType(DLDataType dtype) {
CHECK_EQ
(
dtype
.
bits
&
(
dtype
.
bits
-
1
),
0
);
}
inline
size_t
GetDataSize
(
const
DLTensor
&
arr
)
{
size_t
size
=
1
;
for
(
tvm_index_t
i
=
0
;
i
<
arr
.
ndim
;
++
i
)
{
size
*=
arr
.
shape
[
i
];
}
size
*=
(
arr
.
dtype
.
bits
*
arr
.
dtype
.
lanes
+
7
)
/
8
;
return
size
;
}
inline
size_t
GetDataAlignment
(
const
DLTensor
&
arr
)
{
size_t
align
=
(
arr
.
dtype
.
bits
/
8
)
*
arr
.
dtype
.
lanes
;
if
(
align
<
kAllocAlignment
)
return
kAllocAlignment
;
...
...
@@ -129,8 +120,8 @@ DLManagedTensor* NDArray::ToDLPack() const {
}
NDArray
NDArray
::
Empty
(
std
::
vector
<
int64_t
>
shape
,
DLDataType
dtype
,
DLContext
ctx
)
{
DLDataType
dtype
,
DLContext
ctx
)
{
NDArray
ret
=
Internal
::
Create
(
shape
,
dtype
,
ctx
);
// setup memory content
size_t
size
=
GetDataSize
(
ret
.
data_
->
dl_tensor
);
...
...
tests/python/relay/test_dead_code_elimination.py
0 → 100644
View file @
3e527669
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
dead_code_elimination
,
alpha_equal
from
tvm.relay.ir_builder
import
convert
,
IRBuilder
from
tvm.relay.op
import
log
,
add
,
equal
,
subtract
,
concat
class
env
:
def
__init__
(
self
):
self
.
a
=
relay
.
Var
(
"a"
)
self
.
b
=
relay
.
Var
(
"b"
)
self
.
c
=
relay
.
Var
(
"c"
)
self
.
d
=
relay
.
Var
(
"d"
)
self
.
e
=
relay
.
Var
(
"e"
)
self
.
x
=
relay
.
Var
(
"x"
)
self
.
y
=
relay
.
Var
(
"y"
)
self
.
z
=
relay
.
Var
(
"z"
)
self
.
shape
=
tvm
.
convert
([
1
,
2
,
3
])
self
.
tt
=
relay
.
TensorType
(
self
.
shape
,
"float32"
)
self
.
int32
=
relay
.
TensorType
([],
"int32"
)
self
.
float32
=
relay
.
TensorType
([],
"float32"
)
self
.
one
=
convert
(
1.0
)
self
.
two
=
convert
(
2.0
)
self
.
three
=
convert
(
3.0
)
e
=
env
()
def
test_let
():
orig
=
relay
.
Let
(
e
.
x
,
e
.
y
,
e
.
z
,
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
z
)
def
test_used_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
),
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
c
,
e
.
tt
))
def
test_chain_unused_let
():
orig
=
relay
.
Let
(
e
.
a
,
e
.
b
,
relay
.
Let
(
e
.
c
,
e
.
d
,
e
.
e
,
e
.
tt
),
e
.
tt
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
e
)
# make sure we dont infinite loop
def
test_recursion
():
"""
Program:
let f(n: i32, data: f32) -> f32 = {
if (n == 0) {
return data;
} else {
return f(n - 1, log(data));
}
}
f(2, 10000);
"""
f
=
relay
.
Var
(
"f"
)
n
=
relay
.
Var
(
"n"
)
np
=
relay
.
Param
(
n
,
e
.
int32
)
data
=
relay
.
Var
(
"data"
)
datap
=
relay
.
Param
(
data
,
e
.
float32
)
funcbody
=
relay
.
If
(
equal
(
n
,
convert
(
0
)),
data
,
f
(
subtract
(
n
,
convert
(
1.0
)),
log
(
data
)))
value
=
relay
.
Function
([
np
,
datap
],
e
.
float32
,
funcbody
,
[])
orig
=
relay
.
Let
(
f
,
funcbody
,
f
(
convert
(
2.0
),
convert
(
10000.0
)),
e
.
float32
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
,
e
.
float32
)),
e
.
three
)
def
test_op_let
():
assert
alpha_equal
(
dead_code_elimination
(
add
(
relay
.
Let
(
e
.
a
,
e
.
one
,
e
.
three
,
e
.
float32
),
e
.
two
)),
add
(
e
.
three
,
e
.
two
))
def
test_if
():
orig
=
relay
.
If
(
convert
(
True
),
e
.
a
,
e
.
b
)
assert
alpha_equal
(
dead_code_elimination
(
orig
),
e
.
a
)
if
__name__
==
"__main__"
:
test_let
()
test_used_let
()
test_chain_unused_let
()
test_recursion
()
test_op_let
()
test_if
()
tests/python/relay/test_pass_alpha_eq.py
→
tests/python/relay/test_pass_alpha_eq
ual
.py
View file @
3e527669
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
alpha_equal
from
tvm.relay.ir_builder
import
convert
def
test_tensor_type_alpha_eq
():
def
test_tensor_type_alpha_equal
():
t1
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
t2
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
t3
=
relay
.
TensorType
((
3
,
4
,
5
),
"float32"
)
...
...
@@ -13,8 +14,14 @@ def test_tensor_type_alpha_eq():
t2
=
relay
.
TensorType
((),
"float32"
)
assert
t1
==
t2
def
test_constant_alpha_equal
():
x
=
convert
(
1
)
y
=
convert
(
2
)
assert
alpha_equal
(
x
,
x
)
assert
not
alpha_equal
(
x
,
y
)
assert
alpha_equal
(
x
,
convert
(
1
))
def
test_incomplete_type_alpha_eq
():
def
test_incomplete_type_alpha_eq
ual
():
t1
=
relay
.
IncompleteType
(
relay
.
Kind
.
Shape
)
t2
=
relay
.
IncompleteType
(
relay
.
Kind
.
Type
)
t3
=
relay
.
IncompleteType
(
relay
.
Kind
.
Type
)
...
...
@@ -26,7 +33,7 @@ def test_incomplete_type_alpha_eq():
assert
t2
!=
t3
def
test_type_param_alpha_eq
():
def
test_type_param_alpha_eq
ual
():
t1
=
relay
.
TypeParam
(
"v1"
,
relay
.
Kind
.
Type
)
t2
=
relay
.
TypeParam
(
"v2"
,
relay
.
Kind
.
Shape
)
t3
=
relay
.
TypeParam
(
"v3"
,
relay
.
Kind
.
Type
)
...
...
@@ -48,7 +55,7 @@ def test_type_param_alpha_eq():
assert
ft1
!=
ft3
# kinds still do not match
def
test_func_type_alpha_eq
():
def
test_func_type_alpha_eq
ual
():
t1
=
relay
.
TensorType
((
1
,
2
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
...
...
@@ -108,7 +115,7 @@ def test_func_type_alpha_eq():
assert
ft
!=
more_rels
def
test_tuple_type_alpha_eq
():
def
test_tuple_type_alpha_eq
ual
():
t1
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
,
4
),
"float32"
)
tp1
=
relay
.
TypeParam
(
"v1"
,
relay
.
Kind
.
Type
)
...
...
@@ -126,7 +133,7 @@ def test_tuple_type_alpha_eq():
assert
tup1
!=
tup4
def
test_type_relation_alpha_eq
():
def
test_type_relation_alpha_eq
ual
():
t1
=
relay
.
TensorType
((
1
,
2
),
"float32"
)
t2
=
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
)
t3
=
relay
.
TensorType
((
1
,
2
,
3
,
4
),
"float32"
)
...
...
@@ -162,9 +169,9 @@ def test_type_relation_alpha_eq():
if
__name__
==
"__main__"
:
test_tensor_type_alpha_eq
()
test_incomplete_type_alpha_eq
()
test_type_param_alpha_eq
()
test_func_type_alpha_eq
()
test_tuple_type_alpha_eq
()
test_type_relation_alpha_eq
()
test_tensor_type_alpha_eq
ual
()
test_incomplete_type_alpha_eq
ual
()
test_type_param_alpha_eq
ual
()
test_func_type_alpha_eq
ual
()
test_tuple_type_alpha_eq
ual
()
test_type_relation_alpha_eq
ual
()
tests/python/relay/test_type_infer.py
View file @
3e527669
...
...
@@ -120,9 +120,9 @@ def test_recursion():
Program:
def f(n: i32, data: f32) -> f32 {
if (n == 0) {
return f(n - 1, log(data));
} else {
return data;
} else {
return f(n - 1, log(data));
}
}
f(2, 10000);
...
...
@@ -133,9 +133,9 @@ def test_recursion():
data
=
b
.
param
(
'data'
,
ty
=
'float32'
)
with
b
.
decl
(
f
,
n
,
data
):
with
b
.
if_scope
(
equal
(
n
,
convert
(
0
))):
b
.
ret
(
f
(
subtract
(
n
,
convert
(
1
)),
log
(
data
)))
with
b
.
else_scope
():
b
.
ret
(
data
)
with
b
.
else_scope
():
b
.
ret
(
f
(
subtract
(
n
,
convert
(
1
)),
log
(
data
)))
b
.
ret
(
f
(
convert
(
2.0
),
convert
(
10000.0
)))
assert_decl_has_type
(
b
.
env
,
'f'
,
func_type
(
[
'int32'
,
'float32'
],
'float32'
))
...
...
@@ -160,11 +160,11 @@ def test_concat():
if
__name__
==
"__main__"
:
test_dual_op
()
test_recursion
()
test_monomorphic_let
()
test_single_op
()
test_add_op
()
test_add_broadcast_op
()
test_decl
()
test_recursion
()
test_concat
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment