Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
07399e02
Unverified
Commit
07399e02
authored
Oct 30, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][OP] Maketuple to be resolved when containing incompleteType (#2031)
parent
866d458c
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
7 deletions
+34
-7
src/relay/op/tensor/transform.cc
+1
-1
src/relay/pass/type_infer.cc
+32
-6
tests/python/relay/test_op_level1.py
+1
-0
No files found.
src/relay/op/tensor/transform.cc
View file @
07399e02
...
...
@@ -140,7 +140,7 @@ bool ConcatenateRel(const Array<Type>& types,
CHECK_EQ
(
types
.
size
(),
2
);
const
auto
*
tensor_tuple
=
types
[
0
].
as
<
TupleTypeNode
>
();
if
(
tensor_tuple
==
nullptr
)
{
CHECK
(
types
[
0
].
as
<
Tupl
eTypeNode
>
())
CHECK
(
types
[
0
].
as
<
Incomplet
eTypeNode
>
())
<<
"cast: expect input type to be TupleType but get "
<<
types
[
0
];
return
false
;
...
...
src/relay/pass/type_infer.cc
View file @
07399e02
...
...
@@ -56,11 +56,31 @@ bool TupleGetItemRel(const Array<Type>& types,
return
true
;
}
bool
MakeTupleRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
static_cast
<
size_t
>
(
num_inputs
+
1
),
types
.
size
());
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
if
(
types
[
i
].
as
<
IncompleteTypeNode
>
())
return
false
;
}
Array
<
Type
>
fields
;
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
fields
.
push_back
(
types
[
i
]);
}
reporter
->
Assign
(
types
[
num_inputs
],
TupleTypeNode
::
make
(
fields
));
return
true
;
}
TVM_REGISTER_NODE_TYPE
(
TupleGetItemAttrs
);
TVM_REGISTER_API
(
"tvm.relay.type_relation.TupleGetItem"
)
.
set_body_typed
<
bool
(
const
Array
<
Type
>&
,
int
,
const
Attrs
&
,
const
TypeReporter
&
)
>
(
TupleGetItemRel
);
TVM_REGISTER_API
(
"tvm.relay.type_relation.MakeTuple"
)
.
set_body_typed
<
bool
(
const
Array
<
Type
>&
,
int
,
const
Attrs
&
,
const
TypeReporter
&
)
>
(
MakeTupleRel
);
struct
ResolvedTypeInfo
{
explicit
ResolvedTypeInfo
(
Type
checked_type
,
Array
<
Type
>
type_args
)
:
checked_type
(
checked_type
),
type_args
(
type_args
)
{}
...
...
@@ -104,6 +124,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
TypeSolver
solver_
;
// relation function
TypeRelationFn
tuple_getitem_rel_
;
TypeRelationFn
make_tuple_rel_
;
// Unify two types
Type
Unify
(
const
Type
&
t1
,
const
Type
&
t2
,
const
Span
&
span
)
{
// TODO(tqchen, jroesch): propagate span to solver
...
...
@@ -154,14 +175,19 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
Type
VisitExpr_
(
const
TupleNode
*
op
)
final
{
// TODO(tqchen, jroesch)
// tuple should be a constraint in the type solver
// to handle cases where the field type is not known.
Array
<
Type
>
fields
;
if
(
!
make_tuple_rel_
.
defined
())
{
make_tuple_rel_
=
TypeRelationFn
(
EnvFunc
::
Get
(
"tvm.relay.type_relation.MakeTuple"
).
node_
);
}
Array
<
Type
>
types
;
for
(
Expr
field
:
op
->
fields
)
{
field
s
.
push_back
(
GetType
(
field
));
type
s
.
push_back
(
GetType
(
field
));
}
return
TupleTypeNode
::
make
(
fields
);
Type
rtype
=
IncompleteTypeNode
::
make
(
TypeVarNode
::
Kind
::
kType
);
types
.
push_back
(
rtype
);
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
make_tuple_rel_
,
types
,
op
->
fields
.
size
(),
Attrs
()));
return
rtype
;
}
Type
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
...
...
tests/python/relay/test_op_level1.py
View file @
07399e02
...
...
@@ -87,6 +87,7 @@ def test_concatenate_infer_type():
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
t
,
200
))
x
=
relay
.
exp
(
x
)
z
=
relay
.
concatenate
((
x
,
y
),
axis
=
2
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
t
,
200
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment