Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
201cfdc5
Commit
201cfdc5
authored
Oct 15, 2018
by
雾雨魔理沙
Committed by
Tianqi Chen
Oct 15, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] [Op] Squeeze (#1858)
parent
47b8c36d
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
3 deletions
+149
-3
include/tvm/relay/attrs/transform.h
+14
-0
python/tvm/relay/op/transform.py
+24
-1
src/relay/op/tensor/transform.cc
+70
-2
tests/python/relay/test_op_level3.py
+41
-0
No files found.
include/tvm/relay/attrs/transform.h
View file @
201cfdc5
...
@@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
...
@@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
}
}
};
// struct InitOpAttrs
};
// struct InitOpAttrs
/*! \brief Attributes used in squeeze operators */
struct
SqueezeAttrs
:
public
tvm
::
AttrsNode
<
SqueezeAttrs
>
{
Array
<
IndexExpr
>
axes
;
TVM_DECLARE_ATTRS
(
SqueezeAttrs
,
"relay.attrs.SqueezeAttrs"
)
{
TVM_ATTR_FIELD
(
axes
)
.
describe
(
"The axes to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1."
)
.
set_default
(
Array
<
IndexExpr
>
({}));
}
};
// struct SqueezeAttrs
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
python/tvm/relay/op/transform.py
View file @
201cfdc5
...
@@ -42,12 +42,35 @@ def transpose(data, axes=None):
...
@@ -42,12 +42,35 @@ def transpose(data, axes=None):
Returns
Returns
-------
-------
result : relay.Expr
result : relay.Expr
The
reshap
ed result.
The
transpos
ed result.
"""
"""
axes
=
axes
or
[]
axes
=
axes
or
[]
return
_make
.
transpose
(
data
,
list
(
axes
))
return
_make
.
transpose
(
data
,
list
(
axes
))
def
squeeze
(
data
,
axes
=
None
):
"""Squeeze axes in the array.
Parameters
----------
data : relay.Expr
The input data to the operator.
axes : None or List[int]
Axes to remove.
If axes = [] or = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes.
If any axis in axes has dimension that does not equal 1, it is an error.
Returns
-------
result : relay.Expr
The squeezed result.
"""
axes
=
axes
or
[]
return
_make
.
squeeze
(
data
,
list
(
axes
))
def
reshape
(
data
,
newshape
):
def
reshape
(
data
,
newshape
):
"""Reshapes the input array.
"""Reshapes the input array.
...
...
src/relay/op/tensor/transform.cc
View file @
201cfdc5
...
@@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims")
...
@@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims")
.
set_support_level
(
1
)
.
set_support_level
(
1
)
.
add_type_rel
(
"ExpandDims"
,
ExpandDimsRel
);
.
add_type_rel
(
"ExpandDims"
,
ExpandDimsRel
);
/* relay.concatenate */
TVM_REGISTER_NODE_TYPE
(
ConcatenateAttrs
);
TVM_REGISTER_NODE_TYPE
(
ConcatenateAttrs
);
bool
ConcatenateRel
(
const
Array
<
Type
>&
types
,
bool
ConcatenateRel
(
const
Array
<
Type
>&
types
,
...
@@ -633,5 +631,75 @@ Examples::
...
@@ -633,5 +631,75 @@ Examples::
.
set_support_level
(
4
)
.
set_support_level
(
4
)
.
add_type_rel
(
"Where"
,
WhereRel
);
.
add_type_rel
(
"Where"
,
WhereRel
);
Expr
MakeSqueeze
(
Expr
data
,
Array
<
IndexExpr
>
axes
)
{
auto
attrs
=
make_node
<
SqueezeAttrs
>
();
attrs
->
axes
=
std
::
move
(
axes
);
static
const
Op
&
op
=
Op
::
Get
(
"squeeze"
);
return
CallNode
::
make
(
op
,
{
data
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op._make.squeeze"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
2
>
(
MakeSqueeze
,
args
,
rv
);
});
bool
SqueezeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
2
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
{
return
false
;
}
const
auto
*
param
=
attrs
.
as
<
SqueezeAttrs
>
();
CHECK
(
param
!=
nullptr
);
std
::
vector
<
IndexExpr
>
result_shape
;
// if axes is empty, squeeze all axes of dimension 1
if
(
param
->
axes
.
size
()
==
0
)
{
for
(
const
auto
&
e
:
data
->
shape
)
{
const
int64_t
*
axis_ptr
=
as_const_int
(
e
);
CHECK
(
axis_ptr
!=
nullptr
)
<<
"the axes attribute must be concrete"
;
if
(
*
axis_ptr
!=
1
)
{
result_shape
.
push_back
(
e
);
}
}
}
else
{
// pair up original shape with a boolean which control whether it will be in the final shape.
std
::
vector
<
std
::
pair
<
IndexExpr
,
bool
>
>
original_shape
;
for
(
const
auto
&
e
:
data
->
shape
)
{
original_shape
.
push_back
(
std
::
pair
<
IndexExpr
,
bool
>
(
e
,
true
));
}
for
(
const
auto
&
e
:
param
->
axes
)
{
const
int64_t
*
axis_ptr
=
as_const_int
(
e
);
CHECK
(
axis_ptr
!=
nullptr
);
original_shape
.
at
(
*
axis_ptr
).
second
=
false
;
}
for
(
const
auto
p
:
original_shape
)
{
if
(
p
.
second
)
{
result_shape
.
push_back
(
p
.
first
);
}
else
{
const
int64_t
*
axis_ptr
=
as_const_int
(
p
.
first
);
CHECK
(
axis_ptr
!=
nullptr
)
<<
"cannot get concrete shape of input tensor"
;
CHECK_EQ
(
*
axis_ptr
,
1
)
<<
"cannot squeeze axis with dimension not equal to 1"
;
}
}
}
reporter
->
Assign
(
types
[
1
],
TensorTypeNode
::
make
(
result_shape
,
data
->
dtype
));
return
true
;
}
RELAY_REGISTER_OP
(
"squeeze"
)
.
describe
(
R"code(Squeeze the input tensor at the dimensions given by axes
- **data**: The input data to the operator.
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Squeeze"
,
SqueezeRel
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_op_level3.py
View file @
201cfdc5
...
@@ -6,6 +6,7 @@ from tvm import relay
...
@@ -6,6 +6,7 @@ from tvm import relay
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.ir_builder
import
IRBuilder
,
func_type
from
tvm.relay.ir_builder
import
IRBuilder
,
func_type
from
tvm.relay.env
import
Environment
from
tvm.relay.env
import
Environment
from
nose.tools
import
raises
def
test_zeros_ones
():
def
test_zeros_ones
():
for
op
in
[
relay
.
zeros
,
relay
.
ones
]:
for
op
in
[
relay
.
zeros
,
relay
.
ones
]:
...
@@ -67,6 +68,44 @@ def test_transpose_infer_type():
...
@@ -67,6 +68,44 @@ def test_transpose_infer_type():
(
t
,
n
,
100
),
"float32"
)
(
t
,
n
,
100
),
"float32"
)
def
test_squeeze_default_axes_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
1
,
4
,
1
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
squeeze
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
4
,),
"float32"
)
def
test_squeeze_axes_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
1
,
4
,
1
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
squeeze
(
x
,
axes
=
(
2
,)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
4
),
"float32"
)
@raises
(
tvm
.
_ffi
.
base
.
TVMError
)
def
test_squeeze_bad_axes_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d
=
1
,
4
,
1
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
t
,
d
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
squeeze
(
x
,
axes
=
(
1
,)))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
def
test_reshape_infer_type
():
def
test_reshape_infer_type
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
t
,
d1
,
d2
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
,
20
n
,
t
,
d1
,
d2
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"t"
),
100
,
20
...
@@ -181,3 +220,5 @@ if __name__ == "__main__":
...
@@ -181,3 +220,5 @@ if __name__ == "__main__":
test_take_infer_type
()
test_take_infer_type
()
test_full
()
test_full
()
test_full_like
()
test_full_like
()
test_squeeze_axes_infer_type
()
test_squeeze_default_axes_infer_type
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment