Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6819145a
Commit
6819145a
authored
Oct 13, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
checkin domain
parent
bda95817
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
174 additions
and
17 deletions
+174
-17
python/tvm/__init__.py
+1
-0
python/tvm/domain.py
+38
-0
python/tvm/expr.py
+21
-1
python/tvm/expr_util.py
+22
-3
python/tvm/op.py
+23
-0
python/tvm/tensor.py
+48
-11
tests/python/test_tensor.py
+21
-2
No files found.
python/tvm/__init__.py
View file @
6819145a
...
@@ -5,3 +5,4 @@ from .op import *
...
@@ -5,3 +5,4 @@ from .op import *
from
.expr
import
Var
,
const
from
.expr
import
Var
,
const
from
.expr_util
import
*
from
.expr_util
import
*
from
.tensor
import
Tensor
from
.tensor
import
Tensor
from
.domain
import
RDom
,
Range
python/tvm/domain.py
0 → 100644
View file @
6819145a
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
from
.
import
expr_util
as
_expr_util
class
Range
(
object
):
"""Represent a range in one dimension.
"""
def
__init__
(
self
,
begin
,
end
=
None
):
if
end
is
None
:
end
=
begin
begin
=
_expr
.
const
(
0
)
self
.
begin
=
_expr
.
_symbol
(
begin
)
self
.
end
=
_expr
.
_symbol
(
end
)
self
.
extent
=
_expr_util
.
simplify
(
end
-
begin
)
def
__str__
(
self
):
return
"(
%
s,
%
s)"
%
(
_expr_util
.
format_str
(
self
.
begin
),
_expr_util
.
format_str
(
self
.
end
))
def
__repr__
(
self
):
return
self
.
__str__
()
class
RDom
(
object
):
"""reduction Domain
"""
def
__init__
(
self
,
domain
):
if
isinstance
(
domain
,
Range
):
domain
=
[
domain
]
self
.
index
=
[]
self
.
domain
=
domain
for
i
in
range
(
len
(
domain
)):
self
.
index
.
append
(
_expr
.
Var
(
"rd_index_
%
d_"
%
i
))
"""Use list of ranges as domain"""
Domain
=
list
python/tvm/expr.py
View file @
6819145a
...
@@ -108,7 +108,27 @@ class UnaryOpExpr(Expr):
...
@@ -108,7 +108,27 @@ class UnaryOpExpr(Expr):
self
.
src
=
_symbol
(
src
)
self
.
src
=
_symbol
(
src
)
def
children
(
self
):
def
children
(
self
):
return
(
self
.
src
)
return
(
self
.
src
,)
class
ReduceExpr
(
Expr
):
def
__init__
(
self
,
op
,
src
,
rdom
):
self
.
op
=
op
self
.
src
=
src
self
.
rdom
=
rdom
def
children
(
self
):
return
(
self
.
src
,)
class
TensorReadExpr
(
Expr
):
"""Tensor read expression, tensor[indices]"""
def
__init__
(
self
,
tensor
,
indices
):
self
.
tensor
=
tensor
self
.
indices
=
indices
def
children
(
self
):
return
self
.
indices
def
const
(
value
):
def
const
(
value
):
...
...
python/tvm/expr_util.py
View file @
6819145a
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
op
as
_op
from
.
import
op
as
_op
from
.
import
tensor
as
_tensor
def
expr_with_new_children
(
e
,
children
):
def
expr_with_new_children
(
e
,
children
):
"""Returns same expr as e but with new children
"""Returns same expr as e but with new children
...
@@ -50,10 +49,27 @@ def transform(e, f):
...
@@ -50,10 +49,27 @@ def transform(e, f):
result : return value of f
result : return value of f
The final result of transformation.
The final result of transformation.
"""
"""
assert
isinstance
(
e
,
_expr
.
Expr
)
if
not
isinstance
(
e
,
_expr
.
Expr
):
raise
TypeError
(
"Cannot handle type
%
s"
%
type
(
e
))
return
f
(
e
,
[
transform
(
c
,
f
)
for
c
in
e
.
children
()])
return
f
(
e
,
[
transform
(
c
,
f
)
for
c
in
e
.
children
()])
def
visit
(
e
,
f
):
"""Apply f to each element of e
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e)
"""
assert
isinstance
(
e
,
_expr
.
Expr
)
for
c
in
e
.
children
():
visit
(
c
,
f
)
f
(
e
)
def
format_str
(
expr
):
def
format_str
(
expr
):
"""change expression to string.
"""change expression to string.
...
@@ -76,12 +92,15 @@ def format_str(expr):
...
@@ -76,12 +92,15 @@ def format_str(expr):
return
str
(
e
.
value
)
return
str
(
e
.
value
)
elif
isinstance
(
e
,
_expr
.
Var
):
elif
isinstance
(
e
,
_expr
.
Var
):
return
e
.
name
return
e
.
name
elif
isinstance
(
e
,
_
tenso
r
.
TensorReadExpr
):
elif
isinstance
(
e
,
_
exp
r
.
TensorReadExpr
):
return
"
%
s(
%
s)"
%
(
e
.
tensor
.
name
,
','
.
join
(
result_children
))
return
"
%
s(
%
s)"
%
(
e
.
tensor
.
name
,
','
.
join
(
result_children
))
elif
isinstance
(
e
,
_expr
.
ReduceExpr
):
return
e
.
op
.
format_reduce_str
(
result_children
[
0
],
e
.
rdom
.
domain
)
else
:
else
:
raise
TypeError
(
"Do not know how to handle type "
+
str
(
type
(
e
)))
raise
TypeError
(
"Do not know how to handle type "
+
str
(
type
(
e
)))
return
transform
(
expr
,
make_str
)
return
transform
(
expr
,
make_str
)
def
simplify
(
expr
):
def
simplify
(
expr
):
"""simplify expression
"""simplify expression
...
...
python/tvm/op.py
View file @
6819145a
...
@@ -22,15 +22,20 @@ def canonical_to_expr(c):
...
@@ -22,15 +22,20 @@ def canonical_to_expr(c):
else
:
else
:
return
_expr
.
const
(
0
)
return
_expr
.
const
(
0
)
class
BinaryOp
(
object
):
class
BinaryOp
(
object
):
"""Base class of binary operator"""
"""Base class of binary operator"""
def
__call__
(
self
,
lhs
,
rhs
):
def
__call__
(
self
,
lhs
,
rhs
):
return
_expr
.
BinaryOpExpr
(
self
,
lhs
,
rhs
)
return
_expr
.
BinaryOpExpr
(
self
,
lhs
,
rhs
)
class
AddOp
(
BinaryOp
):
class
AddOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s +
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s +
%
s)'
%
(
lhs
,
rhs
)
def
format_reduce_str
(
self
,
src
,
rd
):
return
"reduce_sum(
%
s, rdom=
%
s)"
%
(
src
,
str
(
rd
))
def
canonical
(
self
,
lhs
,
rhs
):
def
canonical
(
self
,
lhs
,
rhs
):
lhs
=
lhs
.
copy
()
lhs
=
lhs
.
copy
()
for
k
,
v
in
rhs
.
items
():
for
k
,
v
in
rhs
.
items
():
...
@@ -40,6 +45,7 @@ class AddOp(BinaryOp):
...
@@ -40,6 +45,7 @@ class AddOp(BinaryOp):
lhs
[
k
]
=
v
lhs
[
k
]
=
v
return
lhs
return
lhs
class
SubOp
(
BinaryOp
):
class
SubOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s -
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s -
%
s)'
%
(
lhs
,
rhs
)
...
@@ -53,6 +59,7 @@ class SubOp(BinaryOp):
...
@@ -53,6 +59,7 @@ class SubOp(BinaryOp):
lhs
[
k
]
=
-
v
lhs
[
k
]
=
-
v
return
lhs
return
lhs
class
MulOp
(
BinaryOp
):
class
MulOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s *
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s *
%
s)'
%
(
lhs
,
rhs
)
...
@@ -72,6 +79,7 @@ class MulOp(BinaryOp):
...
@@ -72,6 +79,7 @@ class MulOp(BinaryOp):
return
rhs
return
rhs
return
{
elhs
*
erhs
:
1
}
return
{
elhs
*
erhs
:
1
}
class
DivOp
(
BinaryOp
):
class
DivOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s /
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s /
%
s)'
%
(
lhs
,
rhs
)
...
@@ -86,6 +94,7 @@ class DivOp(BinaryOp):
...
@@ -86,6 +94,7 @@ class DivOp(BinaryOp):
elhs
=
canonical_to_expr
(
lhs
)
elhs
=
canonical_to_expr
(
lhs
)
return
{
elhs
/
erhs
:
1
}
return
{
elhs
/
erhs
:
1
}
class
MaxOp
(
BinaryOp
):
class
MaxOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'max(
%
s,
%
s)'
%
(
lhs
,
rhs
)
return
'max(
%
s,
%
s)'
%
(
lhs
,
rhs
)
...
@@ -97,6 +106,7 @@ class MaxOp(BinaryOp):
...
@@ -97,6 +106,7 @@ class MaxOp(BinaryOp):
return
lhs
if
ediff
.
value
>=
0
else
rhs
return
lhs
if
ediff
.
value
>=
0
else
rhs
return
{
MaxOp
()(
lhs
,
rhs
):
1
}
return
{
MaxOp
()(
lhs
,
rhs
):
1
}
class
MinOp
(
BinaryOp
):
class
MinOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'min(
%
s,
%
s)'
%
(
lhs
,
rhs
)
return
'min(
%
s,
%
s)'
%
(
lhs
,
rhs
)
...
@@ -120,3 +130,16 @@ _expr.__addop__ = add
...
@@ -120,3 +130,16 @@ _expr.__addop__ = add
_expr
.
__subop__
=
sub
_expr
.
__subop__
=
sub
_expr
.
__mulop__
=
mul
_expr
.
__mulop__
=
mul
_expr
.
__divop__
=
div
_expr
.
__divop__
=
div
def
reduce_sum
(
expr
,
rdom
):
return
_expr
.
ReduceExpr
(
add
,
expr
,
rdom
)
def
reduce_prod
(
expr
,
rdom
):
return
_expr
.
ReduceExpr
(
mul
,
expr
,
rdom
)
def
reduce_min
(
expr
,
rdom
):
return
_expr
.
ReduceExpr
(
min
,
expr
,
rdom
)
def
reduce_max
(
expr
,
rdom
):
return
_expr
.
ReduceExpr
(
max
,
expr
,
rdom
)
python/tvm/tensor.py
View file @
6819145a
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
expr_util
as
_expr_util
class
TensorReadExpr
(
_expr
.
Expr
):
def
__init__
(
self
,
tensor
,
indices
):
self
.
tensor
=
tensor
self
.
indices
=
indices
def
children
(
self
):
return
self
.
indices
class
Tensor
(
object
):
class
Tensor
(
object
):
def
__init__
(
self
,
ndim
,
fcompute
=
None
,
name
=
None
):
def
__init__
(
self
,
ndim
,
fcompute
=
None
,
name
=
None
,
shape
=
None
):
self
.
ndim
=
ndim
self
.
ndim
=
ndim
if
fcompute
:
if
fcompute
:
arg_names
=
fcompute
.
func_code
.
co_varnames
arg_names
=
fcompute
.
func_code
.
co_varnames
assert
(
len
(
arg_names
)
==
ndim
)
assert
(
len
(
arg_names
)
==
ndim
)
self
.
dim_index
=
[
_expr
.
Var
(
n
)
for
n
in
arg_names
]
self
.
dim_index
=
[
_expr
.
Var
(
n
)
for
n
in
arg_names
]
self
.
expr
=
fcompute
(
*
self
.
dim_index
)
self
.
expr
=
fcompute
(
*
self
.
dim_index
)
if
shape
is
None
:
raise
ValueError
(
"argument shape need to be given for intermediate tensor"
)
self
.
shape
=
shape
else
:
else
:
self
.
expr
=
None
self
.
expr
=
None
self
.
dim_index
=
None
self
.
dim_index
=
None
shape_name
=
'_shape'
shape_name
=
'_shape'
if
name
:
shape_name
=
name
+
shape_name
if
name
:
shape_name
=
name
+
shape_name
self
.
shape
=
tuple
(
_expr
.
Var
(
"
%
s_
%
d_"
%
(
shape_name
,
i
))
for
i
in
range
(
ndim
))
self
.
shape
=
shape
if
shape
else
tuple
(
_expr
.
Var
(
"
%
s_
%
d_"
%
(
shape_name
,
i
))
for
i
in
range
(
ndim
))
self
.
name
=
name
if
name
else
"TensorObj"
self
.
name
=
name
if
name
else
"TensorObj"
self
.
inputs
=
None
def
__call__
(
self
,
*
indices
):
def
__call__
(
self
,
*
indices
):
if
len
(
indices
)
!=
self
.
ndim
:
if
len
(
indices
)
!=
self
.
ndim
:
raise
ValueError
(
"Need to provide
%
d index in tensor slice"
%
self
.
ndim
)
raise
ValueError
(
"Need to provide
%
d index in tensor slice"
%
self
.
ndim
)
return
TensorReadExpr
(
self
,
indices
)
return
_expr
.
TensorReadExpr
(
self
,
indices
)
def
input_tensors
(
self
):
"""List of input tensors to this tensor.
Returns
-------
inputs : list of input tensors
"""
if
self
.
inputs
is
not
None
:
return
self
.
inputs
self
.
inputs
=
[]
if
self
.
expr
:
def
collect
(
e
):
if
isinstance
(
e
,
_expr
.
TensorReadExpr
):
self
.
inputs
.
append
(
e
.
tensor
)
_expr_util
.
visit
(
self
.
expr
,
collect
)
return
self
.
inputs
def
infer_input_domains
(
self
,
out_domain
):
"""Infer the input domains of each domain given output domains
Parameters
----------
out_domain : list of Range
Domain of each dimension.
Returns
-------
in_domains: dict Tensor->Domain
"""
assert
self
.
expr
assert
len
(
out_domain
)
==
len
(
self
.
dim_index
)
index_domains
=
{
self
.
dim_index
[
i
]
:
out_domain
[
i
]
for
i
in
range
(
len
(
out_domain
))
}
def
collect
(
e
):
if
isinstance
(
e
,
_expr
.
TensorReadExpr
):
self
.
inputs
.
append
(
e
.
tensor
)
_expr_util
.
visit
(
self
.
expr
,
collect
)
tests/python/test_tensor.py
View file @
6819145a
...
@@ -3,8 +3,27 @@ import tvm
...
@@ -3,8 +3,27 @@ import tvm
def
test_tensor
():
def
test_tensor
():
A
=
tvm
.
Tensor
(
2
,
name
=
'A'
)
A
=
tvm
.
Tensor
(
2
,
name
=
'A'
)
B
=
tvm
.
Tensor
(
2
,
name
=
'B'
)
B
=
tvm
.
Tensor
(
2
,
name
=
'B'
)
T
=
tvm
.
Tensor
(
3
,
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
))
T
=
tvm
.
Tensor
(
3
,
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
),
shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
],
A
.
shape
[
1
]))
print
(
tvm
.
format_str
(
T
.
expr
))
print
(
tvm
.
format_str
(
T
.
expr
))
def
test_tensor_inputs
():
A
=
tvm
.
Tensor
(
2
,
name
=
'A'
)
B
=
tvm
.
Tensor
(
2
,
name
=
'B'
)
T
=
tvm
.
Tensor
(
3
,
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
),
shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
],
A
.
shape
[
1
]))
assert
(
T
.
input_tensors
()
==
[
A
,
B
])
def
test_tensor_reduce
():
A
=
tvm
.
Tensor
(
2
,
name
=
'A'
)
B
=
tvm
.
Tensor
(
2
,
name
=
'B'
)
T
=
tvm
.
Tensor
(
3
,
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
),
shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
],
A
.
shape
[
1
]))
rd
=
tvm
.
RDom
(
tvm
.
Range
(
A
.
shape
[
1
]))
C
=
tvm
.
Tensor
(
2
,
lambda
i
,
j
:
tvm
.
reduce_sum
(
T
(
i
,
j
,
rd
.
index
[
0
]),
rdom
=
rd
),
shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
]))
print
(
tvm
.
format_str
(
C
.
expr
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_tensor
()
test_tensor_inputs
()
test_tensor_reduce
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment