Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
fe0eac94
Commit
fe0eac94
authored
Nov 27, 2018
by
Siju
Committed by
Tianqi Chen
Nov 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY]take and transpose comp and schd (#2135)
parent
bcacb764
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
21 deletions
+100
-21
nnvm/src/top/tensor/transform.cc
+1
-1
python/tvm/relay/op/_transform.py
+2
-0
src/relay/op/tensor/transform.cc
+29
-2
tests/python/relay/test_op_level3.py
+47
-0
topi/include/topi/transform.h
+21
-18
No files found.
nnvm/src/top/tensor/transform.cc
View file @
fe0eac94
...
@@ -874,7 +874,7 @@ Examples::
...
@@ -874,7 +874,7 @@ Examples::
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
const
Array
<
Tensor
>&
out_info
)
{
const
TransposeParam
&
param
=
nnvm
::
get
<
TransposeParam
>
(
attrs
.
parsed
);
const
TransposeParam
&
param
=
nnvm
::
get
<
TransposeParam
>
(
attrs
.
parsed
);
auto
axes
=
ShapeToArray
(
param
.
axes
);
auto
axes
=
ShapeTo
Int
Array
(
param
.
axes
);
return
Array
<
Tensor
>
{
topi
::
transpose
(
inputs
[
0
],
axes
)
};
return
Array
<
Tensor
>
{
topi
::
transpose
(
inputs
[
0
],
axes
)
};
})
})
.
set_attr
<
FGradient
>
(
.
set_attr
<
FGradient
>
(
...
...
python/tvm/relay/op/_transform.py
View file @
fe0eac94
...
@@ -15,3 +15,5 @@ _reg.register_schedule("cast", schedule_broadcast)
...
@@ -15,3 +15,5 @@ _reg.register_schedule("cast", schedule_broadcast)
_reg
.
register_schedule
(
"strided_slice"
,
schedule_injective
)
_reg
.
register_schedule
(
"strided_slice"
,
schedule_injective
)
_reg
.
register_schedule
(
"slice_like"
,
schedule_injective
)
_reg
.
register_schedule
(
"slice_like"
,
schedule_injective
)
_reg
.
register_schedule
(
"split"
,
schedule_injective
)
_reg
.
register_schedule
(
"split"
,
schedule_injective
)
_reg
.
register_schedule
(
"take"
,
schedule_injective
)
_reg
.
register_schedule
(
"transpose"
,
schedule_injective
)
src/relay/op/tensor/transform.cc
View file @
fe0eac94
...
@@ -282,6 +282,15 @@ bool TransposeRel(const Array<Type>& types,
...
@@ -282,6 +282,15 @@ bool TransposeRel(const Array<Type>& types,
return
true
;
return
true
;
}
}
Array
<
Tensor
>
TransposeCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Target
&
target
)
{
const
auto
*
param
=
attrs
.
as
<
TransposeAttrs
>
();
CHECK
(
param
!=
nullptr
);
return
Array
<
Tensor
>
{
topi
::
transpose
(
inputs
[
0
],
param
->
axes
)
};
}
Expr
MakeTranspose
(
Expr
data
,
Expr
MakeTranspose
(
Expr
data
,
Array
<
Integer
>
axes
)
{
Array
<
Integer
>
axes
)
{
auto
attrs
=
make_node
<
TransposeAttrs
>
();
auto
attrs
=
make_node
<
TransposeAttrs
>
();
...
@@ -307,7 +316,9 @@ RELAY_REGISTER_OP("transpose")
...
@@ -307,7 +316,9 @@ RELAY_REGISTER_OP("transpose")
.
set_attrs_type_key
(
"relay.attrs.TransposeAttrs"
)
.
set_attrs_type_key
(
"relay.attrs.TransposeAttrs"
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
3
)
.
set_support_level
(
3
)
.
add_type_rel
(
"Transpose"
,
TransposeRel
);
.
add_type_rel
(
"Transpose"
,
TransposeRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
TransposeCompute
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kInjective
);
/* relay.reshape */
/* relay.reshape */
...
@@ -575,6 +586,19 @@ bool TakeRel(const Array<Type>& types,
...
@@ -575,6 +586,19 @@ bool TakeRel(const Array<Type>& types,
return
true
;
return
true
;
}
}
Array
<
Tensor
>
TakeCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Target
&
target
)
{
const
auto
*
param
=
attrs
.
as
<
TakeAttrs
>
();
CHECK
(
param
!=
nullptr
);
if
(
!
param
->
axis
.
defined
())
{
return
Array
<
Tensor
>
{
topi
::
take
(
inputs
[
0
],
inputs
[
1
])
};
}
else
{
return
Array
<
Tensor
>
{
topi
::
take
(
inputs
[
0
],
inputs
[
1
],
param
->
axis
)
};
}
}
Expr
MakeTake
(
Expr
data
,
Expr
MakeTake
(
Expr
data
,
Expr
indices
,
Expr
indices
,
Integer
axis
)
{
Integer
axis
)
{
...
@@ -617,7 +641,10 @@ Examples::
...
@@ -617,7 +641,10 @@ Examples::
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"indices"
,
"Tensor"
,
"The indices tensor."
)
.
add_argument
(
"indices"
,
"Tensor"
,
"The indices tensor."
)
.
set_support_level
(
2
)
.
set_support_level
(
2
)
.
add_type_rel
(
"Take"
,
TakeRel
);
.
add_type_rel
(
"Take"
,
TakeRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
TakeCompute
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kInjective
);
// Init ops
// Init ops
TVM_REGISTER_NODE_TYPE
(
InitOpAttrs
);
TVM_REGISTER_NODE_TYPE
(
InitOpAttrs
);
...
...
tests/python/relay/test_op_level3.py
View file @
fe0eac94
...
@@ -87,6 +87,22 @@ def test_transpose_infer_type():
...
@@ -87,6 +87,22 @@ def test_transpose_infer_type():
assert
yy
.
checked_type
==
relay
.
TensorType
(
assert
yy
.
checked_type
==
relay
.
TensorType
(
(
t
,
n
,
100
),
"float32"
)
(
t
,
n
,
100
),
"float32"
)
def
test_transpose
():
def
verify_transpose
(
dshape
,
axes
):
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
(
dshape
,
"float32"
))
z
=
relay
.
transpose
(
x
,
axes
=
axes
)
func
=
relay
.
Function
([
x
],
z
)
x_data
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
dshape
)
.
astype
(
"float32"
)
ref_res
=
np
.
transpose
(
x_data
,
axes
=
axes
)
for
target
,
ctx
in
ctx_list
():
for
kind
in
[
"graph"
,
"debug"
]:
intrp
=
relay
.
create_executor
(
kind
,
ctx
=
ctx
,
target
=
target
)
op_res
=
intrp
.
evaluate
(
func
)(
x_data
)
tvm
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
)
verify_transpose
((
2
,
3
,
4
),
(
0
,
2
,
1
))
def
test_squeeze_infer_type
():
def
test_squeeze_infer_type
():
n
,
t
,
d
=
1
,
4
,
1
n
,
t
,
d
=
1
,
4
,
1
...
@@ -202,6 +218,35 @@ def test_take_infer_type():
...
@@ -202,6 +218,35 @@ def test_take_infer_type():
verify_take
((
d1
,
d2
),
(
d3
,
d4
,
d5
),
(
d1
,
d3
,
d4
,
d5
),
1
)
verify_take
((
d1
,
d2
),
(
d3
,
d4
,
d5
),
(
d1
,
d3
,
d4
,
d5
),
1
)
verify_take
((
d1
,
d2
,
d3
,
d4
),
(
d5
,
d6
),
(
d1
,
d2
,
d5
,
d6
,
d4
),
-
2
)
verify_take
((
d1
,
d2
,
d3
,
d4
),
(
d5
,
d6
),
(
d1
,
d2
,
d5
,
d6
,
d4
),
-
2
)
def
test_take
():
def
verify_take
(
src_shape
,
indices_src
,
axis
=
None
):
src_dtype
=
"float32"
indices_dtype
=
"int32"
indices_src
=
np
.
array
(
indices_src
,
dtype
=
indices_dtype
)
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
(
src_shape
,
src_dtype
))
indices
=
relay
.
var
(
"indices"
,
relay
.
TensorType
(
indices_src
.
shape
,
indices_dtype
))
z
=
relay
.
take
(
x
,
indices
,
axis
=
axis
)
func
=
relay
.
Function
([
x
,
indices
],
z
)
x_data
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
src_shape
)
.
astype
(
src_dtype
)
ref_res
=
np
.
take
(
x_data
,
indices
=
indices_src
,
axis
=
axis
)
for
target
,
ctx
in
ctx_list
():
for
kind
in
[
"graph"
,
"debug"
]:
intrp
=
relay
.
create_executor
(
kind
,
ctx
=
ctx
,
target
=
target
)
op_res
=
intrp
.
evaluate
(
func
)(
x_data
,
indices_src
)
tvm
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
)
verify_take
((
4
,),
[
1
])
verify_take
((
4
,),
[[
0
,
1
,
2
,
3
]])
verify_take
((
3
,
3
,
3
),
[[
11
,
25
]])
verify_take
((
4
,),
[[
0
,
1
],[
2
,
3
]])
verify_take
((
4
,),
[
1
],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
def
test_split_infer_type
():
def
test_split_infer_type
():
def
verify_split
(
dshape
,
indices_or_sections
,
ret_type
,
axis
=
None
):
def
verify_split
(
dshape
,
indices_or_sections
,
ret_type
,
axis
=
None
):
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
...
@@ -360,11 +405,13 @@ if __name__ == "__main__":
...
@@ -360,11 +405,13 @@ if __name__ == "__main__":
test_unary_identity
()
test_unary_identity
()
test_clip
()
test_clip
()
test_transpose_infer_type
()
test_transpose_infer_type
()
test_transpose
()
test_reshape_infer_type
()
test_reshape_infer_type
()
test_reshape
()
test_reshape
()
test_reshape_like_infer_type
()
test_reshape_like_infer_type
()
test_reshape_like
()
test_reshape_like
()
test_take_infer_type
()
test_take_infer_type
()
test_take
()
test_full
()
test_full
()
test_full_like
()
test_full_like
()
test_infer_type_leaky_relu
()
test_infer_type_leaky_relu
()
...
...
topi/include/topi/transform.h
View file @
fe0eac94
...
@@ -86,42 +86,45 @@ inline Tensor expand_dims(const Tensor& x,
...
@@ -86,42 +86,45 @@ inline Tensor expand_dims(const Tensor& x,
* \return A Tensor whose op member is the transpose operation
* \return A Tensor whose op member is the transpose operation
*/
*/
inline
Tensor
transpose
(
const
Tensor
&
x
,
inline
Tensor
transpose
(
const
Tensor
&
x
,
Array
<
Exp
r
>
axes
,
Array
<
Intege
r
>
axes
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kInjective
)
{
std
::
string
tag
=
kInjective
)
{
if
(
axes
.
size
()
==
0
)
{
if
(
!
axes
.
defined
()
||
axes
.
size
()
==
0
)
{
axes
=
Array
<
Exp
r
>
();
axes
=
Array
<
Intege
r
>
();
for
(
int
i
=
static_cast
<
int
>
(
x
->
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
static_cast
<
int
>
(
x
->
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
axes
.
push_back
(
i
);
axes
.
push_back
(
i
);
}
}
}
}
auto
axes_val
=
GetConstIntValues
(
axes
,
"axes"
);
Array
<
Expr
>
new_shape
;
for
(
size_t
i
=
0
;
i
<
axes_val
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
int
axis
=
axes_val
[
i
];
int
axis
=
static_cast
<
int
>
(
axes
[
i
]
->
value
);
if
(
axes_val
[
i
]
<
0
)
{
int
new_axis
=
axis
;
axes_val
[
i
]
=
static_cast
<
int
>
(
x
->
shape
.
size
())
+
axes_val
[
i
];
if
(
axis
<
0
)
{
new_axis
=
static_cast
<
int
>
(
x
->
shape
.
size
())
+
axis
;
axes
.
Set
(
i
,
new_axis
);
}
}
CHECK
((
0
<=
axes_val
[
i
])
&&
(
axes_val
[
i
]
<
static_cast
<
int
>
(
x
->
shape
.
size
())))
CHECK
((
new_axis
>=
0
)
&&
(
new_axis
<
static_cast
<
int
>
(
x
->
shape
.
size
())))
<<
"axis="
<<
axis
<<
" is invalid for the "
<<
"axis="
<<
axis
<<
" is invalid for the "
<<
static_cast
<
int
>
(
x
->
shape
.
size
())
<<
"-dimensional input tensor"
;
<<
static_cast
<
int
>
(
x
->
shape
.
size
())
<<
"-dimensional input tensor"
;
CHECK
(
1
==
std
::
count
(
std
::
begin
(
axes_val
),
std
::
end
(
axes_val
),
axes_val
[
i
]))
for
(
size_t
j
=
0
;
j
<
axes
.
size
();
++
j
)
{
<<
"repeated axis in transpose"
;
if
(
i
!=
j
)
{
CHECK
(
new_axis
!=
static_cast
<
int
>
(
axes
[
j
]
->
value
))
<<
"repeated axis in transpose"
;
}
}
Array
<
Expr
>
new_shape
;
for
(
size_t
i
=
0
;
i
<
axes_val
.
size
();
++
i
)
{
new_shape
.
push_back
(
x
->
shape
[
axes_val
[
i
]]);
}
}
new_shape
.
push_back
(
x
->
shape
[
new_axis
]);
}
return
compute
(
return
compute
(
new_shape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
new_shape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
std
::
vector
<
Expr
>
idx
;
std
::
vector
<
Expr
>
idx
;
for
(
size_t
i
=
0
;
i
<
axes
_val
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
idx
.
push_back
(
1
);
idx
.
push_back
(
1
);
}
}
for
(
size_t
i
=
0
;
i
<
axes_val
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
idx
[
axes_val
[
i
]]
=
indices
[
i
];
int
axis
=
static_cast
<
int
>
(
axes
[
i
]
->
value
);
idx
[
axis
]
=
indices
[
i
];
}
}
return
x
(
idx
);
return
x
(
idx
);
},
name
,
tag
);
},
name
,
tag
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment