Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6e36da35
Unverified
Commit
6e36da35
authored
Apr 16, 2020
by
Samuel
Committed by
GitHub
Apr 16, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI][PYTORCH]Logical & Bitwise operator support (#5341)
parent
cc8cacb1
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
222 additions
and
2 deletions
+222
-2
docs/api/python/topi.rst
+2
-0
docs/langref/relay_op.rst
+1
-0
python/tvm/relay/frontend/pytorch.py
+65
-1
python/tvm/relay/op/_tensor.py
+2
-0
python/tvm/relay/op/tensor.py
+17
-0
src/relay/op/tensor/binary.cc
+6
-0
tests/python/frontend/pytorch/test_forward.py
+94
-1
topi/include/topi/broadcast.h
+13
-0
topi/python/topi/broadcast.py
+19
-0
topi/src/broadcast.cc
+1
-0
topi/tests/python/test_topi_broadcast.py
+2
-0
No files found.
docs/api/python/topi.rst
View file @
6e36da35
...
...
@@ -99,6 +99,7 @@ List of operators
topi.logical_and
topi.logical_or
topi.logical_not
topi.logical_xor
topi.arange
topi.stack
topi.repeat
...
...
@@ -193,6 +194,7 @@ topi
.. autofunction:: topi.logical_and
.. autofunction:: topi.logical_or
.. autofunction:: topi.logical_not
.. autofunction:: topi.logical_xor
topi.nn
~~~~~~~
...
...
docs/langref/relay_op.rst
View file @
6e36da35
...
...
@@ -150,6 +150,7 @@ This level enables additional math and transform operators.
tvm.relay.logical_and
tvm.relay.logical_or
tvm.relay.logical_not
tvm.relay.logical_xor
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.power
...
...
python/tvm/relay/frontend/pytorch.py
View file @
6e36da35
...
...
@@ -1168,7 +1168,6 @@ def _ceil():
def
_clamp
():
def
_impl
(
inputs
,
input_types
):
print
(
inputs
,
input_types
)
data
=
inputs
[
0
]
amin
=
inputs
[
1
]
if
inputs
[
1
]
else
np
.
finfo
(
np
.
float32
)
.
min
amax
=
inputs
[
2
]
if
inputs
[
2
]
else
np
.
finfo
(
np
.
float32
)
.
max
...
...
@@ -1298,6 +1297,67 @@ def _mm():
return
_impl
def
_bitwise_not
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
# The input tensor must be of integral or Boolean types.
# For bool tensors, it computes the logical NOT
if
input_types
[
0
]
==
"bool"
:
out
=
_op
.
logical_not
(
_op
.
cast
(
data
,
"bool"
))
else
:
out
=
_op
.
bitwise_not
(
_op
.
cast
(
data
,
"int"
))
return
out
return
_impl
def
_bitwise_xor
():
def
_impl
(
inputs
,
input_types
):
lhs
=
inputs
[
0
]
import
torch
if
isinstance
(
inputs
[
1
],
_expr
.
Var
):
rhs
=
inputs
[
1
]
elif
isinstance
(
inputs
[
1
],
torch
.
Tensor
):
rhs
=
_wrap_const
(
inputs
[
1
]
.
numpy
())
else
:
msg
=
"Data type
%
s could not be parsed in bitwise_xor operator."
%
(
type
(
inputs
[
1
]))
raise
AssertionError
(
msg
)
lhs
=
_op
.
cast
(
lhs
,
"bool"
)
if
input_types
[
0
]
==
"bool"
else
_op
.
cast
(
lhs
,
"int"
)
rhs
=
_op
.
cast
(
rhs
,
"bool"
)
if
input_types
[
1
]
==
"bool"
else
_op
.
cast
(
rhs
,
"int"
)
return
_op
.
bitwise_xor
(
lhs
,
rhs
)
return
_impl
def
_logical_not
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
return
_op
.
logical_not
(
_op
.
cast
(
data
,
"bool"
))
return
_impl
def
_logical_xor
():
def
_impl
(
inputs
,
input_types
):
lhs
=
_op
.
cast
(
inputs
[
0
],
"bool"
)
import
torch
if
isinstance
(
inputs
[
1
],
_expr
.
Var
):
rhs
=
inputs
[
1
]
elif
isinstance
(
inputs
[
1
],
torch
.
Tensor
):
rhs
=
_wrap_const
(
inputs
[
1
]
.
numpy
())
else
:
msg
=
"Data type
%
s could not be parsed in logical_xor operator."
%
(
type
(
inputs
[
1
]))
raise
AssertionError
(
msg
)
rhs
=
_op
.
cast
(
rhs
,
"bool"
)
return
_op
.
logical_xor
(
lhs
,
rhs
)
return
_impl
def
_isfinite
():
def
_impl
(
inputs
,
input_types
):
return
_op
.
isfinite
(
inputs
[
0
])
...
...
@@ -1524,6 +1584,10 @@ def _get_convert_map(prelude):
"aten::ge"
:
_elemwise
(
"greater_equal"
),
"aten::ne"
:
_elemwise
(
"not_equal"
),
"aten::eq"
:
_elemwise
(
"equal"
),
"aten::logical_not"
:
_logical_not
(),
"aten::logical_xor"
:
_logical_xor
(),
"aten::bitwise_not"
:
_bitwise_not
(),
"aten::bitwise_xor"
:
_bitwise_xor
(),
"aten::isfinite"
:
_isfinite
(),
"aten::isnan"
:
_isnan
(),
"aten::Bool"
:
_Bool
(),
...
...
python/tvm/relay/op/_tensor.py
View file @
6e36da35
...
...
@@ -53,6 +53,7 @@ register_broadcast_schedule("copy")
register_broadcast_schedule
(
"logical_not"
)
register_broadcast_schedule
(
"logical_and"
)
register_broadcast_schedule
(
"logical_or"
)
register_broadcast_schedule
(
"logical_xor"
)
register_broadcast_schedule
(
"bitwise_not"
)
register_broadcast_schedule
(
"bitwise_and"
)
register_broadcast_schedule
(
"bitwise_or"
)
...
...
@@ -205,6 +206,7 @@ register_shape_func("mod", False, broadcast_shape_func)
register_shape_func
(
"floor_mod"
,
False
,
broadcast_shape_func
)
register_shape_func
(
"logical_and"
,
False
,
broadcast_shape_func
)
register_shape_func
(
"logical_or"
,
False
,
broadcast_shape_func
)
register_shape_func
(
"logical_xor"
,
False
,
broadcast_shape_func
)
register_shape_func
(
"bitwise_not"
,
False
,
broadcast_shape_func
)
register_shape_func
(
"bitwise_and"
,
False
,
broadcast_shape_func
)
register_shape_func
(
"bitwise_or"
,
False
,
broadcast_shape_func
)
...
...
python/tvm/relay/op/tensor.py
View file @
6e36da35
...
...
@@ -537,6 +537,23 @@ def logical_or(lhs, rhs):
return
_make
.
logical_or
(
lhs
,
rhs
)
def
logical_xor
(
lhs
,
rhs
):
"""logical XOR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return
_make
.
logical_xor
(
lhs
,
rhs
)
def
bitwise_and
(
lhs
,
rhs
):
"""bitwise AND with numpy-style broadcasting.
...
...
src/relay/op/tensor/binary.cc
View file @
6e36da35
...
...
@@ -123,6 +123,12 @@ RELAY_REGISTER_BINARY_OP("logical_or")
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
logical_or
));
RELAY_REGISTER_BINARY_OP
(
"logical_xor"
)
.
describe
(
"Elementwise logical XOR with broadcasting"
)
.
set_support_level
(
4
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
RELAY_BINARY_COMPUTE
(
topi
::
logical_xor
));
RELAY_REGISTER_BINARY_OP
(
"bitwise_and"
)
.
describe
(
"Elementwise bitwise AND with broadcasting"
)
.
set_support_level
(
4
)
...
...
tests/python/frontend/pytorch/test_forward.py
View file @
6e36da35
...
...
@@ -159,7 +159,7 @@ def verify_model(model_name, input_data=[],
if
isinstance
(
baseline_outputs
,
tuple
):
baseline_outputs
=
tuple
(
out
.
cpu
()
.
numpy
()
for
out
in
baseline_outputs
)
else
:
baseline_outputs
=
(
baseline_outputs
.
float
()
.
cpu
()
.
numpy
(),)
baseline_outputs
=
(
baseline_outputs
.
cpu
()
.
numpy
(),)
trace
=
torch
.
jit
.
trace
(
baseline_model
,
baseline_input
)
.
float
()
.
eval
()
...
...
@@ -1600,6 +1600,95 @@ def test_forward_topk():
verify_model
(
Topk6
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_logical_not
():
torch
.
set_grad_enabled
(
False
)
class
LogicalNot1
(
Module
):
def
forward
(
self
,
*
args
):
return
torch
.
logical_not
(
args
[
0
])
input_data
=
torch
.
tensor
([
True
,
False
])
verify_model
(
LogicalNot1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
input_data
=
torch
.
tensor
([
0
,
1
,
-
10
],
dtype
=
torch
.
int8
)
verify_model
(
LogicalNot1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
input_data
=
torch
.
tensor
([
0.
,
1.5
,
-
10.
],
dtype
=
torch
.
double
)
verify_model
(
LogicalNot1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
input_data
=
torch
.
tensor
([
0.
,
1.
,
-
10.
],
dtype
=
torch
.
int32
)
verify_model
(
LogicalNot1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_bitwise_not
():
torch
.
set_grad_enabled
(
False
)
class
BitwiseNot1
(
Module
):
def
forward
(
self
,
*
args
):
return
torch
.
bitwise_not
(
args
[
0
])
input_data
=
torch
.
tensor
([
0
,
1
,
-
10
],
dtype
=
torch
.
int8
)
verify_model
(
BitwiseNot1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
input_data
=
torch
.
tensor
([
0.
,
1.
,
-
10.
],
dtype
=
torch
.
int32
)
verify_model
(
BitwiseNot1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
input_data
=
torch
.
tensor
([
True
,
False
])
verify_model
(
BitwiseNot1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_bitwise_xor
():
torch
.
set_grad_enabled
(
False
)
class
BitwiseXor1
(
Module
):
def
forward
(
self
,
*
args
):
return
torch
.
bitwise_xor
(
args
[
0
],
args
[
1
])
class
BitwiseXor2
(
Module
):
def
forward
(
self
,
*
args
):
rhs
=
torch
.
tensor
([
1
,
0
,
3
],
dtype
=
torch
.
int8
)
if
torch
.
cuda
.
is_available
():
rhs
=
rhs
.
cuda
()
return
torch
.
bitwise_xor
(
args
[
0
],
rhs
)
lhs
=
torch
.
tensor
([
-
1
,
-
2
,
3
],
dtype
=
torch
.
int8
)
rhs
=
torch
.
tensor
([
1
,
0
,
3
],
dtype
=
torch
.
int8
)
verify_model
(
BitwiseXor1
()
.
float
()
.
eval
(),
input_data
=
[
lhs
,
rhs
])
lhs
=
torch
.
tensor
([
True
,
True
,
False
])
rhs
=
torch
.
tensor
([
False
,
True
,
False
])
verify_model
(
BitwiseXor1
()
.
float
()
.
eval
(),
input_data
=
[
lhs
,
rhs
])
lhs
=
torch
.
tensor
([
-
1
,
-
2
,
3
],
dtype
=
torch
.
int8
)
verify_model
(
BitwiseXor2
()
.
float
()
.
eval
(),
input_data
=
[
lhs
])
def
test_forward_logical_xor
():
torch
.
set_grad_enabled
(
False
)
class
LogicalXor1
(
Module
):
def
forward
(
self
,
*
args
):
return
torch
.
logical_xor
(
args
[
0
],
args
[
1
])
class
LogicalXor2
(
Module
):
def
forward
(
self
,
*
args
):
rhs
=
torch
.
tensor
([
1
,
0
,
3
],
dtype
=
torch
.
int8
)
if
torch
.
cuda
.
is_available
():
rhs
=
rhs
.
cuda
()
return
torch
.
logical_xor
(
args
[
0
],
rhs
)
lhs
=
torch
.
tensor
([
-
1
,
-
2
,
3
],
dtype
=
torch
.
int8
)
rhs
=
torch
.
tensor
([
1
,
0
,
3
],
dtype
=
torch
.
int8
)
verify_model
(
LogicalXor1
()
.
float
()
.
eval
(),
input_data
=
[
lhs
,
rhs
])
lhs
=
torch
.
tensor
([
True
,
True
,
False
])
rhs
=
torch
.
tensor
([
False
,
True
,
False
])
verify_model
(
LogicalXor1
()
.
float
()
.
eval
(),
input_data
=
[
lhs
,
rhs
])
lhs
=
torch
.
tensor
([
-
1
,
-
2
,
3
],
dtype
=
torch
.
int8
)
verify_model
(
LogicalXor2
()
.
float
()
.
eval
(),
input_data
=
[
lhs
])
if
__name__
==
"__main__"
:
# Single operator tests
test_forward_add
()
...
...
@@ -1663,6 +1752,10 @@ if __name__ == "__main__":
test_forward_clamp
()
test_forward_floor
()
test_forward_round
()
test_forward_logical_not
()
test_forward_bitwise_not
()
test_forward_bitwise_xor
()
test_forward_logical_xor
()
test_forward_isfinite
()
test_forward_isnan
()
test_forward_isinf
()
...
...
topi/include/topi/broadcast.h
View file @
6e36da35
...
...
@@ -141,6 +141,19 @@ TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD
(
operator
||
,
logical_or
);
/*!
* \fn logical_xor
* \brief Compute A ^ B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP
(
logical_xor
,
{
return
a
^
b
;
});
/*!
* \fn bitwise_and
* \brief Compute A & B with auto-broadcasting.
*
...
...
topi/python/topi/broadcast.py
View file @
6e36da35
...
...
@@ -420,6 +420,25 @@ def logical_or(lhs, rhs):
return
_cpp
.
logical_or
(
lhs
,
rhs
)
def
logical_xor
(
lhs
,
rhs
):
"""Compute element-wise logical xor of data.
Parameters
----------
lhs : tvm.te.Tensor or Expr
The left operand
rhs : tvm.te.Tensor or Expr
The right operand
Returns
-------
ret : tvm.te.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return
_cpp
.
logical_xor
(
lhs
,
rhs
)
def
bitwise_and
(
lhs
,
rhs
):
"""Compute element-wise bitwise and of data.
...
...
topi/src/broadcast.cc
View file @
6e36da35
...
...
@@ -65,6 +65,7 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP
(
"topi.left_shift"
,
topi
::
left_shift
);
TOPI_REGISTER_BCAST_OP
(
"topi.logical_and"
,
topi
::
logical_and
);
TOPI_REGISTER_BCAST_OP
(
"topi.logical_or"
,
topi
::
logical_or
);
TOPI_REGISTER_BCAST_OP
(
"topi.logical_xor"
,
topi
::
logical_xor
);
TOPI_REGISTER_BCAST_OP
(
"topi.bitwise_and"
,
topi
::
bitwise_and
);
TOPI_REGISTER_BCAST_OP
(
"topi.bitwise_or"
,
topi
::
bitwise_or
);
TOPI_REGISTER_BCAST_OP
(
"topi.bitwise_xor"
,
topi
::
bitwise_xor
);
...
...
topi/tests/python/test_topi_broadcast.py
View file @
6e36da35
...
...
@@ -355,6 +355,8 @@ def test_logical_binary_ele():
test_apply
(
topi
.
logical_and
,
"logical_and"
,
np
.
logical_and
,
[
True
,
False
],
[
False
,
False
])
test_apply
(
topi
.
logical_or
,
"logical_or"
,
np
.
logical_or
,
True
,
False
)
test_apply
(
topi
.
logical_or
,
"logical_or"
,
np
.
logical_or
,
[
True
,
False
],
[
False
,
False
])
test_apply
(
topi
.
logical_xor
,
"logical_xor"
,
np
.
logical_xor
,
True
,
False
)
test_apply
(
topi
.
logical_xor
,
"logical_xor"
,
np
.
logical_xor
,
[
True
,
False
],
[
False
,
False
])
def
test_bitwise_and
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment