Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f03483bf
Commit
f03483bf
authored
Oct 14, 2016
by
Haichen Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
checked split
parent
1a18f08e
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
5 deletions
+51
-5
python/tvm/__init__.py
+1
-0
python/tvm/domain.py
+1
-1
python/tvm/op.py
+2
-2
python/tvm/split.py
+22
-0
python/tvm/tensor.py
+0
-1
tests/python/test_basic.py
+1
-1
tests/python/test_split.py
+24
-0
No files found.
python/tvm/__init__.py
View file @
f03483bf
...
@@ -6,3 +6,4 @@ from .expr import Var, const
...
@@ -6,3 +6,4 @@ from .expr import Var, const
from
.expr_util
import
*
from
.expr_util
import
*
from
.tensor
import
Tensor
from
.tensor
import
Tensor
from
.domain
import
RDom
,
Range
,
infer_range
from
.domain
import
RDom
,
Range
,
infer_range
from
.split
import
Split
python/tvm/domain.py
View file @
f03483bf
...
@@ -17,7 +17,7 @@ class Range(object):
...
@@ -17,7 +17,7 @@ class Range(object):
self
.
extent
=
_expr_util
.
simplify
(
end
-
begin
)
self
.
extent
=
_expr_util
.
simplify
(
end
-
begin
)
def
is_value
(
self
):
def
is_value
(
self
):
return
isinstance
(
self
.
extent
,
_expr
.
ConstExpr
)
and
self
.
exten
d
.
value
==
1
return
isinstance
(
self
.
extent
,
_expr
.
ConstExpr
)
and
self
.
exten
t
.
value
==
1
def
__str__
(
self
):
def
__str__
(
self
):
return
"(
%
s,
%
s)"
%
(
return
"(
%
s,
%
s)"
%
(
...
...
python/tvm/op.py
View file @
f03483bf
...
@@ -6,7 +6,7 @@ constant_canonical_key = '__constant__'
...
@@ -6,7 +6,7 @@ constant_canonical_key = '__constant__'
def
canonical_to_expr
(
c
):
def
canonical_to_expr
(
c
):
elements
=
[]
elements
=
[]
for
k
,
v
in
sorted
(
c
.
items
()):
for
k
,
v
in
sorted
(
c
.
items
()):
if
k
==
constant_canonical_key
:
if
k
==
constant_canonical_key
and
v
!=
0
:
elements
.
append
(
_expr
.
const
(
v
))
elements
.
append
(
_expr
.
const
(
v
))
elif
v
==
0
:
elif
v
==
0
:
continue
continue
...
@@ -87,7 +87,7 @@ class DivOp(BinaryOp):
...
@@ -87,7 +87,7 @@ class DivOp(BinaryOp):
if
isinstance
(
erhs
,
_expr
.
ConstExpr
):
if
isinstance
(
erhs
,
_expr
.
ConstExpr
):
lhs
=
lhs
.
copy
()
lhs
=
lhs
.
copy
()
for
k
,
v
in
lhs
.
items
():
for
k
,
v
in
lhs
.
items
():
lhs
[
k
]
/=
erhs
.
value
lhs
[
k
]
/=
float
(
erhs
.
value
)
return
lhs
return
lhs
elhs
=
canonical_to_expr
(
lhs
)
elhs
=
canonical_to_expr
(
lhs
)
return
{
elhs
/
erhs
:
1
}
return
{
elhs
/
erhs
:
1
}
...
...
python/tvm/split.py
0 → 100644
View file @
f03483bf
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
from
.
import
domain
as
_dom
from
.
import
tensor
as
_tensor
class
Split
(
object
):
def
__init__
(
self
,
dim
,
factor
):
self
.
dim
=
dim
self
.
factor
=
factor
self
.
loop_index
=
_expr
.
Var
(
'loop_index_
%
d_'
%
dim
)
def
infer_inner_domain
(
self
,
domain
):
if
isinstance
(
domain
,
_dom
.
RDom
):
domain
=
domain
.
domain
assert
self
.
dim
<
len
(
domain
)
inner_domain
=
domain
[:]
dim_out_range
=
domain
[
self
.
dim
]
dim_inner_begin
=
dim_out_range
.
begin
+
self
.
loop_index
*
self
.
factor
inner_domain
[
self
.
dim
]
=
_dom
.
Range
(
dim_inner_begin
,
dim_inner_begin
+
self
.
factor
)
return
inner_domain
python/tvm/tensor.py
View file @
f03483bf
...
@@ -25,7 +25,6 @@ class Tensor(object):
...
@@ -25,7 +25,6 @@ class Tensor(object):
self
.
name
=
name
if
name
else
"TensorObj"
self
.
name
=
name
if
name
else
"TensorObj"
self
.
inputs
=
None
self
.
inputs
=
None
self
.
rdom
=
None
def
__call__
(
self
,
*
indices
):
def
__call__
(
self
,
*
indices
):
if
len
(
indices
)
!=
self
.
ndim
:
if
len
(
indices
)
!=
self
.
ndim
:
...
...
tests/python/test_basic.py
View file @
f03483bf
...
@@ -26,6 +26,6 @@ def test_simplify():
...
@@ -26,6 +26,6 @@ def test_simplify():
assert
tvm
.
format_str
(
tvm
.
simplify
(
e4
))
==
'0'
assert
tvm
.
format_str
(
tvm
.
simplify
(
e4
))
==
'0'
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_simplify
()
test_basic
()
test_basic
()
test_bind
()
test_bind
()
test_simplify
()
tests/python/test_split.py
0 → 100644
View file @
f03483bf
import
tvm
def
test_split_dom_infer
():
A
=
tvm
.
Tensor
(
2
,
name
=
'A'
)
rd
=
tvm
.
RDom
(
tvm
.
Range
(
A
.
shape
[
1
]))
split1
=
tvm
.
Split
(
0
,
64
)
split2
=
tvm
.
Split
(
1
,
64
)
split3
=
tvm
.
Split
(
0
,
8
)
dom
=
[
tvm
.
Range
(
A
.
shape
[
0
]),
tvm
.
Range
(
A
.
shape
[
1
])]
dom1
=
split1
.
infer_inner_domain
(
dom
)
dom2
=
split2
.
infer_inner_domain
(
dom1
)
dom3
=
split3
.
infer_inner_domain
(
dom2
)
dom4
=
split3
.
infer_inner_domain
(
rd
)
i1
=
split1
.
loop_index
.
name
i2
=
split2
.
loop_index
.
name
i3
=
split3
.
loop_index
.
name
assert
str
(
dom1
)
==
"[((
%
s * 64), ((
%
s * 64) + 64)), (0, A_shape_1_0)]"
%
(
i1
,
i1
)
assert
str
(
dom2
)
==
"[((
%
s * 64), ((
%
s * 64) + 64)), ((
%
s * 64), ((
%
s * 64) + 64))]"
%
(
i1
,
i1
,
i2
,
i2
)
assert
str
(
dom3
)
==
"[(((
%
s * 64) + (
%
s * 8)), (((
%
s * 64) + (
%
s * 8)) + 8)), ((
%
s * 64), ((
%
s * 64) + 64))]"
%
(
i1
,
i3
,
i1
,
i3
,
i2
,
i2
)
assert
str
(
dom4
)
==
"[((
%
s * 8), ((
%
s * 8) + 8))]"
%
(
i3
,
i3
)
if
__name__
==
"__main__"
:
test_split_dom_infer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment