Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e6e9b371
Commit
e6e9b371
authored
Oct 27, 2018
by
Siva
Committed by
Tianqi Chen
Oct 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][OP] Split (#1876)
parent
cbf4fdbb
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
200 additions
and
2 deletions
+200
-2
docs/langref/relay_op.rst
+2
-0
include/tvm/relay/attrs/transform.h
+16
-0
nnvm/src/top/tensor/transform.cc
+1
-1
python/tvm/relay/expr.py
+11
-0
python/tvm/relay/op/transform.py
+34
-1
src/lang/attr_functor.h
+4
-0
src/lang/attrs.cc
+2
-0
src/relay/op/tensor/transform.cc
+97
-0
tests/python/relay/test_op_level3.py
+33
-0
No files found.
docs/langref/relay_op.rst
View file @
e6e9b371
...
@@ -94,6 +94,7 @@ This level enables additional math and transform operators.
...
@@ -94,6 +94,7 @@ This level enables additional math and transform operators.
tvm.relay.full
tvm.relay.full
tvm.relay.full_like
tvm.relay.full_like
tvm.relay.cast
tvm.relay.cast
tvm.relay.split
**Level 4: Broadcast and Reductions**
**Level 4: Broadcast and Reductions**
...
@@ -198,6 +199,7 @@ Level 3 Definitions
...
@@ -198,6 +199,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.split
Level 4 Definitions
Level 4 Definitions
...
...
include/tvm/relay/attrs/transform.h
View file @
e6e9b371
...
@@ -106,6 +106,22 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
...
@@ -106,6 +106,22 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}
}
};
// struct SqueezeAttrs
};
// struct SqueezeAttrs
struct
SplitAttrs
:
public
tvm
::
AttrsNode
<
SplitAttrs
>
{
NodeRef
indices_or_sections
;
int
axis
;
TVM_DECLARE_ATTRS
(
SplitAttrs
,
"relay.attrs.SplitAttrs"
)
{
TVM_ATTR_FIELD
(
indices_or_sections
)
.
describe
(
"Indices or sections to split into. Accepts an int or a tuple"
"If indices_or_sections is an integer, the input will be divided equally"
"along given axis. If such a split is not possible, an error is raised."
"If indices_or_sections is a tuple of sorted integers,"
"the entries indicate where along axis the array is split."
);
TVM_ATTR_FIELD
(
axis
).
set_default
(
0
)
.
describe
(
"the axis to be splitted."
);
}
};
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
nnvm/src/top/tensor/transform.cc
View file @
e6e9b371
...
@@ -427,7 +427,7 @@ along which to split the array.
...
@@ -427,7 +427,7 @@ along which to split the array.
return
Array
<
Tensor
>
{
topi
::
split
(
inputs
[
0
],
indices
,
param
.
axis
)
};
return
Array
<
Tensor
>
{
topi
::
split
(
inputs
[
0
],
indices
,
param
.
axis
)
};
}
}
})
})
.
set_support_level
(
1
);
.
set_support_level
(
3
);
// cast
// cast
DMLC_REGISTER_PARAMETER
(
CastParam
);
DMLC_REGISTER_PARAMETER
(
CastParam
);
...
...
python/tvm/relay/expr.py
View file @
e6e9b371
...
@@ -5,6 +5,7 @@ from __future__ import absolute_import
...
@@ -5,6 +5,7 @@ from __future__ import absolute_import
import
numpy
as
_np
import
numpy
as
_np
from
.base
import
RelayNode
,
register_relay_node
from
.base
import
RelayNode
,
register_relay_node
from
.
import
_make
from
.
import
_make
from
.
import
_expr
from
.
import
ty
as
_ty
from
.
import
ty
as
_ty
from
.._ffi
import
base
as
_base
from
.._ffi
import
base
as
_base
from
..
import
nd
as
_nd
from
..
import
nd
as
_nd
...
@@ -284,6 +285,16 @@ class TupleWrapper(object):
...
@@ -284,6 +285,16 @@ class TupleWrapper(object):
as an argument to an FFI function."""
as an argument to an FFI function."""
return
self
.
tuple_value
return
self
.
tuple_value
def
astext
(
self
):
"""Get the text format of the tuple expression.
Returns
-------
text : str
The text format of the tuple expression.
"""
return
_expr
.
_text_print
(
self
.
tuple_value
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
if
index
>=
len
(
self
):
if
index
>=
len
(
self
):
raise
IndexError
(
"Tuple index out of range"
)
raise
IndexError
(
"Tuple index out of range"
)
...
...
python/tvm/relay/op/transform.py
View file @
e6e9b371
"""Transform operators."""
"""Transform operators."""
from
.
import
_make
from
.
import
_make
from
..expr
import
TupleWrapper
def
expand_dims
(
data
,
axis
,
num_newaxis
=
1
):
def
expand_dims
(
data
,
axis
,
num_newaxis
=
1
):
...
@@ -146,7 +147,7 @@ def take(data, indices, axis=None):
...
@@ -146,7 +147,7 @@ def take(data, indices, axis=None):
Parameters
Parameters
----------
----------
a : relay.Expr
dat
a : relay.Expr
The source array.
The source array.
indices : rely.Expr
indices : rely.Expr
...
@@ -280,3 +281,35 @@ def collapse_sum_like(data, collapse_type):
...
@@ -280,3 +281,35 @@ def collapse_sum_like(data, collapse_type):
The resulting tensor.
The resulting tensor.
"""
"""
return
_make
.
collapse_sum_like
(
data
,
collapse_type
)
return
_make
.
collapse_sum_like
(
data
,
collapse_type
)
def
split
(
data
,
indices_or_sections
,
axis
=
0
):
"""Split input tensor along axis by sections or indices.
If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.
If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.
Parameters
----------
data : relay.Expr
The source array.
indices_or_sections : int or tuple of int
Indices or sections to split into. Accepts an int or a tuple
axis : int, optional
The axis over which to split.
Returns
-------
ret : relay.Tuple([relay.Expr, relay.Expr])
The computed result.
"""
if
isinstance
(
indices_or_sections
,
int
):
ret_size
=
indices_or_sections
else
:
ret_size
=
len
(
indices_or_sections
)
+
1
return
TupleWrapper
(
_make
.
split
(
data
,
indices_or_sections
,
axis
),
ret_size
)
src/lang/attr_functor.h
View file @
e6e9b371
...
@@ -64,6 +64,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
...
@@ -64,6 +64,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
virtual
R
VisitAttr_
(
const
ir
::
Add
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Add
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Sub
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Sub
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Mul
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Mul
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Div
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Mod
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Mod
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Min
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Min
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Max
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Max
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
...
@@ -96,6 +97,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
...
@@ -96,6 +97,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH
(
Add
);
ATTR_FUNCTOR_DISPATCH
(
Add
);
ATTR_FUNCTOR_DISPATCH
(
Sub
);
ATTR_FUNCTOR_DISPATCH
(
Sub
);
ATTR_FUNCTOR_DISPATCH
(
Mul
);
ATTR_FUNCTOR_DISPATCH
(
Mul
);
ATTR_FUNCTOR_DISPATCH
(
Div
);
ATTR_FUNCTOR_DISPATCH
(
Min
);
ATTR_FUNCTOR_DISPATCH
(
Min
);
ATTR_FUNCTOR_DISPATCH
(
Max
);
ATTR_FUNCTOR_DISPATCH
(
Max
);
ATTR_FUNCTOR_DISPATCH
(
GE
);
ATTR_FUNCTOR_DISPATCH
(
GE
);
...
@@ -135,6 +137,7 @@ class AttrsEqualHandler :
...
@@ -135,6 +137,7 @@ class AttrsEqualHandler :
bool
VisitAttr_
(
const
ir
::
Add
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Add
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Sub
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Sub
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Mul
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Mul
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Div
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Mod
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Mod
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Min
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Min
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Max
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Max
*
lhs
,
const
NodeRef
&
other
)
final
;
...
@@ -174,6 +177,7 @@ class AttrsHashHandler :
...
@@ -174,6 +177,7 @@ class AttrsHashHandler :
size_t
VisitAttr_
(
const
ir
::
Add
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Add
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Sub
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Sub
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Mul
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Mul
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Div
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Mod
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Mod
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Min
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Min
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Max
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Max
*
op
)
final
;
...
...
src/lang/attrs.cc
View file @
e6e9b371
...
@@ -132,6 +132,7 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other)
...
@@ -132,6 +132,7 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other)
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Add
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Add
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Sub
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Sub
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Mul
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Mul
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Div
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Mod
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Mod
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Max
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Max
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Min
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Min
);
...
@@ -243,6 +244,7 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
...
@@ -243,6 +244,7 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
TVM_DEFINE_ATTRS_BINOP_HASH
(
Add
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Add
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Sub
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Sub
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Mul
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Mul
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Div
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Mod
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Mod
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Max
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Max
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Min
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Min
);
...
...
src/relay/op/tensor/transform.cc
View file @
e6e9b371
...
@@ -6,12 +6,14 @@
...
@@ -6,12 +6,14 @@
#include <tvm/relay/op.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <vector>
#include <vector>
#include "../op_common.h"
#include "../op_common.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
using
ir
::
IntImm
;
// relay.cast
// relay.cast
TVM_REGISTER_NODE_TYPE
(
CastAttrs
);
TVM_REGISTER_NODE_TYPE
(
CastAttrs
);
...
@@ -834,5 +836,100 @@ RELAY_REGISTER_OP("broadcast_to_like")
...
@@ -834,5 +836,100 @@ RELAY_REGISTER_OP("broadcast_to_like")
.
set_support_level
(
10
)
.
set_support_level
(
10
)
.
add_type_rel
(
"BroadCastToLike"
,
BroadCastToLikeRel
);
.
add_type_rel
(
"BroadCastToLike"
,
BroadCastToLikeRel
);
// Split
TVM_REGISTER_NODE_TYPE
(
SplitAttrs
);
bool
SplitRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
// `types` contains: [data, result]
CHECK_EQ
(
types
.
size
(),
2
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
CHECK
(
data
!=
nullptr
);
CHECK_NE
(
data
->
shape
.
size
(),
0
)
<<
"Input shape cannot be empty"
;
const
auto
param
=
attrs
.
as
<
SplitAttrs
>
();
CHECK
(
param
!=
nullptr
);
auto
axis
=
param
->
axis
;
if
(
axis
<
0
)
{
axis
+=
data
->
shape
.
size
();
}
CHECK_LT
(
axis
,
data
->
shape
.
size
())
<<
"axis should be within the input dimension range."
;
CHECK_GT
(
axis
,
0
)
<<
"axis should be within the input dimension range."
;
if
(
const
IntImm
*
sections
=
param
->
indices_or_sections
.
as
<
IntImm
>
())
{
CHECK
(
reporter
->
Assert
(
data
->
shape
[
axis
]
%
sections
->
value
==
make_zero
(
Int
(
64
))))
<<
"indices_or_sections need to be able to divide input.shape[axis]"
;
std
::
vector
<
Type
>
fields
;
for
(
int
i
=
0
;
i
<
sections
->
value
;
++
i
)
{
std
::
vector
<
IndexExpr
>&&
oshape
=
AsVector
(
data
->
shape
);
oshape
[
axis
]
/=
int32_t
(
sections
->
value
);
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
fields
.
push_back
(
vec_type
);
}
reporter
->
Assign
(
types
[
1
],
TupleTypeNode
::
make
(
Array
<
Type
>
(
fields
)));
}
else
{
auto
indices
=
param
->
indices_or_sections
.
as
<
ArrayNode
>
()
->
data
;
auto
begin
=
IndexExpr
(
make_zero
(
Int
(
32
)));
std
::
vector
<
Type
>
fields
;
for
(
uint
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
CHECK
(
reporter
->
Assert
(
IndexExpr
(
indices
[
i
])
>
begin
))
<<
"indices_or_sections need to be a sorted ascending list"
;
std
::
vector
<
IndexExpr
>&&
oshape
=
AsVector
(
data
->
shape
);
oshape
[
axis
]
=
IndexExpr
(
indices
[
i
])
-
begin
;
begin
=
IndexExpr
(
indices
[
i
]);
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
fields
.
push_back
(
vec_type
);
}
CHECK
(
reporter
->
Assert
(
begin
<
data
->
shape
[
axis
]))
<<
"The sum of sections must match the input.shape[axis]"
;
std
::
vector
<
IndexExpr
>&&
oshape
=
AsVector
(
data
->
shape
);
oshape
[
axis
]
=
data
->
shape
[
axis
]
-
begin
;
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
fields
.
push_back
(
vec_type
);
reporter
->
Assign
(
types
[
1
],
TupleTypeNode
::
make
(
Array
<
Type
>
(
fields
)));
}
return
true
;
}
Expr
MakeSplit
(
Expr
data
,
NodeRef
indices_or_sections
,
int
axis
)
{
auto
attrs
=
make_node
<
SplitAttrs
>
();
attrs
->
axis
=
axis
;
attrs
->
indices_or_sections
=
std
::
move
(
indices_or_sections
);
static
const
Op
&
op
=
Op
::
Get
(
"split"
);
return
CallNode
::
make
(
op
,
{
data
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op._make.split"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
if
(
args
.
type_codes
[
1
]
==
kDLInt
)
{
*
rv
=
MakeSplit
(
args
[
0
],
make_const
(
Int
(
64
),
int64_t
(
args
[
1
])),
args
[
2
]);
}
else
{
*
rv
=
MakeSplit
(
args
[
0
],
args
[
1
],
args
[
2
]);
}
});
RELAY_REGISTER_OP
(
"split"
)
.
describe
(
R"code(Splits an array along a particular axis into multiple sub-arrays.
Indices or sections to split into. Accepts an int or a tuple
If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.
If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.SplitAttrs"
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Split"
,
SplitRel
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_op_level3.py
View file @
e6e9b371
...
@@ -107,6 +107,38 @@ def test_take_infer_type():
...
@@ -107,6 +107,38 @@ def test_take_infer_type():
verify_take
((
d1
,
d2
),
(
d3
,
d4
,
d5
),
(
d1
,
d3
,
d4
,
d5
),
1
)
verify_take
((
d1
,
d2
),
(
d3
,
d4
,
d5
),
(
d1
,
d3
,
d4
,
d5
),
1
)
verify_take
((
d1
,
d2
,
d3
,
d4
),
(
d5
,
d6
),
(
d1
,
d2
,
d5
,
d6
,
d4
),
-
2
)
verify_take
((
d1
,
d2
,
d3
,
d4
),
(
d5
,
d6
),
(
d1
,
d2
,
d5
,
d6
,
d4
),
-
2
)
def
test_split_infer_type
():
def
verify_split
(
dshape
,
indices_or_sections
,
ret_type
,
axis
=
None
):
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
y
=
relay
.
split
(
x
,
indices_or_sections
,
axis
=
axis
)
y
.
astext
()
yy
=
relay
.
ir_pass
.
infer_type
(
y
.
astuple
())
assert
yy
.
checked_type
==
ret_type
d1
,
d2
,
d3
,
d4
=
tvm
.
var
(
"d1"
),
tvm
.
var
(
"d2"
),
tvm
.
var
(
"d3"
),
tvm
.
var
(
"d4"
)
axis
=
tvm
.
var
(
"axis"
)
verify_split
((
5
,
5
,
2
,
2
),
5
,
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
5
,
1
,
2
,
2
),
"float32"
),
relay
.
ty
.
TensorType
((
5
,
1
,
2
,
2
),
"float32"
),
relay
.
ty
.
TensorType
((
5
,
1
,
2
,
2
),
"float32"
),
relay
.
ty
.
TensorType
((
5
,
1
,
2
,
2
),
"float32"
),
relay
.
ty
.
TensorType
((
5
,
1
,
2
,
2
),
"float32"
)])),
axis
=
1
)
verify_split
((
d1
,
d2
,
d3
,
d4
),
4
,
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
/
4
,
d4
),
"float32"
),
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
/
4
,
d4
),
"float32"
),
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
/
4
,
d4
),
"float32"
),
relay
.
ty
.
TensorType
((
d1
,
d2
,
d3
/
4
,
d4
),
"float32"
)])),
axis
=
2
)
verify_split
((
d1
,
d2
,
d3
,
d4
),
(
2
,
4
,
7
),
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
d1
,
2
,
d3
,
d4
),
"float32"
),
relay
.
ty
.
TensorType
((
d1
,
2
,
d3
,
d4
),
"float32"
),
relay
.
ty
.
TensorType
((
d1
,
3
,
d3
,
d4
),
"float32"
),
relay
.
ty
.
TensorType
((
d1
,
(
d2
-
7
),
d3
,
d4
),
"float32"
)])),
axis
=
1
)
def
test_full
():
def
test_full
():
# default settings: match input dtype
# default settings: match input dtype
...
@@ -161,3 +193,4 @@ if __name__ == "__main__":
...
@@ -161,3 +193,4 @@ if __name__ == "__main__":
test_infer_type_leaky_relu
()
test_infer_type_leaky_relu
()
test_squeeze_infer_type
()
test_squeeze_infer_type
()
test_squeeze_bad_axes_infer_type
()
test_squeeze_bad_axes_infer_type
()
test_split_infer_type
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment