Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b68d9dc0
Commit
b68d9dc0
authored
Oct 09, 2018
by
Siva
Committed by
Tianqi Chen
Oct 08, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][OP] take (#1863)
parent
64d3393e
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
146 additions
and
1 deletions
+146
-1
docs/langref/relay_op.rst
+2
-0
include/tvm/relay/attrs/transform.h
+9
-0
nnvm/src/top/tensor/transform.cc
+1
-1
python/tvm/relay/op/transform.py
+23
-0
src/relay/op/tensor/transform.cc
+89
-0
tests/python/relay/test_op_level3.py
+22
-0
No files found.
docs/langref/relay_op.rst
View file @
b68d9dc0
...
@@ -73,6 +73,7 @@ This level enables additional math and transform operators.
...
@@ -73,6 +73,7 @@ This level enables additional math and transform operators.
tvm.relay.round
tvm.relay.round
tvm.relay.abs
tvm.relay.abs
tvm.relay.negative
tvm.relay.negative
tvm.relay.take
...
@@ -143,6 +144,7 @@ Level 3 Definitions
...
@@ -143,6 +144,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
Level 3 Definitions
Level 3 Definitions
-------------------
-------------------
...
...
include/tvm/relay/attrs/transform.h
View file @
b68d9dc0
...
@@ -59,6 +59,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
...
@@ -59,6 +59,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}
}
};
// struct ReshapeAttrs
};
// struct ReshapeAttrs
struct
TakeAttrs
:
public
tvm
::
AttrsNode
<
TakeAttrs
>
{
IndexExpr
axis
;
TVM_DECLARE_ATTRS
(
TakeAttrs
,
"relay.attrs.TakeAttrs"
)
{
TVM_ATTR_FIELD
(
axis
).
set_default
(
NullValue
<
IndexExpr
>
())
.
describe
(
"The axis over which to select values."
);
}
};
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
nnvm/src/top/tensor/transform.cc
View file @
b68d9dc0
...
@@ -1135,7 +1135,7 @@ Examples::
...
@@ -1135,7 +1135,7 @@ Examples::
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
TakeCorrectLayout
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
TakeCorrectLayout
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
set_num_outputs
(
1
)
.
set_support_level
(
1
)
.
set_support_level
(
3
)
.
set_attr
<
FTVMCompute
>
(
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
inputs
,
...
...
python/tvm/relay/op/transform.py
View file @
b68d9dc0
...
@@ -116,3 +116,26 @@ def reshape(data, newshape):
...
@@ -116,3 +116,26 @@ def reshape(data, newshape):
if
isinstance
(
newshape
,
int
):
if
isinstance
(
newshape
,
int
):
newshape
=
[
newshape
]
newshape
=
[
newshape
]
return
_make
.
reshape
(
data
,
list
(
newshape
))
return
_make
.
reshape
(
data
,
list
(
newshape
))
def
take
(
data
,
indices
,
axis
=
None
):
"""Take elements from an array along an axis.
Parameters
----------
a : relay.Expr
The source array.
indices : rely.Expr
The indices of the values to extract.
axis : int, optional
The axis over which to select values. By default,
the flattened input array is used.
Returns
-------
ret : relay.Expr
The computed result.
"""
return
_make
.
take
(
data
,
indices
,
axis
)
src/relay/op/tensor/transform.cc
View file @
b68d9dc0
...
@@ -315,5 +315,94 @@ Example::
...
@@ -315,5 +315,94 @@ Example::
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Reshape"
,
ReshapeRel
);
.
add_type_rel
(
"Reshape"
,
ReshapeRel
);
// Take
TVM_REGISTER_NODE_TYPE
(
TakeAttrs
);
bool
TakeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
// `types` contains: [data, indices, result]
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
CHECK
(
data
!=
nullptr
);
const
auto
*
indices
=
types
[
1
].
as
<
TensorTypeNode
>
();
CHECK
(
indices
!=
nullptr
);
const
auto
param
=
attrs
.
as
<
TakeAttrs
>
();
CHECK
(
param
!=
nullptr
);
if
(
!
param
->
axis
.
defined
())
{
std
::
vector
<
IndexExpr
>&&
oshape
=
AsVector
(
indices
->
shape
);
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
));
return
true
;
}
std
::
vector
<
IndexExpr
>
oshape
;
const
auto
ndim_data
=
static_cast
<
int
>
(
data
->
shape
.
size
());
const
auto
ndim_indices
=
static_cast
<
int
>
(
indices
->
shape
.
size
());
auto
axis
=
(
*
as_const_int
(
param
->
axis
));
if
(
axis
<
0
)
axis
+=
ndim_data
;
CHECK_LE
(
axis
,
ndim_data
)
<<
"axis should be with in data shape"
<<
", but got = "
<<
axis
;
oshape
.
reserve
(
ndim_data
-
1
+
ndim_indices
);
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
oshape
.
emplace_back
(
data
->
shape
[
i
]);
}
for
(
int
i
=
0
;
i
<
ndim_indices
;
++
i
)
{
oshape
.
emplace_back
(
indices
->
shape
[
i
]);
}
for
(
int
i
=
axis
+
1
;
i
<
ndim_data
;
++
i
)
{
oshape
.
emplace_back
(
data
->
shape
[
i
]);
}
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
));
return
true
;
}
Expr
MakeTake
(
Expr
data
,
Expr
indices
,
IndexExpr
axis
)
{
auto
attrs
=
make_node
<
TakeAttrs
>
();
attrs
->
axis
=
axis
;
static
const
Op
&
op
=
Op
::
Get
(
"take"
);
return
CallNode
::
make
(
op
,
{
data
,
indices
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op._make.take"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
3
>
(
MakeTake
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"take"
)
.
describe
(
R"code(Take elements from an array along an axis.
When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
elements along a given axis.
**Note** that when axis is none the flattened input array is used.
Examples::
a = [[ 1, 2],
[ 3, 4]]
indices = [3, 0, 2]
take(a, indices) = [ 4, 1, 3]
a = [[ 1., 2.],
[ 3., 4.]]
indices = [1, 0]
take(a, indices, axis=1) = [[ 2., 1.],
[ 4., 3.]]
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"indices"
,
"Tensor"
,
"The indices tensor."
)
.
set_support_level
(
2
)
.
add_type_rel
(
"Take"
,
TakeRel
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_op_level3.py
View file @
b68d9dc0
...
@@ -91,6 +91,27 @@ def test_single_op():
...
@@ -91,6 +91,27 @@ def test_single_op():
tvm
.
relay
.
round
,
tvm
.
relay
.
abs
,
tvm
.
relay
.
negative
]:
tvm
.
relay
.
round
,
tvm
.
relay
.
abs
,
tvm
.
relay
.
negative
]:
check_single_op
(
opfunc
)
check_single_op
(
opfunc
)
def
test_take_infer_type
():
def
verify_take
(
dshape
,
indices_shape
,
oshape
,
axis
=
None
):
ib
=
relay
.
ir_builder
.
IRBuilder
()
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
indices
=
ib
.
param
(
"indices"
,
relay
.
ty
.
TensorType
(
indices_shape
,
"int32"
))
with
ib
.
function
(
x
,
indices
)
as
func
:
ib
.
ret
(
relay
.
take
(
x
.
var
,
indices
.
var
,
axis
=
axis
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
oshape
,
"float32"
)
d1
,
d2
,
d3
=
tvm
.
var
(
"d1"
),
tvm
.
var
(
"d2"
),
tvm
.
var
(
"d3"
)
d4
,
d5
,
d6
=
tvm
.
var
(
"d4"
),
tvm
.
var
(
"d5"
),
tvm
.
var
(
"d6"
)
verify_take
((
d1
,),
(
1
,),
(
1
,),
0
)
verify_take
((
4
,),
(
d1
,
d2
),
(
d1
,
d2
))
verify_take
((
3
,
3
,
3
),
(
1
,
d2
),
(
1
,
d2
))
verify_take
((
d1
,
d2
),
(
d3
,
d4
,
d5
),
(
d3
,
d4
,
d5
,
d2
),
0
)
verify_take
((
d1
,
d2
),
(
d3
,
d4
,
d5
),
(
d1
,
d3
,
d4
,
d5
),
1
)
verify_take
((
d1
,
d2
,
d3
,
d4
),
(
d5
,
d6
),
(
d1
,
d2
,
d5
,
d6
,
d4
),
-
2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_single_op
()
test_single_op
()
...
@@ -99,3 +120,4 @@ if __name__ == "__main__":
...
@@ -99,3 +120,4 @@ if __name__ == "__main__":
test_copy_infer_type
()
test_copy_infer_type
()
test_transpose_infer_type
()
test_transpose_infer_type
()
test_reshape_infer_type
()
test_reshape_infer_type
()
test_take_infer_type
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment