Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
62a94c76
Commit
62a94c76
authored
Oct 27, 2018
by
Siju
Committed by
Tianqi Chen
Oct 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY]reshape_like (#1950)
parent
4fbb7c89
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
115 additions
and
0 deletions
+115
-0
docs/langref/relay_op.rst
+2
-0
include/tvm/relay/type.h
+5
-0
python/tvm/relay/op/transform.py
+23
-0
src/relay/ir/type.cc
+12
-0
src/relay/op/tensor/transform.cc
+56
-0
tests/python/relay/test_op_level3.py
+17
-0
No files found.
docs/langref/relay_op.rst
View file @
62a94c76
...
...
@@ -78,6 +78,7 @@ This level enables additional math and transform operators.
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.reshape
tvm.relay.reshape_like
tvm.relay.copy
tvm.relay.transpose
tvm.relay.floor
...
...
@@ -189,6 +190,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.abs
.. autofunction:: tvm.relay.negative
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
...
...
include/tvm/relay/type.h
View file @
62a94c76
...
...
@@ -82,6 +82,11 @@ class TensorTypeNode : public BaseTensorTypeNode {
v
->
Visit
(
"span"
,
&
span
);
}
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL
IndexExpr
Size
()
const
;
TVM_DLL
static
TensorType
make
(
Array
<
IndexExpr
>
shape
,
DataType
dtype
);
/*! \brief Construct an scalar containing elements of dtype. */
...
...
python/tvm/relay/op/transform.py
View file @
62a94c76
...
...
@@ -142,6 +142,29 @@ def reshape(data, newshape):
return
_make
.
reshape
(
data
,
list
(
newshape
))
def
reshape_like
(
data
,
shape_like
):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
Parameters
----------
data : relay.Expr
The input data to the operator.
shape_like : tuple of int
The new shape. Should be compatible with the original shape.
Returns
-------
ret : relay.Expr
The computed result.
"""
return
_make
.
reshape_like
(
data
,
shape_like
)
def
take
(
data
,
indices
,
axis
=
None
):
"""Take elements from an array along an axis.
...
...
src/relay/ir/type.cc
View file @
62a94c76
...
...
@@ -22,6 +22,18 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return
TensorTypeNode
::
make
({},
dtype
);
}
IndexExpr
TensorTypeNode
::
Size
()
const
{
if
(
shape
.
size
()
==
0
)
{
return
make_const
(
Int
(
64
),
1
);
}
IndexExpr
size
=
shape
[
0
];
for
(
size_t
i
=
1
;
i
<
shape
.
size
();
++
i
)
{
size
*=
shape
[
i
];
}
return
size
;
}
TVM_REGISTER_NODE_TYPE
(
TensorTypeNode
);
TVM_REGISTER_API
(
"relay._make.TensorType"
)
...
...
src/relay/op/tensor/transform.cc
View file @
62a94c76
...
...
@@ -377,6 +377,62 @@ Example::
.
set_support_level
(
3
)
.
add_type_rel
(
"Reshape"
,
ReshapeRel
);
/*!
* \brief ReshapeLikeRel User defined type constraint function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return False if the relation has not been resolved, it might be resolved later.
* True if this relation has been resolved.
*/
bool
ReshapeLikeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
{
return
false
;
}
const
auto
*
reshape_like
=
types
[
1
].
as
<
TensorTypeNode
>
();
if
(
reshape_like
==
nullptr
)
{
return
false
;
}
CHECK
(
reporter
->
AssertEQ
(
data
->
Size
(),
reshape_like
->
Size
()))
<<
"Reshape inputs size should be compatible."
;
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
reshape_like
->
shape
,
data
->
dtype
));
return
true
;
}
Expr
MakeReshapeLike
(
Expr
data
,
Expr
shape_like
)
{
static
const
Op
&
op
=
Op
::
Get
(
"reshape_like"
);
return
CallNode
::
make
(
op
,
{
data
,
shape_like
},
Attrs
(),
{});
}
TVM_REGISTER_API
(
"relay.op._make.reshape_like"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
2
>
(
MakeReshapeLike
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"reshape_like"
)
.
describe
(
R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"shape_like"
,
"Tensor"
,
"Shape tensor."
)
.
set_support_level
(
3
)
.
add_type_rel
(
"ReshapeLike"
,
ReshapeLikeRel
);
// Take
TVM_REGISTER_NODE_TYPE
(
TakeAttrs
);
...
...
tests/python/relay/test_op_level3.py
View file @
62a94c76
...
...
@@ -88,6 +88,22 @@ def test_reshape_infer_type():
(
n
,
t
,
2000
),
"float32"
)
def
test_reshape_like
():
# concrete shape
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
1
,
2
,
3
),
"float32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
1
,
6
),
"float32"
))
z
=
relay
.
reshape_like
(
x
,
y
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
1
,
6
),
"float32"
)
# symbolic shape
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
2
,
3
,
tvm
.
var
(
"w"
)
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
y
=
relay
.
var
(
"y"
,
relay
.
TensorType
((
1
,
8
,
8
),
"float32"
))
z
=
relay
.
reshape_like
(
x
,
y
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
TensorType
((
1
,
8
,
8
),
"float32"
)
def
test_take_infer_type
():
def
verify_take
(
dshape
,
indices_shape
,
oshape
,
axis
=
None
):
...
...
@@ -187,6 +203,7 @@ if __name__ == "__main__":
test_clip_type
()
test_transpose_infer_type
()
test_reshape_infer_type
()
test_reshape_like
()
test_take_infer_type
()
test_full
()
test_full_like
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment