Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
57a74936
Commit
57a74936
authored
Jan 06, 2017
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rename dim_var to axis, update testcases
parent
ff26cd68
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
69 additions
and
45 deletions
+69
-45
include/tvm/c_api.h
+4
-2
include/tvm/operation.h
+4
-4
python/tvm/_ctypes/_api.py
+10
-2
src/c_api/c_api.cc
+10
-2
src/lang/expr.cc
+3
-3
src/lang/operation.cc
+10
-10
tests/python/test_basic.py
+9
-2
tests/python/test_bound_inference.py
+9
-9
tests/python/test_inline.py
+3
-4
tests/python/test_schedule.py
+6
-6
tests/python/test_tensor.py
+1
-1
No files found.
include/tvm/c_api.h
View file @
57a74936
...
@@ -129,12 +129,14 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
...
@@ -129,12 +129,14 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
* \param handle The node handle
* \param handle The node handle
* \param key The attribute name
* \param key The attribute name
* \param out_value The attribute value
* \param out_value The attribute value
* \param out_typeid The typeif of the attribute.
* \param out_typeid The typeid of the attribute.
* \param out_success Whether get is successful.
*/
*/
TVM_DLL
int
TVMNodeGetAttr
(
NodeHandle
handle
,
TVM_DLL
int
TVMNodeGetAttr
(
NodeHandle
handle
,
const
char
*
key
,
const
char
*
key
,
ArgVariant
*
out_value
,
ArgVariant
*
out_value
,
int
*
out_typeid
);
int
*
out_typeid
,
int
*
out_success
);
/*!
/*!
* \brief get attributes names in the node.
* \brief get attributes names in the node.
...
...
include/tvm/operation.h
View file @
57a74936
...
@@ -17,8 +17,8 @@ namespace tvm {
...
@@ -17,8 +17,8 @@ namespace tvm {
*/
*/
class
ComputeOpNode
:
public
OperationNode
{
class
ComputeOpNode
:
public
OperationNode
{
public
:
public
:
/*! \brief Iter
ation variables over the dimension
s */
/*! \brief Iter
Var on each axi
s */
Array
<
IterVar
>
dim_var
;
Array
<
IterVar
>
axis
;
/*! \brief the compute expression */
/*! \brief the compute expression */
Expr
body
;
Expr
body
;
/*! \brief constructor */
/*! \brief constructor */
...
@@ -34,11 +34,11 @@ class ComputeOpNode : public OperationNode {
...
@@ -34,11 +34,11 @@ class ComputeOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"
dim_var"
,
&
dim_var
);
v
->
Visit
(
"
axis"
,
&
axis
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"body"
,
&
body
);
}
}
static
Operation
make
(
std
::
string
name
,
static
Operation
make
(
std
::
string
name
,
Array
<
IterVar
>
dim_var
,
Array
<
IterVar
>
axis
,
Expr
body
);
Expr
body
);
static
constexpr
const
char
*
_type_key
=
"ComputeOp"
;
static
constexpr
const
char
*
_type_key
=
"ComputeOp"
;
...
...
python/tvm/_ctypes/_api.py
View file @
57a74936
...
@@ -72,10 +72,18 @@ class NodeBase(object):
...
@@ -72,10 +72,18 @@ class NodeBase(object):
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
):
ret_val
=
ArgVariant
()
ret_val
=
ArgVariant
()
ret_typeid
=
ctypes
.
c_int
()
ret_typeid
=
ctypes
.
c_int
()
ret_success
=
ctypes
.
c_int
()
check_call
(
_LIB
.
TVMNodeGetAttr
(
check_call
(
_LIB
.
TVMNodeGetAttr
(
self
.
handle
,
c_str
(
name
),
self
.
handle
,
c_str
(
name
),
ctypes
.
byref
(
ret_val
),
ctypes
.
byref
(
ret_typeid
)))
ctypes
.
byref
(
ret_val
),
return
RET_SWITCH
[
ret_typeid
.
value
](
ret_val
)
ctypes
.
byref
(
ret_typeid
),
ctypes
.
byref
(
ret_success
)))
value
=
RET_SWITCH
[
ret_typeid
.
value
](
ret_val
)
if
not
ret_success
.
value
:
raise
AttributeError
(
"'
%
s' object has no attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
return
value
def
__hash__
(
self
):
def
__hash__
(
self
):
return
_function_internal
.
_raw_ptr
(
self
)
return
_function_internal
.
_raw_ptr
(
self
)
...
...
src/c_api/c_api.cc
View file @
57a74936
...
@@ -37,6 +37,7 @@ using TVMAPINode = std::shared_ptr<Node>;
...
@@ -37,6 +37,7 @@ using TVMAPINode = std::shared_ptr<Node>;
struct
APIAttrGetter
:
public
AttrVisitor
{
struct
APIAttrGetter
:
public
AttrVisitor
{
std
::
string
skey
;
std
::
string
skey
;
APIVariantValue
*
ret
;
APIVariantValue
*
ret
;
bool
found_node_ref
{
false
};
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
if
(
skey
==
key
)
*
ret
=
value
[
0
];
if
(
skey
==
key
)
*
ret
=
value
[
0
];
...
@@ -62,7 +63,10 @@ struct APIAttrGetter : public AttrVisitor {
...
@@ -62,7 +63,10 @@ struct APIAttrGetter : public AttrVisitor {
if
(
skey
==
key
)
*
ret
=
value
[
0
];
if
(
skey
==
key
)
*
ret
=
value
[
0
];
}
}
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
if
(
skey
==
key
)
*
ret
=
value
[
0
];
if
(
skey
==
key
)
{
*
ret
=
value
[
0
];
found_node_ref
=
true
;
}
}
}
};
};
...
@@ -198,7 +202,8 @@ int TVMNodeFree(NodeHandle handle) {
...
@@ -198,7 +202,8 @@ int TVMNodeFree(NodeHandle handle) {
int
TVMNodeGetAttr
(
NodeHandle
handle
,
int
TVMNodeGetAttr
(
NodeHandle
handle
,
const
char
*
key
,
const
char
*
key
,
ArgVariant
*
ret_val
,
ArgVariant
*
ret_val
,
int
*
ret_typeid
)
{
int
*
ret_typeid
,
int
*
ret_success
)
{
TVMAPIThreadLocalEntry
*
ret
=
TVMAPIThreadLocalStore
::
Get
();
TVMAPIThreadLocalEntry
*
ret
=
TVMAPIThreadLocalStore
::
Get
();
API_BEGIN
();
API_BEGIN
();
ret
->
ret_value
.
type_id
=
kNull
;
ret
->
ret_value
.
type_id
=
kNull
;
...
@@ -209,11 +214,14 @@ int TVMNodeGetAttr(NodeHandle handle,
...
@@ -209,11 +214,14 @@ int TVMNodeGetAttr(NodeHandle handle,
if
(
getter
.
skey
==
"type_key"
)
{
if
(
getter
.
skey
==
"type_key"
)
{
ret_val
->
v_str
=
(
*
tnode
)
->
type_key
();
ret_val
->
v_str
=
(
*
tnode
)
->
type_key
();
*
ret_typeid
=
kStr
;
*
ret_typeid
=
kStr
;
*
ret_success
=
1
;
}
else
{
}
else
{
(
*
tnode
)
->
VisitAttrs
(
&
getter
);
(
*
tnode
)
->
VisitAttrs
(
&
getter
);
if
(
ret
->
ret_value
.
type_id
!=
kNull
)
{
if
(
ret
->
ret_value
.
type_id
!=
kNull
)
{
ret
->
SetReturn
(
ret_val
,
ret_typeid
);
ret
->
SetReturn
(
ret_val
,
ret_typeid
);
*
ret_success
=
1
;
}
else
{
}
else
{
*
ret_success
=
getter
.
found_node_ref
?
1
:
0
;
*
ret_typeid
=
kNull
;
*
ret_typeid
=
kNull
;
}
}
}
}
...
...
src/lang/expr.cc
View file @
57a74936
...
@@ -13,10 +13,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
...
@@ -13,10 +13,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
}
// namespace dmlc
}
// namespace dmlc
namespace
tvm
{
namespace
tvm
{
Range
::
Range
(
Expr
begin
,
Expr
end
)
Range
::
Range
(
Expr
begin
,
Expr
end
)
:
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
begin
,
end
-
begin
))
{
:
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
// TODO(tqchen) add simplify to end - begin
begin
,
is_zero
(
begin
)
?
end
:
(
end
-
begin
)))
{
}
}
Range
Range
::
make_with_min_extent
(
Expr
min
,
Expr
extent
)
{
Range
Range
::
make_with_min_extent
(
Expr
min
,
Expr
extent
)
{
...
...
src/lang/operation.cc
View file @
57a74936
...
@@ -18,27 +18,27 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
...
@@ -18,27 +18,27 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
// compute dimension.
// compute dimension.
size_t
ndim
=
shape
.
size
();
size_t
ndim
=
shape
.
size
();
std
::
vector
<
IterVar
>
dim_var
;
std
::
vector
<
IterVar
>
axis
;
std
::
vector
<
Var
>
args
;
std
::
vector
<
Var
>
args
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
"
dim_var
"
<<
i
;
os
<<
"
ax
"
<<
i
;
dim_var
.
push
_back
(
IterVar
(
Range
(
0
,
shape
[
i
]),
os
.
str
()));
axis
.
emplace
_back
(
IterVar
(
Range
(
0
,
shape
[
i
]),
os
.
str
()));
args
.
push_back
(
dim_var
.
back
()
->
var
);
args
.
push_back
(
axis
.
back
()
->
var
);
}
}
op_node
->
dim_var
=
Array
<
IterVar
>
(
dim_var
);
op_node
->
axis
=
Array
<
IterVar
>
(
axis
);
op_node
->
body
=
fcompute
(
args
);
op_node
->
body
=
fcompute
(
args
);
op_node
->
name
=
name
;
op_node
->
name
=
name
;
return
Operation
(
op_node
).
output
(
0
);
return
Operation
(
op_node
).
output
(
0
);
}
}
Operation
ComputeOpNode
::
make
(
std
::
string
name
,
Operation
ComputeOpNode
::
make
(
std
::
string
name
,
Array
<
IterVar
>
dim_var
,
Array
<
IterVar
>
axis
,
Expr
body
)
{
Expr
body
)
{
auto
n
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
n
=
std
::
make_shared
<
ComputeOpNode
>
();
n
->
name
=
name
;
n
->
name
=
name
;
n
->
dim_var
=
dim_var
;
n
->
axis
=
axis
;
n
->
body
=
body
;
n
->
body
=
body
;
return
Operation
(
n
);
return
Operation
(
n
);
}
}
...
@@ -54,7 +54,7 @@ Tensor Operation::output(size_t i) const {
...
@@ -54,7 +54,7 @@ Tensor Operation::output(size_t i) const {
}
}
Array
<
IterVar
>
ComputeOpNode
::
root_iter_vars
()
const
{
Array
<
IterVar
>
ComputeOpNode
::
root_iter_vars
()
const
{
return
dim_var
;
return
axis
;
}
}
std
::
string
ComputeOpNode
::
output_name
(
size_t
i
)
const
{
std
::
string
ComputeOpNode
::
output_name
(
size_t
i
)
const
{
...
@@ -70,8 +70,8 @@ Type ComputeOpNode::output_dtype(size_t i) const {
...
@@ -70,8 +70,8 @@ Type ComputeOpNode::output_dtype(size_t i) const {
Array
<
Expr
>
ComputeOpNode
::
output_shape
(
size_t
i
)
const
{
Array
<
Expr
>
ComputeOpNode
::
output_shape
(
size_t
i
)
const
{
CHECK_EQ
(
i
,
0U
);
CHECK_EQ
(
i
,
0U
);
std
::
vector
<
Expr
>
shape
;
std
::
vector
<
Expr
>
shape
;
for
(
size_t
i
=
0
;
i
<
dim_var
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
const
Range
&
r
=
dim_var
[
i
]
->
dom
;
const
Range
&
r
=
axis
[
i
]
->
dom
;
shape
.
push_back
(
r
->
extent
);
shape
.
push_back
(
r
->
extent
);
}
}
return
Array
<
Expr
>
(
shape
);
return
Array
<
Expr
>
(
shape
);
...
...
tests/python/test_basic.py
View file @
57a74936
...
@@ -30,7 +30,15 @@ def test_attr():
...
@@ -30,7 +30,15 @@ def test_attr():
stmt
=
tvm
.
make
.
AttrStmt
(
stmt
=
tvm
.
make
.
AttrStmt
(
y
,
"stride"
,
10
,
tvm
.
make
.
Evaluate
(
x
+
1
));
y
,
"stride"
,
10
,
tvm
.
make
.
Evaluate
(
x
+
1
));
assert
stmt
.
node
==
y
assert
stmt
.
node
==
y
print
(
stmt
)
a
=
tvm
.
convert
(
1
)
assert
a
.
value
==
1
try
:
a
.
no_field
assert
False
except
AttributeError
:
pass
def
test_basic
():
def
test_basic
():
a
=
tvm
.
Var
(
'a'
)
a
=
tvm
.
Var
(
'a'
)
...
@@ -48,7 +56,6 @@ def test_stmt():
...
@@ -48,7 +56,6 @@ def test_stmt():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_attr
()
test_attr
()
test_const
()
test_const
()
test_make
()
test_make
()
test_ir
()
test_ir
()
...
...
tests/python/test_bound_inference.py
View file @
57a74936
...
@@ -8,11 +8,11 @@ def test_bound1():
...
@@ -8,11 +8,11 @@ def test_bound1():
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
,
name
=
'A2'
)
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
,
name
=
'A2'
)
sA1
=
tvm
.
Schedule
(
A1
.
op
)
sA1
=
tvm
.
Schedule
(
A1
.
op
)
sA2
=
tvm
.
Schedule
(
A2
.
op
)
sA2
=
tvm
.
Schedule
(
A2
.
op
)
xo
,
xi
=
sA2
.
split
(
A2
.
op
.
dim_var
[
0
],
8
)
xo
,
xi
=
sA2
.
split
(
A2
.
op
.
axis
[
0
],
8
)
sA1
.
compute_at
(
sA2
,
xo
)
sA1
.
compute_at
(
sA2
,
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA2
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA2
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
(
bounds
[
A1
.
op
.
dim_var
[
0
]]
.
extent
.
value
==
8
)
assert
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
.
value
==
8
)
def
test_bound2
():
def
test_bound2
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
@@ -22,12 +22,12 @@ def test_bound2():
...
@@ -22,12 +22,12 @@ def test_bound2():
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
,
name
=
'A2'
)
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
,
name
=
'A2'
)
sA1
=
tvm
.
Schedule
(
A1
.
op
)
sA1
=
tvm
.
Schedule
(
A1
.
op
)
sA2
=
tvm
.
Schedule
(
A2
.
op
)
sA2
=
tvm
.
Schedule
(
A2
.
op
)
xo
,
yo
,
xi
,
yi
=
sA2
.
tile
(
A2
.
op
.
dim_var
[
0
],
A2
.
op
.
dim_var
[
1
],
8
,
8
)
xo
,
yo
,
xi
,
yi
=
sA2
.
tile
(
A2
.
op
.
axis
[
0
],
A2
.
op
.
axis
[
1
],
8
,
8
)
sA1
.
compute_at
(
sA2
,
yo
)
sA1
.
compute_at
(
sA2
,
yo
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA2
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA2
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
(
bounds
[
A1
.
op
.
dim_var
[
0
]]
.
extent
.
value
==
8
)
assert
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
.
value
==
8
)
assert
(
bounds
[
A1
.
op
.
dim_var
[
1
]]
.
extent
.
value
==
8
)
assert
(
bounds
[
A1
.
op
.
axis
[
1
]]
.
extent
.
value
==
8
)
def
test_bound3
():
def
test_bound3
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
@@ -38,16 +38,16 @@ def test_bound3():
...
@@ -38,16 +38,16 @@ def test_bound3():
sA1
=
tvm
.
Schedule
(
A1
.
op
,
scope
=
"shared"
)
sA1
=
tvm
.
Schedule
(
A1
.
op
,
scope
=
"shared"
)
sA2
=
tvm
.
Schedule
(
A2
.
op
)
sA2
=
tvm
.
Schedule
(
A2
.
op
)
thread_x
=
tvm
.
IterVar
((
0
,
16
),
thread_tag
=
"threadIdx.x"
)
thread_x
=
tvm
.
IterVar
((
0
,
16
),
thread_tag
=
"threadIdx.x"
)
xo
,
xi
=
sA2
.
split
(
A2
.
op
.
dim_var
[
0
],
32
)
xo
,
xi
=
sA2
.
split
(
A2
.
op
.
axis
[
0
],
32
)
xi0
,
xi1
=
sA2
.
split
(
xi
,
outer
=
thread_x
)
xi0
,
xi1
=
sA2
.
split
(
xi
,
outer
=
thread_x
)
yo
,
yi
=
sA2
.
split
(
A2
.
op
.
dim_var
[
1
],
16
)
yo
,
yi
=
sA2
.
split
(
A2
.
op
.
axis
[
1
],
16
)
sA2
.
reorder
(
xo
,
xi0
,
yo
,
xi1
,
yi
)
sA2
.
reorder
(
xo
,
xi0
,
yo
,
xi1
,
yi
)
sA1
.
compute_at
(
sA2
,
yo
)
sA1
.
compute_at
(
sA2
,
yo
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA2
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA2
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
(
bounds
[
A1
.
op
.
dim_var
[
0
]]
.
extent
.
value
==
32
)
assert
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
.
value
==
32
)
assert
(
bounds
[
A1
.
op
.
dim_var
[
1
]]
.
extent
.
value
==
16
)
assert
(
bounds
[
A1
.
op
.
axis
[
1
]]
.
extent
.
value
==
16
)
def
test_create_read_graph
():
def
test_create_read_graph
():
...
...
tests/python/test_inline.py
View file @
57a74936
...
@@ -3,11 +3,10 @@ import tvm
...
@@ -3,11 +3,10 @@ import tvm
def
test_inline
():
def
test_inline
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
A
=
tvm
.
placeholder
((
m
,),
name
=
'A'
)
A
=
tvm
.
placeholder
((
m
,),
name
=
'A'
)
T
=
tvm
.
compute
((
m
,),
lambda
i
,:
A
(
i
)
+
10
,
name
=
'T'
)
T
=
tvm
.
compute
((
m
,),
lambda
i
,:
A
[
i
]
+
10
,
name
=
'T'
)
X
=
T
(
100
)
stmt
=
tvm
.
make
.
Evaluate
(
T
[
10
]
+
11
*
T
[
100
])
stmt
=
tvm
.
make
.
Evaluate
(
T
(
10
)
+
11
*
T
(
100
))
stmt
=
tvm
.
ir_pass
.
Inline
(
stmt
=
tvm
.
ir_pass
.
Inline
(
T
,
T
.
op
.
dim_var
,
T
.
op
.
body
,
stmt
)
T
,
[
x
.
var
for
x
in
T
.
op
.
axis
]
,
T
.
op
.
body
,
stmt
)
print
(
stmt
)
print
(
stmt
)
assert
(
tvm
.
ir_pass
.
VerifySSA
(
stmt
))
assert
(
tvm
.
ir_pass
.
VerifySSA
(
stmt
))
...
...
tests/python/test_schedule.py
View file @
57a74936
...
@@ -12,14 +12,14 @@ def test_schedule_create():
...
@@ -12,14 +12,14 @@ def test_schedule_create():
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
sch_A
=
tvm
.
Schedule
(
AA
.
op
,
scope
=
"global"
)
sch_A
=
tvm
.
Schedule
(
AA
.
op
,
scope
=
"global"
)
xo
,
xi
=
sch_T
.
split
(
T
.
op
.
dim_var
[
0
],
factor
=
10
)
xo
,
xi
=
sch_T
.
split
(
T
.
op
.
axis
[
0
],
factor
=
10
)
xi1
,
xi2
=
sch_T
.
split
(
xi
,
factor
=
2
)
xi1
,
xi2
=
sch_T
.
split
(
xi
,
factor
=
2
)
sch_A
.
compute_at
(
sch_T
,
xi1
)
sch_A
.
compute_at
(
sch_T
,
xi1
)
xo
,
xi
=
sch_A
.
split
(
AA
.
op
.
dim_var
[
0
],
factor
=
10
)
xo
,
xi
=
sch_A
.
split
(
AA
.
op
.
axis
[
0
],
factor
=
10
)
sch_T
.
reorder
(
xi2
,
xi1
)
sch_T
.
reorder
(
xi2
,
xi1
)
assert
T
.
op
.
dim_var
[
1
]
in
sch_T
.
leaf_iter_vars
assert
T
.
op
.
axis
[
1
]
in
sch_T
.
leaf_iter_vars
def
test_reorder
():
def
test_reorder
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
@@ -27,7 +27,7 @@ def test_reorder():
...
@@ -27,7 +27,7 @@ def test_reorder():
T
=
tvm
.
compute
(
m
,
lambda
i
:
A
[
i
+
1
])
T
=
tvm
.
compute
(
m
,
lambda
i
:
A
[
i
+
1
])
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
xo
,
xi
=
sch_T
.
split
(
T
.
op
.
dim_var
[
0
],
factor
=
10
)
xo
,
xi
=
sch_T
.
split
(
T
.
op
.
axis
[
0
],
factor
=
10
)
xi1
,
xi2
=
sch_T
.
split
(
xi
,
factor
=
2
)
xi1
,
xi2
=
sch_T
.
split
(
xi
,
factor
=
2
)
order
=
(
xi2
,
xi1
,
xo
)
order
=
(
xi2
,
xi1
,
xo
)
assert
tuple
(
sch_T
.
leaf_iter_vars
)
!=
order
assert
tuple
(
sch_T
.
leaf_iter_vars
)
!=
order
...
@@ -40,7 +40,7 @@ def test_split():
...
@@ -40,7 +40,7 @@ def test_split():
T
=
tvm
.
compute
((
m
,),
lambda
i
:
A
[
i
])
T
=
tvm
.
compute
((
m
,),
lambda
i
:
A
[
i
])
sT
=
tvm
.
Schedule
(
T
.
op
)
sT
=
tvm
.
Schedule
(
T
.
op
)
xo
,
xi
=
sT
.
split
(
T
.
op
.
dim_var
[
0
],
factor
=
10
)
xo
,
xi
=
sT
.
split
(
T
.
op
.
axis
[
0
],
factor
=
10
)
assert
tuple
(
sT
.
leaf_iter_vars
)
==
(
xo
,
xi
)
assert
tuple
(
sT
.
leaf_iter_vars
)
==
(
xo
,
xi
)
...
@@ -51,7 +51,7 @@ def test_tile():
...
@@ -51,7 +51,7 @@ def test_tile():
T
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
A
[
i
,
j
])
T
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
A
[
i
,
j
])
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
xo
,
yo
,
xi
,
yi
=
sch_T
.
tile
(
T
.
op
.
dim_var
[
0
],
T
.
op
.
dim_var
[
1
],
x_factor
=
10
,
y_factor
=
5
)
xo
,
yo
,
xi
,
yi
=
sch_T
.
tile
(
T
.
op
.
axis
[
0
],
T
.
op
.
axis
[
1
],
x_factor
=
10
,
y_factor
=
5
)
assert
tuple
(
sch_T
.
leaf_iter_vars
)
==
(
xo
,
yo
,
xi
,
yi
)
assert
tuple
(
sch_T
.
leaf_iter_vars
)
==
(
xo
,
yo
,
xi
,
yi
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/python/test_tensor.py
View file @
57a74936
...
@@ -10,7 +10,7 @@ def test_tensor():
...
@@ -10,7 +10,7 @@ def test_tensor():
print
(
T
)
print
(
T
)
print
(
T
.
op
.
body
)
print
(
T
.
op
.
body
)
assert
(
tuple
(
T
.
shape
)
==
(
m
,
n
,
l
))
assert
(
tuple
(
T
.
shape
)
==
(
m
,
n
,
l
))
assert
(
A
.
source
is
None
)
assert
(
A
.
op
is
None
)
def
test_tensor_reduce
():
def
test_tensor_reduce
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment