Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
672203f2
Commit
672203f2
authored
Aug 02, 2019
by
雾雨魔理沙
Committed by
Thierry Moreau
Aug 02, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] [Error] Fix error in partial evaluator (#3693)
* fix * lint
parent
8ad36a17
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
13 deletions
+37
-13
src/relay/pass/partial_eval.cc
+27
-12
tests/python/relay/test_pass_partial_eval.py
+10
-1
No files found.
src/relay/pass/partial_eval.cc
View file @
672203f2
...
...
@@ -131,7 +131,7 @@ Expr PostProcess(const Expr&);
/*! \brief The base container type of Relay values. */
class
StaticNode
:
public
RelayNode
{
public
:
static
constexpr
const
char
*
_type_key
=
"relay.
Value
"
;
static
constexpr
const
char
*
_type_key
=
"relay.
Static
"
;
TVM_DECLARE_BASE_NODE_INFO
(
ValueNode
,
RelayNode
);
};
...
...
@@ -161,6 +161,7 @@ struct PStaticNode : Node {
PStaticNode
(
const
Static
&
pstatic
,
const
Expr
&
dynamic
)
:
pstatic
(
pstatic
),
dynamic
(
dynamic
),
created_time
(
time
())
{
}
explicit
PStaticNode
(
const
Expr
&
dynamic
)
:
PStaticNode
(
Static
(),
dynamic
)
{
}
static
constexpr
const
char
*
_type_key
=
"relay.PStatic"
;
TVM_DECLARE_NODE_TYPE_INFO
(
PStaticNode
,
Node
);
};
...
...
@@ -169,6 +170,7 @@ RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef);
struct
STupleNode
:
StaticNode
{
std
::
vector
<
PStatic
>
fields
;
explicit
STupleNode
(
const
std
::
vector
<
PStatic
>&
fields
)
:
fields
(
fields
)
{
}
static
constexpr
const
char
*
_type_key
=
"relay.STuple"
;
TVM_DECLARE_NODE_TYPE_INFO
(
STupleNode
,
StaticNode
);
};
...
...
@@ -181,7 +183,8 @@ Static MkSTuple(const std::vector<PStatic>& fields) {
struct
STensorNode
:
StaticNode
{
runtime
::
NDArray
data
;
explicit
STensorNode
(
const
NDArray
&
data
)
:
data
(
data
)
{
}
TVM_DECLARE_NODE_TYPE_INFO
(
STupleNode
,
StaticNode
);
static
constexpr
const
char
*
_type_key
=
"relay.STensor"
;
TVM_DECLARE_NODE_TYPE_INFO
(
STensorNode
,
StaticNode
);
};
RELAY_DEFINE_NODE_REF
(
STensor
,
STensorNode
,
Value
);
...
...
@@ -195,6 +198,7 @@ struct SConstructorNode : StaticNode {
std
::
vector
<
PStatic
>
fields
;
SConstructorNode
(
const
Constructor
&
constructor
,
const
std
::
vector
<
PStatic
>&
fields
)
:
constructor
(
constructor
),
fields
(
fields
)
{
}
static
constexpr
const
char
*
_type_key
=
"relay.SConstructor"
;
TVM_DECLARE_NODE_TYPE_INFO
(
SConstructorNode
,
StaticNode
);
};
...
...
@@ -205,6 +209,7 @@ Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>
}
struct
SRefNode
:
StaticNode
{
static
constexpr
const
char
*
_type_key
=
"relay.SRef"
;
// we will use the address as the guid for hashing
TVM_DECLARE_NODE_TYPE_INFO
(
SRefNode
,
StaticNode
);
};
...
...
@@ -223,6 +228,7 @@ using Func = std::function<PStatic(const std::vector<PStatic>&,
struct
SFuncNode
:
StaticNode
{
Func
func
;
explicit
SFuncNode
(
const
Func
&
func
)
:
func
(
func
)
{
}
static
constexpr
const
char
*
_type_key
=
"relay.SFunc"
;
TVM_DECLARE_NODE_TYPE_INFO
(
SFuncNode
,
StaticNode
);
};
...
...
@@ -711,8 +717,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return
VisitFunc
(
GetRef
<
Function
>
(
op
),
ll
);
}
struct
ReflectError
:
dmlc
::
Error
{
ReflectError
()
:
dmlc
::
Error
(
"static value not found"
)
{
}
};
Expr
Reflect
(
const
PStatic
&
st
)
{
if
(
const
STensorNode
*
op
=
st
->
pstatic
.
as
<
STensorNode
>
())
{
if
(
!
st
->
pstatic
.
defined
())
{
throw
ReflectError
();
}
else
if
(
const
STensorNode
*
op
=
st
->
pstatic
.
as
<
STensorNode
>
())
{
return
ConstantNode
::
make
(
op
->
data
);
}
else
if
(
const
STupleNode
*
op
=
st
->
pstatic
.
as
<
STupleNode
>
())
{
tvm
::
Array
<
Expr
>
fields
;
...
...
@@ -721,7 +733,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
return
TupleNode
::
make
(
fields
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown case
"
;
LOG
(
FATAL
)
<<
"Unknown case
: "
<<
st
->
dynamic
;
throw
;
}
}
...
...
@@ -767,19 +779,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for
(
const
PStatic
&
ps
:
pv
)
{
ns_args
.
push_back
(
ps
->
dynamic
);
}
PStatic
ns
=
NoStatic
(
ll
->
Push
(
CallNode
::
make
(
expr
,
ns_args
,
attrs
,
type_args
)));
auto
ns
=
[
&
]()
{
return
NoStatic
(
ll
->
Push
(
CallNode
::
make
(
expr
,
ns_args
,
attrs
,
type_args
)));
};
if
(
StatefulOp
(
expr
))
{
return
ns
;
return
ns
()
;
}
t
vm
::
Array
<
Expr
>
args
;
for
(
const
PStatic
&
ps
:
pv
)
{
if
(
ps
->
pstatic
.
defined
()
)
{
t
ry
{
tvm
::
Array
<
Expr
>
args
;
for
(
const
PStatic
&
ps
:
pv
)
{
args
.
push_back
(
Reflect
(
ps
));
}
else
{
return
ns
;
}
return
ConstEvaluate
(
CallNode
::
make
(
expr
,
args
,
attrs
,
type_args
),
ll
);
}
catch
(
const
ReflectError
&
)
{
return
ns
();
}
return
ConstEvaluate
(
CallNode
::
make
(
expr
,
args
,
attrs
,
type_args
),
ll
);
};
}
...
...
tests/python/relay/test_pass_partial_eval.py
View file @
672203f2
...
...
@@ -18,7 +18,7 @@
import
numpy
as
np
import
tvm
from
tvm
import
relay
from
tvm.relay.analysis
import
alpha_equal
from
tvm.relay.analysis
import
alpha_equal
,
assert_alpha_equal
from
tvm.relay.prelude
import
Prelude
from
tvm.relay
import
op
,
create_executor
,
transform
from
tvm.relay
import
Var
,
TypeVar
,
TupleGetItem
,
Let
,
Function
,
const
,
RefRead
,
RefWrite
,
RefCreate
...
...
@@ -306,6 +306,14 @@ def test_double():
assert
alpha_equal
(
res
.
body
,
make_nat_expr
(
p
,
6
))
def
test_concat
():
t
=
relay
.
TensorType
([
10
],
"float32"
)
x
=
Var
(
"x"
,
t
)
y
=
Var
(
"x"
,
t
)
orig
=
run_infer_type
(
Function
([
x
,
y
],
op
.
concatenate
([
x
,
y
],
axis
=
0
)))
assert_alpha_equal
(
orig
,
dcpe
(
orig
))
if
__name__
==
'__main__'
:
test_ref
()
test_tuple
()
...
...
@@ -323,3 +331,4 @@ if __name__ == '__main__':
test_nat_id
()
test_global_match_nat_id
()
test_match_nat_id
()
test_concat
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment