Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b3f3ab55
Commit
b3f3ab55
authored
Jul 03, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
Jul 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Fix PE (#3482)
parent
287078c3
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
101 additions
and
67 deletions
+101
-67
include/tvm/relay/module.h
+1
-1
src/relay/ir/expr_functor.cc
+20
-6
src/relay/ir/type_functor.cc
+5
-1
src/relay/ir/type_functor.h
+1
-0
src/relay/pass/let_list.h
+1
-1
src/relay/pass/partial_eval.cc
+60
-56
src/relay/pass/type_infer.cc
+1
-0
src/relay/pass/util.cc
+10
-0
tests/python/relay/test_pass_partial_eval.py
+2
-2
No files found.
include/tvm/relay/module.h
View file @
b3f3ab55
...
@@ -55,7 +55,7 @@ struct Module;
...
@@ -55,7 +55,7 @@ struct Module;
* The functional style allows users to construct custom
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* environments easily, for example each thread can store
* a Module while auto-tuning.
* a Module while auto-tuning.
*
*
/
*/
class
ModuleNode
:
public
RelayNode
{
class
ModuleNode
:
public
RelayNode
{
public
:
public
:
...
...
src/relay/ir/expr_functor.cc
View file @
b3f3ab55
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
*/
*/
/*!
/*!
* Copyright (c) 201
8
by Contributors
* Copyright (c) 201
9
by Contributors
* \file src/tvm/relay/expr_mutator.cc
* \file src/tvm/relay/expr_mutator.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
*
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
* the cost of using functional updates.
* the cost of using functional updates.
*/
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
#include "type_functor.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit")
...
@@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit")
});
});
// Implement bind.
// Implement bind.
class
ExprBinder
:
public
ExprMutator
{
class
ExprBinder
:
public
ExprMutator
,
PatternMutator
{
public
:
public
:
explicit
ExprBinder
(
const
tvm
::
Map
<
Var
,
Expr
>&
args_map
)
explicit
ExprBinder
(
const
tvm
::
Map
<
Var
,
Expr
>&
args_map
)
:
args_map_
(
args_map
)
{
:
args_map_
(
args_map
)
{
...
@@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator {
...
@@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator {
}
}
}
}
Pattern
VisitPattern
(
const
Pattern
&
p
)
final
{
return
PatternMutator
::
VisitPattern
(
p
);
}
Clause
VisitClause
(
const
Clause
&
c
)
final
{
Pattern
pat
=
VisitPattern
(
c
->
lhs
);
return
ClauseNode
::
make
(
pat
,
VisitExpr
(
c
->
rhs
));
}
Var
VisitVar
(
const
Var
&
v
)
final
{
return
Downcast
<
Var
>
(
VisitExpr
(
v
));
}
private
:
private
:
const
tvm
::
Map
<
Var
,
Expr
>&
args_map_
;
const
tvm
::
Map
<
Var
,
Expr
>&
args_map_
;
};
};
Expr
Bind
(
const
Expr
&
expr
,
const
tvm
::
Map
<
Var
,
Expr
>&
args_map
)
{
Expr
Bind
(
const
Expr
&
expr
,
const
tvm
::
Map
<
Var
,
Expr
>&
args_map
)
{
if
(
const
FunctionNode
*
func
=
expr
.
as
<
FunctionNode
>
())
{
if
(
const
FunctionNode
*
func
=
expr
.
as
<
FunctionNode
>
())
{
Expr
new_body
=
ExprBinder
(
args_map
).
Mutate
(
func
->
body
);
Expr
new_body
=
ExprBinder
(
args_map
).
VisitExpr
(
func
->
body
);
Array
<
Var
>
new_params
;
Array
<
Var
>
new_params
;
for
(
Var
param
:
func
->
params
)
{
for
(
Var
param
:
func
->
params
)
{
if
(
!
args_map
.
count
(
param
))
{
if
(
!
args_map
.
count
(
param
))
{
...
@@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
...
@@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
func
->
type_params
,
func
->
type_params
,
func
->
attrs
);
func
->
attrs
);
}
else
{
}
else
{
return
ExprBinder
(
args_map
).
Mutate
(
expr
);
return
ExprBinder
(
args_map
).
VisitExpr
(
expr
);
}
}
}
}
...
...
src/relay/ir/type_functor.cc
View file @
b3f3ab55
...
@@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
...
@@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
}
}
}
}
Type
TypeMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
.
defined
()
?
TypeFunctor
<
Type
(
const
Type
&
)
>::
VisitType
(
t
)
:
t
;
}
// Type Mutator.
// Type Mutator.
Array
<
Type
>
TypeMutator
::
MutateArray
(
Array
<
Type
>
arr
)
{
Array
<
Type
>
TypeMutator
::
MutateArray
(
Array
<
Type
>
arr
)
{
// The array will do copy on write
// The array will do copy on write
...
@@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator {
...
@@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator {
};
};
Type
Bind
(
const
Type
&
type
,
const
tvm
::
Map
<
TypeVar
,
Type
>&
args_map
)
{
Type
Bind
(
const
Type
&
type
,
const
tvm
::
Map
<
TypeVar
,
Type
>&
args_map
)
{
return
type
.
defined
()
?
TypeBinder
(
args_map
).
VisitType
(
type
)
:
type
;
return
TypeBinder
(
args_map
).
VisitType
(
type
)
;
}
}
}
// namespace relay
}
// namespace relay
...
...
src/relay/ir/type_functor.h
View file @
b3f3ab55
...
@@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
...
@@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
// Mutator that transform a type to another one.
// Mutator that transform a type to another one.
class
TypeMutator
:
public
TypeFunctor
<
Type
(
const
Type
&
n
)
>
{
class
TypeMutator
:
public
TypeFunctor
<
Type
(
const
Type
&
n
)
>
{
public
:
public
:
Type
VisitType
(
const
Type
&
t
)
override
;
Type
VisitType_
(
const
TypeVarNode
*
op
)
override
;
Type
VisitType_
(
const
TypeVarNode
*
op
)
override
;
Type
VisitType_
(
const
TensorTypeNode
*
op
)
override
;
Type
VisitType_
(
const
TensorTypeNode
*
op
)
override
;
Type
VisitType_
(
const
IncompleteTypeNode
*
op
)
override
;
Type
VisitType_
(
const
IncompleteTypeNode
*
op
)
override
;
...
...
src/relay/pass/let_list.h
View file @
b3f3ab55
...
@@ -48,7 +48,7 @@ class LetList {
...
@@ -48,7 +48,7 @@ class LetList {
public
:
public
:
~
LetList
()
{
~
LetList
()
{
if
(
lets_
.
size
()
>
0
&&
!
used_
)
{
if
(
lets_
.
size
()
>
0
&&
!
used_
)
{
std
::
cout
<<
"Warning: letlist not used"
<<
std
::
endl
;
LOG
(
WARNING
)
<<
"letlist not used"
;
}
}
}
}
/*!
/*!
...
...
src/relay/pass/partial_eval.cc
View file @
b3f3ab55
...
@@ -64,7 +64,7 @@
...
@@ -64,7 +64,7 @@
* 3: The generated code reuses bindings (although they are not shadowed),
* 3: The generated code reuses bindings (although they are not shadowed),
* so we have to deduplicate them.
* so we have to deduplicate them.
*
*
* 4: In the generated code, multiple VarNode might have same Id.
* 4: In the generated code,
as it call TypeSubst,
multiple VarNode might have same Id.
* While it is permitted, most pass use NodeHash for Var,
* While it is permitted, most pass use NodeHash for Var,
* and having multiple VarNode for same Id break them.
* and having multiple VarNode for same Id break them.
* Thus we remap them to a single Id for now.
* Thus we remap them to a single Id for now.
...
@@ -216,9 +216,9 @@ Static MkSRef() {
...
@@ -216,9 +216,9 @@ Static MkSRef() {
}
}
using
Func
=
std
::
function
<
PStatic
(
const
std
::
vector
<
PStatic
>&
,
using
Func
=
std
::
function
<
PStatic
(
const
std
::
vector
<
PStatic
>&
,
const
Attrs
&
,
const
Attrs
&
,
const
Array
<
Type
>&
,
const
Array
<
Type
>&
,
LetList
*
)
>
;
LetList
*
)
>
;
struct
SFuncNode
:
StaticNode
{
struct
SFuncNode
:
StaticNode
{
Func
func
;
Func
func
;
...
@@ -256,6 +256,7 @@ class Environment {
...
@@ -256,6 +256,7 @@ class Environment {
void
Insert
(
const
Var
&
v
,
const
PStatic
&
ps
)
{
void
Insert
(
const
Var
&
v
,
const
PStatic
&
ps
)
{
CHECK
(
ps
.
defined
());
CHECK
(
ps
.
defined
());
CHECK_EQ
(
env_
.
back
().
locals
.
count
(
v
),
0
);
env_
.
back
().
locals
[
v
]
=
ps
;
env_
.
back
().
locals
[
v
]
=
ps
;
}
}
...
@@ -287,12 +288,17 @@ class Environment {
...
@@ -287,12 +288,17 @@ class Environment {
/*!
/*!
* \brief As our store require rollback, we implement it as a frame.
* \brief As our store require rollback, we implement it as a frame.
* every time we need to copy the store, a new frame is insert.
*
* every time we roll back, a frame is popped.
* Every time we need to copy the store, a new frame is insert.
* Every time we roll back, a frame is popped.
*/
*/
struct
StoreFrame
{
struct
StoreFrame
{
std
::
unordered_map
<
const
SRefNode
*
,
PStatic
>
store
;
std
::
unordered_map
<
const
SRefNode
*
,
PStatic
>
store
;
/*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */
/*!
* \brief On unknown effect, history_valid is set to true to signal above frame is outdated.
*
* It only outdate the frame above it, but not the current frame.
*/
bool
history_valid
=
true
;
bool
history_valid
=
true
;
explicit
StoreFrame
(
const
std
::
unordered_map
<
const
SRefNode
*
,
PStatic
>&
store
)
:
store
(
store
)
{
}
explicit
StoreFrame
(
const
std
::
unordered_map
<
const
SRefNode
*
,
PStatic
>&
store
)
:
store
(
store
)
{
}
StoreFrame
()
=
default
;
StoreFrame
()
=
default
;
...
@@ -310,6 +316,7 @@ class Store {
...
@@ -310,6 +316,7 @@ class Store {
}
}
void
Insert
(
const
SRefNode
*
r
,
const
PStatic
&
ps
)
{
void
Insert
(
const
SRefNode
*
r
,
const
PStatic
&
ps
)
{
CHECK
(
r
);
store_
.
back
().
store
[
r
]
=
ps
;
store_
.
back
().
store
[
r
]
=
ps
;
}
}
...
@@ -317,19 +324,21 @@ class Store {
...
@@ -317,19 +324,21 @@ class Store {
PStatic
Lookup
(
const
SRefNode
*
r
)
{
PStatic
Lookup
(
const
SRefNode
*
r
)
{
auto
rit
=
store_
.
rbegin
();
auto
rit
=
store_
.
rbegin
();
while
(
rit
!=
store_
.
rend
())
{
while
(
rit
!=
store_
.
rend
())
{
if
(
!
rit
->
history_valid
)
{
return
PStatic
();
}
if
(
rit
->
store
.
find
(
r
)
!=
rit
->
store
.
end
())
{
if
(
rit
->
store
.
find
(
r
)
!=
rit
->
store
.
end
())
{
return
rit
->
store
.
find
(
r
)
->
second
;
return
rit
->
store
.
find
(
r
)
->
second
;
}
}
if
(
!
rit
->
history_valid
)
{
return
PStatic
();
}
++
rit
;
++
rit
;
}
}
return
PStatic
();
return
PStatic
();
}
}
void
Invalidate
()
{
void
Invalidate
()
{
store_
.
back
().
history_valid
=
false
;
StoreFrame
sf
;
sf
.
history_valid
=
false
;
store_
.
push_back
(
sf
);
}
}
private
:
private
:
...
@@ -341,6 +350,10 @@ class Store {
...
@@ -341,6 +350,10 @@ class Store {
store_
->
store_
.
push_back
(
StoreFrame
());
store_
->
store_
.
push_back
(
StoreFrame
());
}
}
~
StoreFrameContext
()
{
~
StoreFrameContext
()
{
// push one history valid frame off.
while
(
!
store_
->
store_
.
back
().
history_valid
)
{
store_
->
store_
.
pop_back
();
}
store_
->
store_
.
pop_back
();
store_
->
store_
.
pop_back
();
}
}
};
};
...
@@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) {
...
@@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) {
class
PartialEvaluator
:
public
ExprFunctor
<
PStatic
(
const
Expr
&
e
,
LetList
*
ll
)
>
,
class
PartialEvaluator
:
public
ExprFunctor
<
PStatic
(
const
Expr
&
e
,
LetList
*
ll
)
>
,
public
PatternFunctor
<
MatchStatus
(
const
Pattern
&
,
const
PStatic
&
)
>
{
public
PatternFunctor
<
MatchStatus
(
const
Pattern
&
,
const
PStatic
&
)
>
{
public
:
public
:
PartialEvaluator
(
const
tvm
::
Array
<
Var
>&
free_vars
,
PartialEvaluator
(
const
Module
&
mod
)
:
mod_
(
mod
)
{
}
const
Module
&
mod
)
:
mod_
(
mod
)
{
for
(
const
Var
&
v
:
free_vars
)
{
env_
.
Insert
(
v
,
NoStatic
(
v
));
}
}
PStatic
VisitExpr
(
const
Expr
&
e
,
LetList
*
ll
)
final
{
PStatic
VisitExpr
(
const
Expr
&
e
,
LetList
*
ll
)
final
{
PStatic
ret
=
ExprFunctor
<
PStatic
(
const
Expr
&
,
LetList
*
)
>::
VisitExpr
(
e
,
ll
);
PStatic
ret
=
ExprFunctor
<
PStatic
(
const
Expr
&
,
LetList
*
)
>::
VisitExpr
(
e
,
ll
);
...
@@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
...
@@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return
env_
.
Lookup
(
GetRef
<
Var
>
(
op
));
return
env_
.
Lookup
(
GetRef
<
Var
>
(
op
));
}
}
PStatic
Visit
Expr_
(
const
GlobalVarNode
*
op
,
LetList
*
ll
)
final
{
PStatic
Visit
GlobalVar
(
const
GlobalVar
&
gv
)
{
GlobalVar
gv
=
GetRef
<
GlobalVar
>
(
op
);
CHECK
(
mod_
.
defined
()
);
if
(
gv_map_
.
count
(
gv
)
==
0
)
{
if
(
gv_map_
.
count
(
gv
)
==
0
)
{
if
(
mod_
.
defined
())
{
Function
func
=
mod_
->
Lookup
(
gv
);
Function
func
=
mod_
->
Lookup
(
gv
);
InitializeFuncId
(
func
);
InitializeFuncId
(
func
);
Func
f
=
VisitFuncStatic
(
func
,
gv
);
Func
f
=
VisitFuncStatic
(
func
,
gv
);
gv_map_
.
insert
({
gv
,
HasStatic
(
MkSFunc
(
f
),
gv
)});
gv_map_
.
insert
({
gv
,
HasStatic
(
MkSFunc
(
f
),
gv
)});
func
=
AsFunc
(
PostProcess
(
VisitFuncDynamic
(
func
,
f
)));
func
=
AsFunc
(
PostProcess
(
VisitFuncDynamic
(
func
,
f
)));
mod_
->
Update
(
gv
,
func
);
mod_
->
Update
(
gv
,
func
);
}
else
{
gv_map_
.
insert
({
gv
,
NoStatic
(
gv
)});
}
}
}
return
gv_map_
.
at
(
gv
);
return
gv_map_
.
at
(
gv
);
}
}
PStatic
VisitExpr_
(
const
GlobalVarNode
*
op
,
LetList
*
ll
)
final
{
return
VisitGlobalVar
(
GetRef
<
GlobalVar
>
(
op
));
}
PStatic
VisitExpr_
(
const
LetNode
*
op
,
LetList
*
ll
)
final
{
PStatic
VisitExpr_
(
const
LetNode
*
op
,
LetList
*
ll
)
final
{
env_
.
Insert
(
op
->
var
,
VisitExpr
(
op
->
value
,
ll
));
env_
.
Insert
(
op
->
var
,
VisitExpr
(
op
->
value
,
ll
));
return
VisitExpr
(
op
->
body
,
ll
);
return
VisitExpr
(
op
->
body
,
ll
);
...
@@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
...
@@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
subst
.
Set
(
func
->
type_params
[
i
],
type_args
[
i
]);
subst
.
Set
(
func
->
type_params
[
i
],
type_args
[
i
]);
}
}
for
(
size_t
i
=
type_args
.
size
();
i
<
func
->
type_params
.
size
();
++
i
)
{
for
(
size_t
i
=
type_args
.
size
();
i
<
func
->
type_params
.
size
();
++
i
)
{
subst
.
Set
(
func
->
type_params
[
i
],
Type
(
));
subst
.
Set
(
func
->
type_params
[
i
],
IncompleteTypeNode
::
make
(
kType
));
}
}
std
::
vector
<
Time
>
args_time
;
std
::
vector
<
Time
>
args_time
;
for
(
const
auto
&
v
:
pv
)
{
for
(
const
auto
&
v
:
pv
)
{
...
@@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
...
@@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
};
};
}
}
Expr
VisitFuncDynamic
(
const
Function
&
func
,
const
Func
&
f
)
{
Expr
VisitFuncDynamic
(
const
Function
&
func
,
const
Func
&
f
)
{
return
store_
.
Extend
<
Expr
>
([
&
]()
{
return
store_
.
Extend
<
Expr
>
([
&
]()
{
store_
.
Invalidate
();
store_
.
Invalidate
();
return
FunctionNode
::
make
(
func
->
params
,
LetList
::
With
([
&
](
LetList
*
ll
)
{
return
FunctionNode
::
make
(
func
->
params
,
std
::
vector
<
PStatic
>
pv
;
LetList
::
With
([
&
](
LetList
*
ll
)
{
for
(
const
auto
&
v
:
func
->
params
)
{
std
::
vector
<
PStatic
>
pv
;
pv
.
push_back
(
NoStatic
(
v
));
for
(
const
auto
&
v
:
func
->
params
)
{
}
pv
.
push_back
(
NoStatic
(
v
));
tvm
::
Array
<
Type
>
type_args
;
}
for
(
const
auto
&
tp
:
func
->
type_params
)
{
tvm
::
Array
<
Type
>
type_args
;
type_args
.
push_back
(
tp
);
for
(
const
auto
&
tp
:
func
->
type_params
)
{
}
type_args
.
push_back
(
tp
);
return
f
(
pv
,
Attrs
(),
type_args
,
ll
)
->
dynamic
;
}
}),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
return
f
(
pv
,
Attrs
(),
type_args
,
ll
)
->
dynamic
;
});
}),
func
->
ret_type
,
func
->
type_params
,
func
->
attrs
);
});
}
}
PStatic
VisitFunc
(
const
Function
&
func
,
LetList
*
ll
)
{
PStatic
VisitFunc
(
const
Function
&
func
,
LetList
*
ll
)
{
...
@@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) {
...
@@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) {
Module
PartialEval
(
const
Module
&
m
)
{
Module
PartialEval
(
const
Module
&
m
)
{
CHECK
(
m
->
entry_func
.
defined
());
CHECK
(
m
->
entry_func
.
defined
());
auto
func
=
m
->
Lookup
(
m
->
entry_func
);
relay
::
partial_eval
::
PartialEvaluator
pe
(
m
);
Expr
ret
=
std
::
vector
<
GlobalVar
>
gvs
;
TransformF
([
&
](
const
Expr
&
e
)
{
for
(
const
auto
&
p
:
m
->
functions
)
{
return
LetList
::
With
([
&
](
LetList
*
ll
)
{
gvs
.
push_back
(
p
.
first
);
relay
::
partial_eval
::
PartialEvaluator
pe
(
FreeVars
(
e
),
m
);
}
pe
.
InitializeFuncId
(
e
);
for
(
const
auto
&
gv
:
gvs
)
{
return
relay
::
partial_eval
::
PostProcess
(
pe
.
VisitExpr
(
e
,
ll
)
->
dynamic
);
pe
.
VisitGlobalVar
(
gv
);
});
}
},
func
);
CHECK
(
ret
->
is_type
<
FunctionNode
>
());
m
->
Update
(
m
->
entry_func
,
Downcast
<
Function
>
(
ret
));
return
m
;
return
m
;
}
}
...
...
src/relay/pass/type_infer.cc
View file @
b3f3ab55
...
@@ -172,6 +172,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
...
@@ -172,6 +172,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return
it
->
second
.
checked_type
;
return
it
->
second
.
checked_type
;
}
}
Type
ret
=
this
->
VisitExpr
(
expr
);
Type
ret
=
this
->
VisitExpr
(
expr
);
CHECK
(
ret
.
defined
());
KindCheck
(
ret
,
mod_
);
KindCheck
(
ret
,
mod_
);
ResolvedTypeInfo
&
rti
=
type_map_
[
expr
];
ResolvedTypeInfo
&
rti
=
type_map_
[
expr
];
rti
.
checked_type
=
ret
;
rti
.
checked_type
=
ret
;
...
...
src/relay/pass/util.cc
View file @
b3f3ab55
...
@@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
...
@@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
Var
VisitVar
(
const
Var
&
v
)
final
{
Var
VisitVar
(
const
Var
&
v
)
final
{
return
Downcast
<
Var
>
(
VisitExpr
(
v
));
return
Downcast
<
Var
>
(
VisitExpr
(
v
));
}
}
Pattern
VisitPattern
(
const
Pattern
&
p
)
final
{
return
PatternMutator
::
VisitPattern
(
p
);
}
Clause
VisitClause
(
const
Clause
&
c
)
final
{
Pattern
pat
=
VisitPattern
(
c
->
lhs
);
return
ClauseNode
::
make
(
pat
,
VisitExpr
(
c
->
rhs
));
}
private
:
private
:
const
tvm
::
Map
<
TypeVar
,
Type
>&
subst_map_
;
const
tvm
::
Map
<
TypeVar
,
Type
>&
subst_map_
;
};
};
...
...
tests/python/relay/test_pass_partial_eval.py
View file @
b3f3ab55
...
@@ -307,10 +307,10 @@ def test_double():
...
@@ -307,10 +307,10 @@ def test_double():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_
empty_ad
()
test_
ref
()
test_tuple
()
test_tuple
()
test_empty_ad
()
test_const_inline
()
test_const_inline
()
test_ref
()
test_ad
()
test_ad
()
test_if_ref
()
test_if_ref
()
test_function_invalidate
()
test_function_invalidate
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment