Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
dcddd208
Commit
dcddd208
authored
Oct 13, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
finish tensor dom infer
parent
6819145a
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
131 additions
and
16 deletions
+131
-16
python/tvm/__init__.py
+1
-1
python/tvm/domain.py
+74
-5
python/tvm/op.py
+0
-2
python/tvm/tensor.py
+29
-8
tests/python/test_domain.py
+27
-0
No files found.
python/tvm/__init__.py
View file @
dcddd208
...
...
@@ -5,4 +5,4 @@ from .op import *
from
.expr
import
Var
,
const
from
.expr_util
import
*
from
.tensor
import
Tensor
from
.domain
import
RDom
,
Range
from
.domain
import
RDom
,
Range
,
infer_range
python/tvm/domain.py
View file @
dcddd208
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
from
.
import
expr_util
as
_expr_util
from
.
import
op
as
_op
class
Range
(
object
):
"""Represent a range in one dimension.
...
...
@@ -10,10 +10,15 @@ class Range(object):
if
end
is
None
:
end
=
begin
begin
=
_expr
.
const
(
0
)
self
.
begin
=
_expr
.
_symbol
(
begin
)
self
.
end
=
_expr
.
_symbol
(
end
)
begin
=
_expr_util
.
simplify
(
_expr
.
_symbol
(
begin
))
end
=
_expr_util
.
simplify
(
_expr
.
_symbol
(
end
))
self
.
begin
=
begin
self
.
end
=
end
self
.
extent
=
_expr_util
.
simplify
(
end
-
begin
)
def
is_value
(
self
):
return
isinstance
(
self
.
extent
,
_expr
.
ConstExpr
)
and
self
.
extend
.
value
==
1
def
__str__
(
self
):
return
"(
%
s,
%
s)"
%
(
_expr_util
.
format_str
(
self
.
begin
),
...
...
@@ -22,9 +27,13 @@ class Range(object):
def
__repr__
(
self
):
return
self
.
__str__
()
class
RangeInferError
(
ValueError
):
pass
class
RDom
(
object
):
"""reduction Domain
"""
"""Reduction Domain."""
def
__init__
(
self
,
domain
):
if
isinstance
(
domain
,
Range
):
domain
=
[
domain
]
...
...
@@ -36,3 +45,63 @@ class RDom(object):
"""Use list of ranges as domain"""
Domain
=
list
def
_combine_range_binary_op
(
op
,
lhs
,
rhs
):
if
op
==
_op
.
add
:
return
Range
(
lhs
.
begin
+
rhs
.
begin
,
lhs
.
end
+
rhs
.
end
-
1
)
elif
op
==
_op
.
sub
:
return
Range
(
lhs
.
begin
-
rhs
.
end
+
1
,
lhs
.
end
-
rhs
.
begin
)
elif
op
==
_op
.
mul
:
v
=
None
if
lhs
.
is_value
():
v
=
lhs
.
begin
.
value
e
=
rhs
elif
rhs
.
is_value
():
v
=
rhs
.
begin
.
value
e
=
lhs
if
v
==
-
1
:
return
Range
(
-
e
.
end
,
-
e
.
begin
)
raise
InferRangeError
(
"donot know how to infer range for
%
s"
%
type
(
op
))
def
infer_range
(
e
,
range_dict
,
allow_unbind_var
=
True
):
"""Infer the range of result e given range of variables.
Parameters
----------
expr : Expr
Input expression
range_dict : dict of Var->Range
The variables to be replaced.
allow_unbind_var: bool
Whether allow unbinded variables
"""
def
combine_range
(
e
,
result_children
):
if
isinstance
(
e
,
_expr
.
ConstExpr
):
return
Range
(
e
,
e
+
1
)
elif
isinstance
(
e
,
_expr
.
BinaryOpExpr
):
return
_combine_range_binary_op
(
e
.
op
,
result_children
[
0
],
result_children
[
1
])
elif
isinstance
(
e
,
_expr
.
Var
):
if
e
in
range_dict
:
return
range_dict
[
e
]
else
:
if
allow_unbind_var
:
return
Range
(
e
,
e
+
1
)
else
:
raise
ValueError
(
"Cannot find var
%
s in range_dict"
%
e
.
name
)
else
:
raise
InferRangeError
(
"cannot infer range for
%
s"
%
_expr_util
.
format_str
(
e
))
return
_expr_util
.
transform
(
e
,
combine_range
)
def
union_range
(
lhs
,
rhs
):
if
lhs
is
None
:
return
rhs
if
rhs
is
None
:
return
lhs
begin
=
_op
.
min
(
lhs
.
begin
,
rhs
.
begin
)
end
=
_op
.
max
(
rhs
.
end
,
lhs
.
end
)
return
Range
(
begin
,
end
)
python/tvm/op.py
View file @
dcddd208
...
...
@@ -22,7 +22,6 @@ def canonical_to_expr(c):
else
:
return
_expr
.
const
(
0
)
class
BinaryOp
(
object
):
"""Base class of binary operator"""
def
__call__
(
self
,
lhs
,
rhs
):
...
...
@@ -45,7 +44,6 @@ class AddOp(BinaryOp):
lhs
[
k
]
=
v
return
lhs
class
SubOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s -
%
s)'
%
(
lhs
,
rhs
)
...
...
python/tvm/tensor.py
View file @
dcddd208
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
from
.
import
expr_util
as
_expr_util
from
.
import
domain
as
_dom
class
Tensor
(
object
):
...
...
@@ -39,16 +40,17 @@ class Tensor(object):
"""
if
self
.
inputs
is
not
None
:
return
self
.
inputs
self
.
inputs
=
[]
inputs
=
[]
if
self
.
expr
:
def
collect
(
e
):
if
isinstance
(
e
,
_expr
.
TensorReadExpr
):
self
.
inputs
.
append
(
e
.
tensor
)
inputs
.
append
(
e
.
tensor
)
_expr_util
.
visit
(
self
.
expr
,
collect
)
self
.
inputs
=
set
(
inputs
)
return
self
.
inputs
def
infer_input_domains
(
self
,
out_domain
):
"""Infer the input domains of each domain
given output domains
def
infer_input_domains
(
self
,
out_domain
,
inputs
):
"""Infer the input domains of each domain
in given inputs list.
Parameters
----------
...
...
@@ -64,7 +66,26 @@ class Tensor(object):
index_domains
=
{
self
.
dim_index
[
i
]
:
out_domain
[
i
]
for
i
in
range
(
len
(
out_domain
))
}
def
collect
(
e
):
if
isinstance
(
e
,
_expr
.
TensorReadExpr
):
self
.
inputs
.
append
(
e
.
tensor
)
_expr_util
.
visit
(
self
.
expr
,
collect
)
iset
=
{}
for
t
in
inputs
:
assert
t
in
self
.
input_tensors
()
iset
[
t
]
=
[]
def
prepare
(
e
):
if
isinstance
(
e
,
_expr
.
ReduceExpr
):
rd
=
e
.
rdom
for
i
in
range
(
len
(
rd
.
domain
)):
index_domains
[
rd
.
index
[
i
]]
=
rd
.
domain
[
i
]
elif
isinstance
(
e
,
_expr
.
TensorReadExpr
):
if
e
.
tensor
in
iset
:
iset
[
e
.
tensor
]
.
append
(
e
)
_expr_util
.
visit
(
self
.
expr
,
prepare
)
result
=
{}
for
k
,
v
in
iset
.
items
():
dm
=
[
None
]
*
len
(
v
[
0
]
.
indices
)
for
e
in
v
:
for
i
,
idx
in
enumerate
(
e
.
indices
):
dm
[
i
]
=
_dom
.
union_range
(
dm
[
i
],
_dom
.
infer_range
(
idx
,
index_domains
,
allow_unbind_var
=
False
))
result
[
k
]
=
dm
return
result
tests/python/test_domain.py
0 → 100644
View file @
dcddd208
import
tvm
def
test_range_infer
():
x
=
tvm
.
Var
(
'x'
)
y
=
tvm
.
Var
(
'y'
)
t
=
tvm
.
Var
(
't'
)
z
=
x
+
y
+
t
zr
=
tvm
.
infer_range
(
z
,
{
x
:
tvm
.
Range
(
10
,
20
),
y
:
tvm
.
Range
(
10
,
11
)})
assert
str
(
zr
)
==
"((t0 + 20), (t0 + 30))"
def
test_tensor_dom_infer
():
A
=
tvm
.
Tensor
(
2
,
name
=
'A'
)
B
=
tvm
.
Tensor
(
2
,
name
=
'B'
)
T
=
tvm
.
Tensor
(
3
,
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
),
shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
],
A
.
shape
[
1
]))
rd
=
tvm
.
RDom
(
tvm
.
Range
(
A
.
shape
[
1
]))
C
=
tvm
.
Tensor
(
2
,
lambda
i
,
j
:
tvm
.
reduce_sum
(
T
(
i
,
j
,
rd
.
index
[
0
]),
rdom
=
rd
),
shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
]))
cdom
=
[
tvm
.
Range
(
0
,
10
),
tvm
.
Range
(
1
,
11
)]
tdom
=
C
.
infer_input_domains
(
cdom
,
inputs
=
[
T
])[
T
]
assert
str
(
tdom
[
0
])
==
"(0, 10)"
if
__name__
==
"__main__"
:
test_range_infer
()
test_tensor_dom_infer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment