Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
69aefaa3
Commit
69aefaa3
authored
Jun 22, 2017
by
Haichen Shen
Committed by
Tianqi Chen
Jun 22, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG] Add all and any in the python API (#196)
* [LANG] Add all and any in the python API * compatible with python3
parent
7cc92ace
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
102 additions
and
1 deletions
+102
-1
python/tvm/api.py
+47
-0
python/tvm/expr.py
+7
-0
tests/python/unittest/test_arith_intset.py
+1
-1
tests/python/unittest/test_lang_basic.py
+47
-0
No files found.
python/tvm/api.py
View file @
69aefaa3
...
...
@@ -114,6 +114,53 @@ def var(name="tindex", dtype=int32):
return
_api_internal
.
_Var
(
name
,
dtype
)
def
any
(
*
args
):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if
not
args
:
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
return
args
[
0
]
ret
=
_make
.
Or
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_make
.
Or
(
ret
,
args
[
i
])
return
ret
def
all
(
*
args
):
"""Create a new experssion of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if
not
args
:
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
return
args
[
0
]
ret
=
_make
.
And
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_make
.
And
(
ret
,
args
[
i
])
return
ret
def
placeholder
(
shape
,
dtype
=
None
,
name
=
"placeholder"
):
"""Construct an empty tensor object.
...
...
python/tvm/expr.py
View file @
69aefaa3
...
...
@@ -74,6 +74,13 @@ class ExprOp(object):
def
__ge__
(
self
,
other
):
return
_make
.
GE
(
self
,
other
)
def
__nonzero__
(
self
):
raise
ValueError
(
"Cannot use and / or / not operator to Expr, hint: "
+
"use tvm.all / tvm.any instead"
)
def
__bool__
(
self
):
return
self
.
__nonzero__
()
def
equal
(
self
,
other
):
"""Build an equal check expression with other expr.
...
...
tests/python/unittest/test_arith_intset.py
View file @
69aefaa3
...
...
@@ -50,7 +50,7 @@ def test_check():
assert
res1
.
is_nothing
()
# multiple compare operators
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
a
+
b
>
3
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
(
a
+
b
>
3
)
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
assert
res1
.
is_nothing
()
# multiple target variable
...
...
tests/python/unittest/test_lang_basic.py
View file @
69aefaa3
...
...
@@ -81,6 +81,51 @@ def test_dtype():
y
=
tvm
.
var
(
'y'
)
assert
(
x
>
y
)
.
dtype
==
'uint1'
def
test_any
():
x
=
tvm
.
var
(
'x'
)
y
=
tvm
.
var
(
'y'
)
z
=
tvm
.
var
(
'z'
)
try
:
t
=
x
or
x
assert
False
except
ValueError
:
pass
try
:
tvm
.
any
()
assert
False
except
ValueError
:
pass
assert
str
(
tvm
.
any
(
x
<
y
))
==
'(
%
s <
%
s)'
%
(
x
.
name
,
y
.
name
)
assert
str
(
tvm
.
any
(
x
<
y
,
x
>
z
))
==
'((
%
s <
%
s) || (
%
s >
%
s))'
%
(
x
.
name
,
y
.
name
,
x
.
name
,
z
.
name
)
assert
str
(
tvm
.
any
(
x
<
y
,
y
>
z
+
1
,
x
<
z
*
2
))
==
\
'(((
%
s <
%
s) || (
%
s > (
%
s + 1))) || (
%
s < (
%
s*2)))'
%
(
x
.
name
,
y
.
name
,
y
.
name
,
z
.
name
,
x
.
name
,
z
.
name
)
def
test_all
():
x
=
tvm
.
var
(
'x'
)
y
=
tvm
.
var
(
'y'
)
z
=
tvm
.
var
(
'z'
)
try
:
t
=
x
and
x
assert
False
except
ValueError
:
pass
try
:
tvm
.
all
()
assert
False
except
ValueError
:
pass
assert
str
(
tvm
.
all
(
x
<
y
))
==
'(
%
s <
%
s)'
%
(
x
.
name
,
y
.
name
)
assert
str
(
tvm
.
all
(
x
<
y
,
x
>
z
))
==
'((
%
s <
%
s) && (
%
s >
%
s))'
%
(
x
.
name
,
y
.
name
,
x
.
name
,
z
.
name
)
assert
str
(
tvm
.
all
(
x
<
y
,
y
>
z
+
1
,
x
<
z
*
2
))
==
\
'(((
%
s <
%
s) && (
%
s > (
%
s + 1))) && (
%
s < (
%
s*2)))'
%
(
x
.
name
,
y
.
name
,
y
.
name
,
z
.
name
,
x
.
name
,
z
.
name
)
if
__name__
==
"__main__"
:
test_attr
()
test_const
()
...
...
@@ -92,3 +137,5 @@ if __name__ == "__main__":
test_let
()
test_dir
()
test_dtype
()
test_any
()
test_all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment