Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0f693212
Commit
0f693212
authored
Jan 05, 2017
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Pass first basic case of bound inference
parent
c5395a1f
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
222 additions
and
77 deletions
+222
-77
HalideIR
+1
-1
src/c_api/c_api_schedule.cc
+10
-1
src/lang/operation.cc
+5
-0
src/schedule/bound.cc
+115
-31
src/schedule/graph.cc
+19
-16
src/schedule/graph.h
+2
-2
src/schedule/int_set.cc
+35
-16
src/schedule/int_set.h
+12
-3
tests/python/test_bound_inference.py
+23
-7
No files found.
HalideIR
@
5d1bd103
Subproject commit
adaea9e85bc0a213d4eb63edfa4762f2147c73ec
Subproject commit
5d1bd103c2abe19392b4d8def7e3ff1c854e8683
src/c_api/c_api_schedule.cc
View file @
0f693212
...
...
@@ -6,8 +6,9 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include "../schedule/bound.h"
#include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h"
namespace
tvm
{
namespace
schedule
{
...
...
@@ -20,8 +21,16 @@ using RetValue = APIVariantValue;
*ret = PassName(args.at(0)); \
}) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1)); \
}) \
REGISTER_SCHEDULE_PASS1
(
InferBound
);
REGISTER_SCHEDULE_PASS1
(
CreateReadGraph
);
REGISTER_SCHEDULE_PASS2
(
PostDFSOrder
);
}
// namespace schedule
}
// namespace tvm
src/lang/operation.cc
View file @
0f693212
...
...
@@ -9,6 +9,11 @@
namespace
tvm
{
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
ComputeOpNode
>
([](
const
ComputeOpNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"op("
<<
op
<<
")"
;
});
Tensor
Compute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
)
{
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
// compute dimension.
...
...
src/schedule/bound.cc
View file @
0f693212
...
...
@@ -7,6 +7,7 @@
#include <tvm/ir_visitor.h>
#include "./int_set.h"
#include "./bound.h"
#include "./graph.h"
namespace
tvm
{
namespace
schedule
{
...
...
@@ -62,7 +63,7 @@ void PassDown(const Schedule& s,
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
void
PassUp
(
const
Schedule
&
s
,
void
PassUp
(
const
Schedule
Node
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
std
::
unordered_map
<
IterVar
,
IntSet
>*
p_state
)
{
auto
&
state
=
*
p_state
;
...
...
@@ -89,62 +90,145 @@ void PassUp(const Schedule& s,
}
}
void
PassBound
(
/*!
* \brief Pass the bound of tensor read
* to the corresponding bound of the IterVar of operation
* \param tensor The tensor to be passed.
* \param dim_bounds The read index set on each dimension.
* \param The result IterVar bound .
*/
void
PassToOperation
(
const
Tensor
&
tensor
,
const
std
::
vector
<
IntSet
>&
arg
_bounds
,
const
std
::
vector
<
IntSet
>&
dim
_bounds
,
std
::
unordered_map
<
IterVar
,
std
::
vector
<
IntSet
>
>*
result
)
{
if
(
tensor
->
op
.
as
<
ComputeOpNode
>
())
{
auto
root_iter_vars
=
tensor
->
op
->
root_iter_vars
();
CHECK_EQ
(
tensor
.
ndim
(),
root_iter_vars
.
size
());
for
(
size_t
i
=
0
;
i
<
tensor
.
ndim
();
++
i
)
{
(
*
result
)[
root_iter_vars
[
i
]].
push_back
(
arg
_bounds
[
i
]);
(
*
result
)[
root_iter_vars
[
i
]].
push_back
(
dim
_bounds
[
i
]);
}
}
else
{
LOG
(
FATAL
)
<<
"unknown operation mode"
;
}
}
void
PassBound
(
Operation
op
,
std
::
unordered_map
<
IterVar
,
IntSet
>*
ebound
)
{
if
(
op
.
as
<
ComputeOpNode
>
())
{
auto
fvisit
=
[
ebound
](
const
NodeRef
&
n
)
{
auto
*
call
=
n
.
as
<
ir
::
Call
>
();
if
(
call
!=
nullptr
&&
call
->
func
.
defined
())
{
Tensor
t
(
call
->
func
.
node_
);
std
::
vector
<
IntSet
>
arg_bounds
;
for
(
size_t
i
=
0
;
i
<
t
.
ndim
();
++
i
)
{
arg_bounds
.
push_back
(
Eval
(
call
->
args
[
i
],
*
ebound
));
}
/*!
* \brief Recursively propagate bound
* \param post_order The propagation order.
* \param dom_map The domain map to be propagated
* \return The result bound
*/
std
::
unordered_map
<
IterVar
,
IntSet
>
BoundProp
(
const
Array
<
Operation
>&
post_order
,
std
::
unordered_map
<
IterVar
,
std
::
vector
<
IntSet
>
>
*
p_state
)
{
std
::
unordered_map
<
IterVar
,
IntSet
>
result
;
for
(
size_t
i
=
post_order
.
size
();
i
!=
0
;
--
i
)
{
Operation
op
=
post_order
[
i
-
1
];
if
(
op
.
as
<
ComputeOpNode
>
())
{
for
(
auto
iv
:
op
->
root_iter_vars
())
{
CHECK
(
p_state
->
count
(
iv
))
<<
"Bound of root operator must exists"
;
CHECK
(
!
result
.
count
(
iv
));
result
[
iv
]
=
Union
(
p_state
->
at
(
iv
));
}
};
ir
::
PostOrderVisit
(
op
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
}
else
{
LOG
(
FATAL
)
<<
"unknown operation mode"
;
auto
fvisit
=
[
p_state
,
&
result
](
const
NodeRef
&
n
)
{
auto
*
call
=
n
.
as
<
ir
::
Call
>
();
if
(
call
!=
nullptr
&&
call
->
func
.
defined
())
{
Tensor
t
(
call
->
func
.
node_
);
if
(
t
->
op
.
defined
())
{
std
::
vector
<
IntSet
>
arg_bounds
;
for
(
size_t
i
=
0
;
i
<
t
.
ndim
();
++
i
)
{
arg_bounds
.
push_back
(
EvalSet
(
call
->
args
[
i
],
result
));
}
PassToOperation
(
t
,
arg_bounds
,
p_state
);
}
}
};
ir
::
PostOrderVisit
(
op
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
}
else
{
LOG
(
FATAL
)
<<
"unknown operation mode"
;
}
}
return
result
;
}
void
InferBound
(
const
Schedule
&
sch
,
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
CHECK_NE
(
sch
->
attach_type
,
kNone
);
// check if scope
bool
ScopeRelax
(
const
IterVar
&
iv
,
const
std
::
string
&
scope
)
{
if
(
iv
->
thread_tag
.
length
()
==
0
)
return
false
;
if
(
scope
.
length
()
==
0
)
return
false
;
static
std
::
unordered_map
<
std
::
string
,
int
>
scope_rank
{
{
"global"
,
0
},
{
"shared"
,
1
},
{
"local"
,
2
}
};
return
scope_rank
.
at
(
scope
)
<=
scope_rank
.
at
(
iv
->
thread_tag
);
}
void
InferBound
(
const
ScheduleNode
*
parent
,
const
Schedule
&
sch
,
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
if
(
sch
->
attach_type
==
kInline
)
return
;
if
(
sch
->
attach_type
==
kRoot
)
{
if
(
sch
->
attach_type
==
kRoot
||
sch
->
attach_type
==
kNone
)
{
auto
root_iter_vars
=
sch
->
op
->
root_iter_vars
();
for
(
size_t
i
=
0
;
i
<
root_iter_vars
.
size
();
++
i
)
{
auto
v
=
root_iter_vars
[
i
];
CHECK
(
v
->
dom
.
defined
());
CHECK
(
!
rmap
->
count
(
v
));
(
*
rmap
)[
v
]
=
v
->
dom
;
for
(
auto
iv
:
root_iter_vars
)
{
CHECK
(
iv
->
dom
.
defined
());
CHECK
(
!
rmap
->
count
(
iv
));
(
*
rmap
)[
iv
]
=
iv
->
dom
;
}
}
// get range of all child iter vars.
PassDown
(
sch
,
rmap
);
// pass iteration variable to children
if
(
sch
->
attach_type
==
kScope
)
{
CHECK
(
parent
!=
nullptr
);
auto
g
=
CreateReadGraph
(
parent
->
op
);
auto
post_order
=
PostDFSOrder
(
parent
->
op
,
g
);
std
::
unordered_map
<
IterVar
,
IntSet
>
up_state
;
bool
fix_value
=
true
;
for
(
auto
iv
:
parent
->
leaf_iter_vars
)
{
if
(
fix_value
&&
!
ScopeRelax
(
iv
,
sch
->
scope
))
{
up_state
[
iv
]
=
IntSet
::
make_point
(
iv
->
var
);
}
else
{
up_state
[
iv
]
=
IntSet
::
make_range
(
rmap
->
at
(
iv
));
}
if
(
sch
->
attach_parent
==
iv
)
{
fix_value
=
false
;
}
}
// get the bound of the root IterVars given the current condition
PassUp
(
parent
,
*
rmap
,
&
up_state
);
std
::
unordered_map
<
IterVar
,
std
::
vector
<
IntSet
>
>
bp_state
;
for
(
auto
iv
:
parent
->
op
->
root_iter_vars
())
{
CHECK
(
up_state
.
count
(
iv
));
bp_state
[
iv
]
=
{
up_state
.
at
(
iv
)};
}
auto
result
=
BoundProp
(
post_order
,
&
bp_state
);
for
(
auto
iv
:
sch
->
op
->
root_iter_vars
())
{
CHECK
(
result
.
count
(
iv
));
CHECK
(
!
rmap
->
count
(
iv
));
(
*
rmap
)[
iv
]
=
result
.
at
(
iv
).
GetCoverRange
();
}
}
// also call infer bound on children
for
(
Schedule
child
:
sch
->
children
)
{
InferBound
(
sch
.
operator
->
(),
child
,
rmap
);
}
}
Map
<
IterVar
,
Range
>
InferBound
(
Schedule
sch
)
{
return
{};
std
::
unordered_map
<
IterVar
,
Range
>
ret
;
CHECK
(
sch
->
attach_type
!=
kInline
&&
sch
->
attach_type
!=
kScope
)
<<
"the Schedule is not a root Schedule"
;
InferBound
(
nullptr
,
sch
,
&
ret
);
return
Map
<
IterVar
,
Range
>
(
ret
.
begin
(),
ret
.
end
());
}
}
// namespace schedule
...
...
src/schedule/graph.cc
View file @
0f693212
...
...
@@ -14,26 +14,29 @@ namespace schedule {
// construct a read graph that gives readers of each operation
// that the root depend on
ReadGraph
CreateReadGraph
(
Operation
root
)
{
std
::
unordered_map
<
Operation
,
std
::
vector
<
Tensor
>
>
rmap
;
rmap
[
root
]
=
{};
ReadGraph
CreateReadGraph
(
const
Operation
&
root
)
{
ReadGraph
rmap
;
std
::
vector
<
Operation
>
stack
{
root
};
std
::
unordered_set
<
const
Node
*>
visited
{
root
.
get
()};
while
(
!
stack
.
empty
())
{
Operation
r
=
stack
.
back
();
Operation
op
=
stack
.
back
();
stack
.
pop_back
();
auto
&
vec
=
rmap
.
at
(
r
)
;
if
(
r
.
as
<
ComputeOpNode
>
())
{
auto
fvisit
=
[
&
vec
,
&
rmap
,
&
stack
](
const
NodeRef
&
n
)
{
Array
<
Tensor
>
deps
;
if
(
op
.
as
<
ComputeOpNode
>
())
{
auto
fvisit
=
[
&
deps
,
&
visited
,
&
stack
](
const
NodeRef
&
n
)
{
auto
*
call
=
n
.
as
<
ir
::
Call
>
();
if
(
call
!=
nullptr
&&
call
->
func
.
defined
())
{
Tensor
t
(
call
->
func
.
node_
);
vec
.
push_back
(
t
);
if
(
t
->
op
.
defined
()
&&
rmap
.
count
(
t
->
op
)
==
0
)
{
rmap
[
t
->
op
]
=
{};
stack
.
push_back
(
t
->
op
);
deps
.
push_back
(
t
);
if
(
t
->
op
.
defined
()
&&
visited
.
count
(
t
->
op
.
get
())
==
0
)
{
visited
.
insert
(
t
->
op
.
get
());
stack
.
push_back
(
t
->
op
);
}
}
};
ir
::
PostOrderVisit
(
r
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
ir
::
PostOrderVisit
(
op
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
rmap
.
Set
(
op
,
deps
);
}
else
{
LOG
(
FATAL
)
<<
"unknown operation mode"
;
}
...
...
@@ -43,9 +46,9 @@ ReadGraph CreateReadGraph(Operation root) {
void
PostDFSOrder
(
const
Operation
&
op
,
const
ReadGraph
&
g
,
std
::
unordered_set
<
Operation
>*
visited
,
std
::
vector
<
Operation
>*
post_order
)
{
const
ReadGraph
&
g
,
std
::
unordered_set
<
Operation
>*
visited
,
Array
<
Operation
>*
post_order
)
{
visited
->
insert
(
op
);
for
(
const
auto
&
t
:
g
.
at
(
op
))
{
if
(
t
->
op
.
defined
()
&&
!
visited
->
count
(
t
->
op
))
{
...
...
@@ -55,10 +58,10 @@ void PostDFSOrder(const Operation& op,
post_order
->
push_back
(
op
);
}
std
::
vector
<
Operation
>
PostDFSOrder
(
Array
<
Operation
>
PostDFSOrder
(
const
Operation
&
root
,
const
ReadGraph
&
g
)
{
std
::
unordered_set
<
Operation
>
visited
;
std
::
vector
<
Operation
>
post_order
;
Array
<
Operation
>
post_order
;
PostDFSOrder
(
root
,
g
,
&
visited
,
&
post_order
);
return
post_order
;
}
...
...
src/schedule/graph.h
View file @
0f693212
...
...
@@ -17,7 +17,7 @@ namespace schedule {
/*!
* \brief data structure of Operation->Tensors it reads
*/
using
ReadGraph
=
std
::
unordered_map
<
Operation
,
std
::
vector
<
Tensor
>
>
;
using
ReadGraph
=
Map
<
Operation
,
Array
<
Tensor
>
>
;
/*!
* \brief Get read graph of each operation to all the
...
...
@@ -38,7 +38,7 @@ ReadGraph CreateReadGraph(const Operation& root);
* \note PostDFSOrder is a special case of Topoligical order,
* and can be used when topoligical order is needed.
*/
std
::
vector
<
Operation
>
PostDFSOrder
(
Array
<
Operation
>
PostDFSOrder
(
const
Operation
&
root
,
const
ReadGraph
&
g
);
}
// namespace schedule
...
...
src/schedule/int_set.cc
View file @
0f693212
...
...
@@ -176,17 +176,37 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p
->
stream
<<
')'
;
});
IntSet
IntSet
::
make
(
Range
dom
)
{
IntSet
IntSet
::
make
_range
(
Range
dom
)
{
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
n
->
base
=
dom
;
return
IntSet
(
n
);
}
Range
IntSet
::
GetCoverRange
()
const
{
const
IntSetNode
*
s
=
operator
->
();
CHECK
(
s
!=
nullptr
)
<<
"empty set"
;
if
(
s
->
domain
.
size
()
==
0
&&
s
->
concrete
.
size
()
==
0
)
{
return
s
->
base
;
}
LOG
(
FATAL
)
<<
"not yet implemented"
;
return
Range
();
}
IntSet
IntSet
::
make_point
(
Expr
point
)
{
return
IntSet
::
make_range
(
Range
::
make_with_min_extent
(
point
,
1
));
}
IntSet
IntSet
::
make_all_set
()
{
LOG
(
FATAL
)
<<
"TODO"
;
return
IntSet
();
}
IntSet
Union
(
const
Array
<
IntSet
>&
set
)
{
if
(
set
.
size
()
==
1
)
return
set
[
0
];
LOG
(
FATAL
)
<<
"TODO"
;
return
IntSet
();
}
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
...
...
@@ -197,7 +217,7 @@ void PassUp(const SplitNode* s,
dom_map
.
count
(
s
->
parent
)
&&
Match
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
Match
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
make
(
dom_map
.
at
(
s
->
parent
));
*
parent
=
IntSet
::
make
_range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
// copy construct
...
...
@@ -230,21 +250,21 @@ void PassUp(const FuseNode* s,
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
Match
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
make
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make
(
dom_map
.
at
(
s
->
inner
));
*
outer
=
IntSet
::
make
_range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make
_range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
if
(
IsNumber
(
fused
))
{
Expr
value
=
AsNumber
(
fused
);
Expr
factor
=
dom_map
.
at
(
s
->
outer
)
->
extent
;
*
outer
=
IntSet
::
make
(
Range
::
make_with_min_extent
(
value
/
factor
,
1
)
);
*
inner
=
IntSet
::
make
(
Range
::
make_with_min_extent
(
value
%
factor
,
1
)
);
*
outer
=
IntSet
::
make
_point
(
value
/
factor
);
*
inner
=
IntSet
::
make
_point
(
value
%
factor
);
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
make
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make
(
dom_map
.
at
(
s
->
inner
));
*
outer
=
IntSet
::
make
_range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make
_range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
...
...
@@ -272,7 +292,7 @@ class IRSetEvaluator {
};
inline
IntSet
ConstOp
(
const
NodeRef
&
,
const
Expr
&
e
,
IRSetEvaluator
*
)
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
e
,
1
)
);
return
IntSet
::
make
_point
(
e
);
}
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
...
...
@@ -286,7 +306,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
if
(
it
!=
m
->
dom_map
.
end
())
{
return
it
->
second
;
}
else
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
e
,
1
)
);
return
IntSet
::
make
_point
(
e
);
}
});
...
...
@@ -298,10 +318,9 @@ inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) {
if
(
IsNumber
(
a
)
&&
IsNumber
(
b
))
{
if
(
Match
(
a
,
op
->
a
)
&&
Match
(
b
,
op
->
b
))
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
e
,
1
)
);
return
IntSet
::
make
_point
(
e
);
}
else
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
T
::
make
(
AsNumber
(
a
),
AsNumber
(
b
)),
1
));
return
IntSet
::
make_point
(
T
::
make
(
AsNumber
(
a
),
AsNumber
(
b
)));
}
}
else
{
return
BinaryCombine
<
T
>
(
a
,
b
);
...
...
@@ -319,7 +338,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
// use simply bound for logical expressions for now.
inline
IntSet
Logical
(
const
NodeRef
&
,
const
Expr
&
e
,
IRSetEvaluator
*
)
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
0
,
2
));
return
IntSet
::
make
_range
(
Range
::
make_with_min_extent
(
0
,
2
));
}
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
...
...
@@ -334,8 +353,8 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
}
// namespace
IntSet
Eval
(
Expr
e
,
const
std
::
unordered_m
ap
<
IterVar
,
IntSet
>&
dom_map
)
{
IntSet
Eval
Set
(
Expr
e
,
const
M
ap
<
IterVar
,
IntSet
>&
dom_map
)
{
IRSetEvaluator
m
;
for
(
auto
kv
:
dom_map
)
{
m
.
dom_map
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
...
...
src/schedule/int_set.h
View file @
0f693212
...
...
@@ -29,6 +29,10 @@ class IntSet : public NodeRef {
return
!
defined
();
}
/*!
* \return a range that covers the IntSet
*/
Range
GetCoverRange
()
const
;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
...
...
@@ -37,7 +41,12 @@ class IntSet : public NodeRef {
* \param dom The domain to be created.
* \return create integer set from existing domain
*/
static
IntSet
make
(
Range
dom
);
static
IntSet
make_range
(
Range
dom
);
/*!
* \param point
* \return create integer set that only contains one point
*/
static
IntSet
make_point
(
Expr
point
);
/*!
* \return create integer set that represents everything
*/
...
...
@@ -52,8 +61,8 @@ class IntSet : public NodeRef {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet
Eval
(
Expr
e
,
const
std
::
unordered_m
ap
<
IterVar
,
IntSet
>&
dom_map
);
IntSet
Eval
Set
(
Expr
e
,
const
M
ap
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Conditional upward message passing.
*
...
...
tests/python/test_bound_inference.py
View file @
0f693212
...
...
@@ -4,16 +4,32 @@ def test_bound_inference():
m
=
tvm
.
Var
(
'm'
)
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
A1
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A
[
i
,
j
])
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
)
A1
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A
[
i
,
j
]
,
name
=
'A1'
)
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
,
name
=
'A2'
)
sA1
=
tvm
.
Schedule
(
A1
.
op
)
sA2
=
tvm
.
Schedule
(
A2
.
op
)
xo
,
xi
=
sA1
.
split
(
A1
.
op
.
dim_var
[
0
],
factor
=
8
)
sA2
.
compute_at
(
sA1
,
xi
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA1
)
xo
,
xi
=
sA2
.
split
(
A2
.
op
.
dim_var
[
0
],
8
)
sA1
.
compute_at
(
sA2
,
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
sA2
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
print
(
bounds
)
print
(
bounds
[
A1
.
op
.
dim_var
[
0
]])
print
(
bounds
[
A1
.
op
.
dim_var
[
1
]])
def
test_create_read_graph
():
m
=
tvm
.
Var
(
'm'
)
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
A1
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A
[
i
,
j
])
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
)
g
=
tvm
.
schedule
.
CreateReadGraph
(
A2
.
op
)
assert
g
[
A2
.
op
][
0
]
==
A1
assert
g
[
A1
.
op
][
0
]
==
A
post_order
=
tvm
.
schedule
.
PostDFSOrder
(
A2
.
op
,
g
)
assert
(
post_order
[
0
]
==
A1
.
op
)
assert
(
post_order
[
1
]
==
A2
.
op
)
if
__name__
==
"__main__"
:
test_bound_inference
()
test_create_read_graph
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment