Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
dd9d76ac
Commit
dd9d76ac
authored
Oct 31, 2018
by
ziheng
Committed by
Tianqi Chen
Oct 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY/PASS] Simplify inference. (#2033)
parent
2f9ab71e
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
241 additions
and
2 deletions
+241
-2
nnvm/tests/python/compiler/test_simplify_inference.py
+0
-1
python/tvm/relay/expr.py
+60
-1
python/tvm/relay/ir_pass.py
+15
-0
python/tvm/relay/op/__init__.py
+8
-0
src/relay/pass/pattern_util.h
+34
-0
src/relay/pass/simplify_inference.cc
+77
-0
tests/python/relay/test_pass_simplify_inference.py
+47
-0
No files found.
nnvm/tests/python/compiler/test_simplify_inference.py
View file @
dd9d76ac
...
@@ -10,7 +10,6 @@ def test_simplify_batchnorm():
...
@@ -10,7 +10,6 @@ def test_simplify_batchnorm():
scale
=
sym
.
elemwise_mul
(
1
/
sym
.
sqrt
(
moving_var
+
epsilon
),
gamma
)
scale
=
sym
.
elemwise_mul
(
1
/
sym
.
sqrt
(
moving_var
+
epsilon
),
gamma
)
shift
=
sym
.
elemwise_add
(
shift
=
sym
.
elemwise_add
(
sym
.
elemwise_mul
(
sym
.
negative
(
moving_mean
),
scale
),
beta
)
sym
.
elemwise_mul
(
sym
.
negative
(
moving_mean
),
scale
),
beta
)
shape
=
[
-
1
if
i
==
axis
else
1
for
i
in
range
(
len
(
shape
))]
# for 2D
# for 2D
num_newaxis
=
len
(
shape
)
-
axis
-
1
num_newaxis
=
len
(
shape
)
-
axis
-
1
if
num_newaxis
:
if
num_newaxis
:
...
...
python/tvm/relay/expr.py
View file @
dd9d76ac
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay."""
"""The expression nodes of Relay."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
numbers
import
Number
as
_Number
import
numpy
as
_np
import
numpy
as
_np
from
.base
import
RelayNode
,
register_relay_node
from
.base
import
RelayNode
,
register_relay_node
...
@@ -11,6 +12,8 @@ from .._ffi import base as _base
...
@@ -11,6 +12,8 @@ from .._ffi import base as _base
from
..
import
nd
as
_nd
from
..
import
nd
as
_nd
from
..
import
convert
from
..
import
convert
# will be registered afterwards
_op_make
=
None
class
Expr
(
RelayNode
):
class
Expr
(
RelayNode
):
"""The base type for all Relay expressions."""
"""The base type for all Relay expressions."""
...
@@ -48,6 +51,62 @@ class Expr(RelayNode):
...
@@ -48,6 +51,62 @@ class Expr(RelayNode):
"""
"""
return
_make
.
dtype_cast
(
self
,
dtype
)
return
_make
.
dtype_cast
(
self
,
dtype
)
def
__add__
(
self
,
other
):
if
isinstance
(
other
,
Expr
):
return
_op_make
.
add
(
self
,
other
)
elif
isinstance
(
other
,
_Number
):
raise
TypeError
(
'convert "
%
s" with `const` first'
%
str
(
other
))
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
if
isinstance
(
other
,
Expr
):
return
_op_make
.
subtract
(
self
,
other
)
elif
isinstance
(
other
,
_Number
):
raise
TypeError
(
'convert "
%
s" with `const` first'
%
str
(
other
))
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__rsub__
(
self
,
other
):
if
isinstance
(
other
,
_Number
):
raise
TypeError
(
'convert "
%
s" with `const` first'
%
str
(
other
))
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__mul__
(
self
,
other
):
if
isinstance
(
other
,
Expr
):
return
_op_make
.
multiply
(
self
,
other
)
elif
isinstance
(
other
,
_Number
):
raise
TypeError
(
'convert "
%
s" with `const` first'
%
str
(
other
))
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__rmul__
(
self
,
other
):
return
self
.
__mul__
(
other
)
def
__div__
(
self
,
other
):
if
isinstance
(
other
,
Expr
):
return
_op_make
.
divide
(
self
,
other
)
elif
isinstance
(
other
,
_Number
):
raise
TypeError
(
'convert "
%
s" with `const` first'
%
str
(
other
))
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__rdiv__
(
self
,
other
):
if
isinstance
(
other
,
_Number
):
raise
TypeError
(
'convert "
%
s" with `const` first'
%
str
(
other
))
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__truediv__
(
self
,
other
):
return
self
.
__div__
(
other
)
def
__rtruediv__
(
self
,
other
):
return
self
.
__rdiv__
(
other
)
@register_relay_node
@register_relay_node
class
Constant
(
Expr
):
class
Constant
(
Expr
):
...
@@ -305,7 +364,7 @@ class TupleWrapper(object):
...
@@ -305,7 +364,7 @@ class TupleWrapper(object):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
"TupleWrapper("
+
self
.
tuple_value
.
__repr__
()
+
return
(
"TupleWrapper("
+
self
.
tuple_value
.
__repr__
()
+
", "
+
s
elf
.
size
+
")"
)
", "
+
s
tr
(
self
.
size
)
+
")"
)
def
astype
(
self
,
_
):
def
astype
(
self
,
_
):
raise
TypeError
(
"astype cannot be used on tuple"
)
raise
TypeError
(
"astype cannot be used on tuple"
)
...
...
python/tvm/relay/ir_pass.py
View file @
dd9d76ac
...
@@ -160,6 +160,21 @@ def free_type_vars(expr):
...
@@ -160,6 +160,21 @@ def free_type_vars(expr):
"""
"""
return
_ir_pass
.
free_type_vars
(
expr
)
return
_ir_pass
.
free_type_vars
(
expr
)
def
simplify_inference
(
expr
):
""" Simplify the data-flow graph for inference phase.
Parameters
----------
e: tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
return
_ir_pass
.
simplify_inference
(
expr
)
def
dead_code_elimination
(
expr
):
def
dead_code_elimination
(
expr
):
""" Remove expressions which does not effect the program result (dead code).
""" Remove expressions which does not effect the program result (dead code).
...
...
python/tvm/relay/op/__init__.py
View file @
dd9d76ac
...
@@ -15,3 +15,11 @@ from . import vision
...
@@ -15,3 +15,11 @@ from . import vision
from
.
import
_tensor
from
.
import
_tensor
from
..expr
import
Expr
from
..expr
import
Expr
from
..base
import
register_relay_node
from
..base
import
register_relay_node
def
_register_op_make
():
from
.
import
_make
from
..
import
expr
expr
.
_op_make
=
_make
_register_op_make
()
src/relay/pass/pattern_util.h
View file @
dd9d76ac
...
@@ -120,6 +120,40 @@ inline bool IsDepthwiseConv2D(const Call& call,
...
@@ -120,6 +120,40 @@ inline bool IsDepthwiseConv2D(const Call& call,
}
}
/*!
* \brief Create a Constant with a scalar
*
* \param dtype The data type.
* \param value The value of the scalar.
* \return A Constant.
*/
template
<
typename
T
>
inline
Constant
MakeConstantScalar
(
DataType
dtype
,
T
value
)
{
CHECK_EQ
(
sizeof
(
T
)
*
8
,
dtype
.
bits
())
<<
"data type mismatch"
;
runtime
::
NDArray
arr
=
runtime
::
NDArray
::
Empty
({},
Type2TVMType
(
dtype
),
{
kDLCPU
,
0
});
*
static_cast
<
T
*>
(
arr
->
data
)
=
value
;
return
ConstantNode
::
make
(
arr
);
}
inline
Expr
Negative
(
Expr
x
)
{
static
const
Op
&
op
=
Op
::
Get
(
"negative"
);
return
CallNode
::
make
(
op
,
{
x
},
Attrs
(),
{});
}
inline
Expr
Sqrt
(
Expr
x
)
{
static
const
Op
&
op
=
Op
::
Get
(
"sqrt"
);
return
CallNode
::
make
(
op
,
{
x
},
Attrs
(),
{});
}
inline
Expr
Add
(
Expr
lhs
,
Expr
rhs
)
{
static
const
Op
&
op
=
Op
::
Get
(
"add"
);
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
}
inline
Expr
Multiply
(
Expr
lhs
,
Expr
rhs
)
{
inline
Expr
Multiply
(
Expr
lhs
,
Expr
rhs
)
{
static
const
Op
&
op
=
Op
::
Get
(
"multiply"
);
static
const
Op
&
op
=
Op
::
Get
(
"multiply"
);
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
...
...
src/relay/pass/simplify_inference.cc
0 → 100644
View file @
dd9d76ac
/*!
* Copyright (c) 2018 by Contributors
* \file simplify_inference.cc
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include "./pattern_util.h"
namespace
tvm
{
namespace
relay
{
Expr
BatchNormToInferUnpack
(
const
Attrs
attrs
,
Expr
data
,
Expr
gamma
,
Expr
beta
,
Expr
moving_mean
,
Expr
moving_var
)
{
const
auto
param
=
attrs
.
as
<
BatchNormAttrs
>
();
Expr
epsilon
=
MakeConstantScalar
(
Float
(
32
),
static_cast
<
float
>
(
param
->
epsilon
));
Expr
var_add_eps
=
Add
(
moving_var
,
epsilon
);
Expr
sqrt_var
=
Sqrt
(
var_add_eps
);
Expr
scale
=
Divide
(
MakeConstantScalar
(
Float
(
32
),
1.0
f
),
sqrt_var
);
if
(
param
->
scale
)
{
scale
=
Multiply
(
scale
,
gamma
);
}
Expr
neg_mean
=
Negative
(
moving_mean
);
Expr
shift
=
Multiply
(
neg_mean
,
scale
);
if
(
param
->
center
)
{
shift
=
Add
(
shift
,
beta
);
}
int
axis
=
param
->
axis
;
const
auto
*
tdata
=
data
->
type_as
<
TensorTypeNode
>
();
scale
=
ExpandBiasToMatchAxis
(
scale
,
tdata
->
shape
.
size
(),
{
axis
});
shift
=
ExpandBiasToMatchAxis
(
shift
,
tdata
->
shape
.
size
(),
{
axis
});
Expr
out
=
Multiply
(
data
,
scale
);
out
=
Add
(
out
,
shift
);
return
out
;
}
class
InferenceSimplifier
:
public
ExprMutator
{
public
:
Expr
VisitExpr_
(
const
TupleGetItemNode
*
n
)
final
{
static
const
Op
&
batch_norm
=
Op
::
Get
(
"nn.batch_norm"
);
static
const
Op
&
dropout
=
Op
::
Get
(
"nn.dropout"
);
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
n
);
const
auto
*
new_n
=
new_e
.
as
<
TupleGetItemNode
>
();
if
(
new_n
->
index
!=
0
)
{
return
new_e
;
}
if
(
const
auto
*
call
=
new_n
->
tuple
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
batch_norm
))
{
return
BatchNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
call
->
args
[
3
],
call
->
args
[
4
]);
}
else
if
(
call
->
op
.
same_as
(
dropout
))
{
return
call
->
args
[
0
];
}
}
return
new_e
;
}
};
Expr
SimplifyInference
(
const
Expr
&
e
)
{
return
InferenceSimplifier
().
Mutate
(
e
);
}
TVM_REGISTER_API
(
"relay._ir_pass.simplify_inference"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
SimplifyInference
(
args
[
0
]);
});
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_simplify_inference.py
0 → 100644
View file @
dd9d76ac
from
tvm
import
relay
as
rly
from
tvm.relay.ir_pass
import
simplify_inference
,
alpha_equal
def
test_simplify_batchnorm
():
def
simple_bn
(
x
,
gamma
,
beta
,
moving_mean
,
moving_var
,
axis
=
1
,
epsilon
=
1e-5
,
shape
=
None
):
# expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
scale
=
rly
.
multiply
(
rly
.
const
(
1
,
'float32'
)
/
rly
.
sqrt
(
moving_var
+
rly
.
const
(
epsilon
,
'float32'
)),
gamma
)
shift
=
rly
.
add
(
rly
.
multiply
(
rly
.
negative
(
moving_mean
),
scale
),
beta
)
num_newaxis
=
len
(
shape
)
-
(
axis
+
1
)
if
num_newaxis
:
scale
=
rly
.
expand_dims
(
scale
,
axis
=
1
,
num_newaxis
=
num_newaxis
)
shift
=
rly
.
expand_dims
(
shift
,
axis
=
1
,
num_newaxis
=
num_newaxis
)
return
x
*
scale
+
shift
def
check
(
dim
,
axis
,
nstep
):
eps
=
0.01
ttype1
=
rly
.
TensorType
(
tuple
(
10
for
i
in
range
(
dim
)),
'float32'
)
ttype2
=
rly
.
TensorType
((
10
,),
'float32'
)
x
=
rly
.
var
(
"x"
,
ttype1
)
beta
=
rly
.
var
(
"beta"
,
ttype2
)
gamma
=
rly
.
var
(
"gamma"
,
ttype2
)
moving_var
=
rly
.
var
(
"moving_var"
,
ttype2
)
moving_mean
=
rly
.
var
(
"moving_mean"
,
ttype2
)
y1
,
y2
=
x
,
x
for
_
in
range
(
nstep
):
y1
,
_
,
_
=
rly
.
nn
.
batch_norm
(
y1
+
rly
.
const
(
1
,
'float32'
),
gamma
,
beta
,
moving_mean
,
moving_var
,
epsilon
=
eps
,
axis
=
axis
)
y1
=
rly
.
nn
.
dropout
(
y1
)
y1
=
rly
.
ir_pass
.
infer_type
(
y1
)
y1
=
simplify_inference
(
y1
)
y2
=
simple_bn
(
y2
+
rly
.
const
(
1
,
'float32'
),
gamma
,
beta
,
moving_mean
,
moving_var
,
epsilon
=
eps
,
axis
=
axis
,
shape
=
ttype1
.
shape
)
assert
rly
.
ir_pass
.
graph_equal
(
y1
,
y2
)
check
(
2
,
1
,
1
)
check
(
4
,
1
,
1
)
check
(
4
,
0
,
3
)
if
__name__
==
"__main__"
:
test_simplify_batchnorm
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment