Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
820a8597
Commit
820a8597
authored
Feb 13, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 13, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG] Introduce Scan, Bugfix Canonical (#43)
parent
f8f02829
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
753 additions
and
94 deletions
+753
-94
include/tvm/ir.h
+31
-4
include/tvm/operation.h
+64
-0
python/tvm/api.py
+52
-3
python/tvm/tensor.py
+7
-2
src/api/api_lang.cc
+9
-0
src/arithmetic/canonical.cc
+9
-2
src/codegen/codegen_cuda.cc
+1
-1
src/codegen/codegen_cuda.h
+1
-1
src/lang/operation.cc
+87
-0
src/pass/inject_virtual_thread.cc
+0
-4
src/pass/storage_flatten.cc
+0
-34
src/schedule/bound.cc
+107
-10
src/schedule/graph.cc
+15
-7
src/schedule/schedule_lang.cc
+2
-0
src/schedule/schedule_ops.cc
+270
-20
tests/python/integration/test_scan.py
+54
-0
tests/python/unittest/test_lang_tensor.py
+14
-0
tests/python/unittest/test_pass_simplify.py
+8
-2
tests/python/unittest/test_schedule_schedule_ops.py
+22
-4
No files found.
include/tvm/ir.h
View file @
820a8597
...
@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
...
@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
static
constexpr
const
char
*
Min
=
"Min"
;
static
constexpr
const
char
*
Min
=
"Min"
;
};
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace
attr
{
/*!
/*!
* \brief
Mark scope of iteration variable, used by Schedule
.
* \brief
Auxiliary data structure used in IR Pass to indicate a tensor
.
*/
*/
constexpr
const
char
*
scope
=
"scope"
;
struct
TensorKey
{
FunctionRef
f
;
int
value_index
;
inline
bool
operator
==
(
const
TensorKey
&
other
)
const
{
return
f
==
other
.
f
&&
value_index
==
other
.
value_index
;
}
inline
std
::
string
GetName
()
const
{
if
(
f
->
num_outputs
()
==
1
)
return
f
->
func_name
();
std
::
ostringstream
os
;
os
<<
f
->
func_name
()
<<
".v"
<<
value_index
;
return
os
.
str
();
}
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace
attr
{
// The above attr does not pass to ir stage.
/*!
/*!
* \brief Mark launching extent of thread, used by device API.
* \brief Mark launching extent of thread, used by device API.
*/
*/
...
@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
...
@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
namespace
std
{
template
<>
struct
hash
<::
tvm
::
ir
::
TensorKey
>
{
std
::
size_t
operator
()(
const
::
tvm
::
ir
::
TensorKey
&
k
)
const
{
size_t
lhs
=
k
.
f
.
hash
();
size_t
rhs
=
static_cast
<
size_t
>
(
k
.
value_index
);
lhs
^=
rhs
+
0x9e3779b9
+
(
lhs
<<
6
)
+
(
lhs
>>
2
);
return
lhs
;
}
};
}
// namespace std
#endif // TVM_IR_H_
#endif // TVM_IR_H_
include/tvm/operation.h
View file @
820a8597
...
@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
...
@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO
(
ComputeOpNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
ComputeOpNode
);
};
};
/*!
* \brief Symbolic scan.
*/
class
ScanOpNode
:
public
OperationNode
{
public
:
/*! \brief IterVar to scan over */
IterVar
scan_axis
;
/*! \brief the initialization tensors */
Array
<
Tensor
>
init
;
/*! \brief the update function represented by tensor */
Array
<
Tensor
>
update
;
/*! \brief The placeholder to refer as states in update. */
Array
<
Tensor
>
state_placeholder
;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
* These are auxiliary data structure for storing result of bound inference.
* They do not corresponds to splittable iterations, thus the name comes
* with underscore.
*/
Array
<
IterVar
>
spatial_axis_
;
/*! \brief constructor */
ScanOpNode
()
{}
// override behavior.
int
num_outputs
()
const
final
;
Array
<
IterVar
>
root_iter_vars
()
const
final
;
Type
output_dtype
(
size_t
i
)
const
final
;
Array
<
Expr
>
output_shape
(
size_t
i
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"scan_axis"
,
&
scan_axis
);
v
->
Visit
(
"init"
,
&
init
);
v
->
Visit
(
"update"
,
&
update
);
v
->
Visit
(
"state_placeholder"
,
&
state_placeholder
);
v
->
Visit
(
"spatial_axis_"
,
&
spatial_axis_
);
}
static
Operation
make
(
std
::
string
name
,
IterVar
axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
);
static
constexpr
const
char
*
_type_key
=
"ScanOp"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ScanOpNode
);
};
/*! \brief The compute function to specify the input source of a Tensor */
/*! \brief The compute function to specify the input source of a Tensor */
using
FCompute
=
std
::
function
<
Expr
(
const
Array
<
Var
>&
i
)
>
;
using
FCompute
=
std
::
function
<
Expr
(
const
Array
<
Var
>&
i
)
>
;
...
@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
...
@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
*/
*/
Tensor
Compute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
);
Tensor
Compute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
);
/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array
<
Tensor
>
Scan
(
IterVar
scan_axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
,
std
::
string
name
=
"scan"
);
// same as compute, specialized for different fcompute function
// same as compute, specialized for different fcompute function
inline
Tensor
Compute
(
Array
<
Expr
>
shape
,
inline
Tensor
Compute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
function
<
Expr
(
Var
)
>
f
,
...
...
python/tvm/api.py
View file @
820a8597
...
@@ -14,6 +14,7 @@ from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
...
@@ -14,6 +14,7 @@ from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from
.
import
_api_internal
from
.
import
_api_internal
from
.
import
make
as
_make
from
.
import
make
as
_make
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
tensor
as
_tensor
from
.
import
collections
as
_collections
from
.
import
collections
as
_collections
int32
=
"int32"
int32
=
"int32"
...
@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
...
@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
shape: Tuple of Expr
shape: Tuple of Expr
The shape of the tensor
The shape of the tensor
fcompute: lambda function of *indices-> value
fcompute: lambda function of *indices-> value
Specifies the input source expression
Specifies the input source expression
...
@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
...
@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
body
=
convert
(
body
)
body
=
convert
(
body
)
op_node
=
_api_internal
.
_ComputeOp
(
op_node
=
_api_internal
.
_ComputeOp
(
name
,
dim_var
,
body
)
name
,
dim_var
,
body
)
return
_api_internal
.
_Tensor
(
return
op_node
.
output
(
0
)
shape
,
body
.
dtype
,
op_node
,
0
)
def
scan
(
axis
,
init
,
update
,
state_placeholder
,
name
=
"scan"
):
"""Construct new tensors by scanning over axis.
Parameters
----------
axis: IterVar
The scanning axis.
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
name: str, optional
The name hint of the tensor
Returns
-------
tensor: tensor.Tensor
The created tensor
Example
-------
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
"""
if
isinstance
(
init
,
_tensor
.
Tensor
):
init
=
[
init
]
if
isinstance
(
update
,
_tensor
.
Tensor
):
update
=
[
update
]
if
isinstance
(
state_placeholder
,
_tensor
.
Tensor
):
state_placeholder
=
[
state_placeholder
]
if
len
(
init
)
!=
len
(
update
)
or
len
(
init
)
!=
len
(
state_placeholder
):
raise
ValueError
(
"init, update, state_placeholder must have same length"
)
op
=
_api_internal
.
_ScanOp
(
name
,
axis
,
init
,
update
,
state_placeholder
)
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
update
))]
return
(
res
[
0
]
if
len
(
res
)
==
1
else
res
)
def
Buffer
(
shape
,
dtype
=
None
,
def
Buffer
(
shape
,
dtype
=
None
,
...
...
python/tvm/tensor.py
View file @
820a8597
...
@@ -75,11 +75,16 @@ class Operation(NodeBase):
...
@@ -75,11 +75,16 @@ class Operation(NodeBase):
return
_api_internal
.
_OpGetOutput
(
self
,
index
)
return
_api_internal
.
_OpGetOutput
(
self
,
index
)
@register_node
@register_node
class
PlaceholderOp
(
Operation
):
"""Placeholder operation."""
pass
@register_node
class
ComputeOp
(
Operation
):
class
ComputeOp
(
Operation
):
"""Compute operation."""
"""Compute operation."""
pass
pass
@register_node
@register_node
class
Placeholder
Op
(
Operation
):
class
Scan
Op
(
Operation
):
"""
Placeholder
operation."""
"""
Scan
operation."""
pass
pass
src/api/api_lang.cc
View file @
820a8597
...
@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
...
@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
args
[
2
]);
args
[
2
]);
});
});
TVM_REGISTER_API
(
_ScanOp
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ScanOpNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
]);
});
TVM_REGISTER_API
(
_OpGetOutput
)
TVM_REGISTER_API
(
_OpGetOutput
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Operation
().
output
(
*
ret
=
args
[
0
].
operator
Operation
().
output
(
...
...
src/arithmetic/canonical.cc
View file @
820a8597
...
@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
...
@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
const
ComExpr
&
sumb
,
const
ComExpr
&
sumb
,
int
bscale
)
{
int
bscale
)
{
std
::
shared_ptr
<
ComExprNode
>
n
=
std
::
make_shared
<
ComExprNode
>
();
std
::
shared_ptr
<
ComExprNode
>
n
=
std
::
make_shared
<
ComExprNode
>
();
n
->
base
=
suma
->
base
+
sumb
->
base
;
n
->
base
=
suma
->
base
+
sumb
->
base
*
bscale
;
// merge of suma and sumb;
// merge of suma and sumb;
size_t
i
=
0
,
j
=
0
;
size_t
i
=
0
,
j
=
0
;
while
(
i
<
suma
->
elem
.
size
()
&&
j
<
sumb
->
elem
.
size
())
{
while
(
i
<
suma
->
elem
.
size
()
&&
j
<
sumb
->
elem
.
size
())
{
...
@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
...
@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
// convert sum to expr
// convert sum to expr
Expr
Sum2Expr
(
const
ComExpr
&
com
,
Type
t
)
{
Expr
Sum2Expr
(
const
ComExpr
&
com
,
Type
t
)
{
Expr
vsum
;
Expr
vsum
;
if
(
com
->
base
!=
0
)
{
if
(
com
->
base
>
0
)
{
vsum
=
make_const
(
t
,
com
->
base
);
vsum
=
make_const
(
t
,
com
->
base
);
}
}
for
(
const
ComExprEntry
&
e
:
com
->
elem
)
{
for
(
const
ComExprEntry
&
e
:
com
->
elem
)
{
...
@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
...
@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
}
}
}
}
}
}
if
(
com
->
base
<
0
)
{
if
(
vsum
.
defined
())
{
vsum
=
Sub
::
make
(
vsum
,
make_const
(
t
,
-
com
->
base
));
}
else
{
vsum
=
make_const
(
t
,
com
->
base
);
}
}
for
(
const
ComExprEntry
&
e
:
com
->
elem
)
{
for
(
const
ComExprEntry
&
e
:
com
->
elem
)
{
if
(
e
.
scale
<
0
)
{
if
(
e
.
scale
<
0
)
{
Expr
v
=
e
.
value
;
Expr
v
=
e
.
value
;
...
...
src/codegen/codegen_cuda.cc
View file @
820a8597
...
@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
...
@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_postproc"
);
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_postproc"
);
code
=
f
(
code
).
operator
std
::
string
();
code
=
f
(
code
).
operator
std
::
string
();
}
}
LOG
(
INFO
)
<<
code
;
std
::
string
ptx
;
std
::
string
ptx
;
if
(
PackedFunc
::
GlobalExist
(
"tvm_callback_cuda_compile"
))
{
if
(
PackedFunc
::
GlobalExist
(
"tvm_callback_cuda_compile"
))
{
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_compile"
);
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_compile"
);
...
...
src/codegen/codegen_cuda.h
View file @
820a8597
...
@@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC {
...
@@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC {
private
:
private
:
// magic number to add pragma unroll to it.
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
// used to generate code that is compact but still unrolls.
int
max_auto_unroll_
{
8
};
int
max_auto_unroll_
{
1025
};
};
};
}
// namespace codegen
}
// namespace codegen
...
...
src/lang/operation.cc
View file @
820a8597
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <tvm/operation.h>
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory>
#include <memory>
namespace
tvm
{
namespace
tvm
{
...
@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE
(
ComputeOpNode
);
TVM_REGISTER_NODE_TYPE
(
ComputeOpNode
);
// Scan
inline
bool
prove_equal
(
Expr
lhs
,
Expr
rhs
)
{
return
is_zero
(
ir
::
Simplify
(
lhs
-
rhs
));
}
int
ScanOpNode
::
num_outputs
()
const
{
return
update
.
size
();
}
Array
<
IterVar
>
ScanOpNode
::
root_iter_vars
()
const
{
return
Array
<
IterVar
>
{
scan_axis
};
}
Type
ScanOpNode
::
output_dtype
(
size_t
i
)
const
{
return
update
[
i
]
->
dtype
;
}
Array
<
Expr
>
ScanOpNode
::
output_shape
(
size_t
i
)
const
{
CHECK_LT
(
i
,
state_placeholder
.
size
());
return
state_placeholder
[
i
]
->
shape
;
}
Operation
ScanOpNode
::
make
(
std
::
string
name
,
IterVar
axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
)
{
auto
n
=
std
::
make_shared
<
ScanOpNode
>
();
CHECK_EQ
(
init
.
size
(),
update
.
size
());
CHECK_EQ
(
init
.
size
(),
state_placeholder
.
size
());
for
(
size_t
i
=
0
;
i
<
init
.
size
();
++
i
)
{
CHECK_EQ
(
init
[
i
]
->
dtype
,
state_placeholder
[
i
]
->
dtype
);
CHECK_EQ
(
init
[
i
]
->
dtype
,
update
[
i
]
->
dtype
);
CHECK
(
can_prove
(
init
[
i
]
->
shape
[
0
]
==
axis
->
dom
->
min
))
<<
"init.shape[0] need to match scan_axis.dom.min"
;
CHECK
(
prove_equal
(
state_placeholder
[
i
]
->
shape
[
0
],
axis
->
dom
->
min
+
axis
->
dom
->
extent
))
<<
"shate_placeholder.shape[0] need to match"
<<
" scan_axis.dom.min + scan_axis.dom.extent"
;
CHECK_EQ
(
state_placeholder
[
i
].
ndim
(),
init
[
i
].
ndim
())
<<
"The dimension of init need to match state_placeholder"
;
CHECK_EQ
(
update
[
i
].
ndim
()
+
1
,
state_placeholder
[
i
].
ndim
())
<<
"The update.ndim need to be state_placeholder.ndim - 1"
;
for
(
size_t
k
=
0
;
k
<
update
[
i
].
ndim
();
++
k
)
{
CHECK
(
prove_equal
(
update
[
i
]
->
shape
[
k
],
state_placeholder
[
i
]
->
shape
[
k
+
1
]));
// setup spatial axis
std
::
ostringstream
spatial_name
;
spatial_name
<<
name
<<
".out"
<<
i
<<
".i"
<<
k
+
1
;
n
->
spatial_axis_
.
push_back
(
IterVar
(
Range
::
make_with_min_extent
(
0
,
update
[
i
]
->
shape
[
k
]),
spatial_name
.
str
()));
}
for
(
size_t
k
=
1
;
k
<
init
[
i
].
ndim
();
++
k
)
{
CHECK
(
prove_equal
(
init
[
i
]
->
shape
[
k
],
state_placeholder
[
i
]
->
shape
[
k
]));
}
}
n
->
name
=
name
;
n
->
scan_axis
=
axis
;
n
->
init
=
init
;
n
->
update
=
update
;
n
->
state_placeholder
=
state_placeholder
;
return
Operation
(
n
);
}
Array
<
Tensor
>
Scan
(
IterVar
scan_axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
,
std
::
string
name
)
{
Operation
op
=
ScanOpNode
::
make
(
name
,
scan_axis
,
init
,
update
,
state_placeholder
);
Array
<
Tensor
>
res
;
for
(
int
i
=
0
;
i
<
op
->
num_outputs
();
++
i
)
{
res
.
push_back
(
op
.
output
(
i
));
}
return
res
;
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
ScanOpNode
>
([](
const
ScanOpNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"scan("
<<
op
->
name
<<
", "
<<
op
<<
")"
;
});
}
// namespace tvm
}
// namespace tvm
src/pass/inject_virtual_thread.cc
View file @
820a8597
...
@@ -191,9 +191,6 @@ class VTInjector : public IRMutator {
...
@@ -191,9 +191,6 @@ class VTInjector : public IRMutator {
}
}
// Attribute
// Attribute
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type_key
==
attr
::
scope
)
{
return
Mutate
(
op
->
body
);
}
else
{
Expr
value
=
Mutate
(
op
->
value
);
Expr
value
=
Mutate
(
op
->
value
);
if
(
visit_touched_var_
)
{
if
(
visit_touched_var_
)
{
return
InjectVTLoop
(
s
,
true
);
return
InjectVTLoop
(
s
,
true
);
...
@@ -207,7 +204,6 @@ class VTInjector : public IRMutator {
...
@@ -207,7 +204,6 @@ class VTInjector : public IRMutator {
}
}
}
}
}
}
}
// LetStmt
// LetStmt
Stmt
Mutate_
(
const
LetStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
LetStmt
*
op
,
const
Stmt
&
s
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
value
=
this
->
Mutate
(
op
->
value
);
...
...
src/pass/storage_flatten.cc
View file @
820a8597
...
@@ -11,40 +11,6 @@
...
@@ -11,40 +11,6 @@
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
// key of function buffer
struct
TensorKey
{
FunctionRef
f
;
int
value_index
;
inline
bool
operator
==
(
const
TensorKey
&
other
)
const
{
return
f
==
other
.
f
&&
value_index
==
other
.
value_index
;
}
inline
std
::
string
GetName
()
const
{
if
(
f
->
num_outputs
()
==
1
)
return
f
->
func_name
();
std
::
ostringstream
os
;
os
<<
f
->
func_name
()
<<
".v"
<<
value_index
;
return
os
.
str
();
}
};
}
// namespace ir
}
// namespace tvm
namespace
std
{
template
<>
struct
hash
<::
tvm
::
ir
::
TensorKey
>
{
std
::
size_t
operator
()(
const
::
tvm
::
ir
::
TensorKey
&
k
)
const
{
size_t
lhs
=
k
.
f
.
hash
();
size_t
rhs
=
static_cast
<
size_t
>
(
k
.
value_index
);
lhs
^=
rhs
+
0x9e3779b9
+
(
lhs
<<
6
)
+
(
lhs
>>
2
);
return
lhs
;
}
};
}
// namespace std
namespace
tvm
{
namespace
ir
{
using
Halide
::
Internal
::
Region
;
using
Halide
::
Internal
::
Region
;
using
runtime
::
StorageScope
;
using
runtime
::
StorageScope
;
using
runtime
::
ThreadScope
;
using
runtime
::
ThreadScope
;
...
...
src/schedule/bound.cc
View file @
820a8597
...
@@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) {
...
@@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) {
return
ir
::
Simplify
((
a
+
b
-
1
)
/
b
);
return
ir
::
Simplify
((
a
+
b
-
1
)
/
b
);
}
}
inline
bool
prove_equal
(
Expr
lhs
,
Expr
rhs
)
{
return
is_zero
(
ir
::
Simplify
(
lhs
-
rhs
));
}
// Downward message passing algorithm on stage schedule s,
// Downward message passing algorithm on stage schedule s,
// pass the range state down from the root to the leaves
// pass the range state down from the root to the leaves
// after this pass, every IterVar in the stage hyper graph will have a range(domain)
// after this pass, every IterVar in the stage hyper graph will have a range(domain)
...
@@ -41,9 +45,18 @@ void PassDown(const Stage& s,
...
@@ -41,9 +45,18 @@ void PassDown(const Stage& s,
if
(
r
->
outer
->
dom
.
defined
())
{
if
(
r
->
outer
->
dom
.
defined
())
{
state
[
r
->
outer
]
=
r
->
outer
->
dom
;
state
[
r
->
outer
]
=
r
->
outer
->
dom
;
}
else
{
}
else
{
CHECK
(
!
state
.
count
(
r
->
outer
));
if
(
!
state
.
count
(
r
->
outer
))
{
state
[
r
->
outer
]
=
Range
::
make_with_min_extent
(
state
[
r
->
outer
]
=
Range
::
make_with_min_extent
(
0
,
DivCeil
(
range_parent
->
extent
,
r
->
factor
));
0
,
DivCeil
(
range_parent
->
extent
,
r
->
factor
));
}
else
{
Expr
outer_ext
=
DivCeil
(
range_parent
->
extent
,
r
->
factor
);
Range
outer_rng
=
state
.
at
(
r
->
outer
);
bool
match
=
is_zero
(
outer_rng
->
min
);
if
(
!
prove_equal
(
outer_ext
,
outer_rng
->
extent
))
match
=
false
;
CHECK
(
match
)
<<
"IterVar is used in two places as outer scope,"
<<
" cannot prove their extents are the same"
;
}
}
}
}
else
{
}
else
{
CHECK
(
r
->
outer
->
dom
.
defined
());
CHECK
(
r
->
outer
->
dom
.
defined
());
...
@@ -181,6 +194,21 @@ void PassUp(const Stage& s,
...
@@ -181,6 +194,21 @@ void PassUp(const Stage& s,
}
}
}
}
// All the itervars that are needed to output bound of op.
// For most op, it is root_iter_vars
// For Scan, it also contains the additional spatial axis.
Array
<
IterVar
>
OutputRelatedIterVars
(
const
Operation
&
op
)
{
if
(
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
op
.
as
<
ScanOpNode
>
();
Array
<
IterVar
>
ret
{
scan
->
scan_axis
};
for
(
IterVar
iv
:
scan
->
spatial_axis_
)
{
ret
.
push_back
(
iv
);
}
return
ret
;
}
else
{
return
op
->
root_iter_vars
();
}
}
/*! \brief temporary data structure to store Tensor domain */
/*! \brief temporary data structure to store Tensor domain */
struct
TensorDom
{
struct
TensorDom
{
...
@@ -214,6 +242,34 @@ void BoundProp(const Operation& op,
...
@@ -214,6 +242,34 @@ void BoundProp(const Operation& op,
}
}
};
};
ir
::
PostOrderVisit
(
op
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
ir
::
PostOrderVisit
(
op
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
}
else
if
(
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
op
.
as
<
ScanOpNode
>
();
size_t
sp_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
scan
->
init
.
size
();
++
i
)
{
TensorDom
*
init_dom
=
nullptr
;
TensorDom
*
update_dom
=
nullptr
;
if
(
out
->
count
(
scan
->
init
[
i
]))
{
init_dom
=
&
out
->
at
(
scan
->
init
[
i
]);
}
if
(
out
->
count
(
scan
->
update
[
i
]))
{
update_dom
=
&
out
->
at
(
scan
->
update
[
i
]);
}
// first dimension, always needed.
if
(
init_dom
)
{
init_dom
->
data
[
0
].
push_back
(
IntSet
::
range
(
Range
::
make_with_min_extent
(
0
,
scan
->
init
[
i
]
->
shape
[
0
])));
}
// The update dimensions
for
(
size_t
k
=
0
;
k
<
scan
->
update
[
i
]
->
shape
.
size
();
++
k
,
++
sp_idx
)
{
IterVar
sp_ax
=
scan
->
spatial_axis_
[
sp_idx
];
if
(
init_dom
)
{
init_dom
->
data
[
k
+
1
].
push_back
(
dom_map
.
at
(
sp_ax
->
var
.
get
()));
}
if
(
update_dom
)
{
update_dom
->
data
[
k
].
push_back
(
dom_map
.
at
(
sp_ax
->
var
.
get
()));
}
}
}
}
else
if
(
op
.
as
<
PlaceholderOpNode
>
())
{
}
else
if
(
op
.
as
<
PlaceholderOpNode
>
())
{
// do nothing
// do nothing
}
else
{
}
else
{
...
@@ -221,14 +277,49 @@ void BoundProp(const Operation& op,
...
@@ -221,14 +277,49 @@ void BoundProp(const Operation& op,
}
}
}
}
void
InferOpBound
(
const
Operation
&
op
,
// Given the bound of output of op
// Pass the bound to the related axis in op.
void
GatherOpBound
(
const
ScanOpNode
*
scan
,
const
Operation
&
op
,
const
std
::
unordered_map
<
Tensor
,
TensorDom
>&
tmap
,
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
CHECK
(
!
rmap
->
count
(
scan
->
scan_axis
));
std
::
vector
<
Tensor
>
output
(
op
->
num_outputs
());
for
(
size_t
i
=
0
;
i
<
output
.
size
();
++
i
)
{
output
[
i
]
=
op
.
output
(
i
);
}
// Update for time axis.
std
::
vector
<
IntSet
>
time_dom
;
for
(
size_t
i
=
0
;
i
<
output
.
size
();
++
i
)
{
const
TensorDom
&
d
=
tmap
.
at
(
output
[
i
]);
time_dom
.
insert
(
time_dom
.
end
(),
d
.
data
[
0
].
begin
(),
d
.
data
[
0
].
end
());
}
LOG
(
INFO
)
<<
time_dom
.
size
();
CHECK
(
!
rmap
->
count
(
scan
->
scan_axis
));
Range
sdom
=
scan
->
scan_axis
->
dom
;
Range
r
=
arith
::
Union
(
time_dom
).
cover_range
(
sdom
);
(
*
rmap
)[
scan
->
scan_axis
]
=
Range
::
make_with_min_extent
(
sdom
->
min
,
ir
::
Simplify
(
r
->
extent
+
r
->
min
-
sdom
->
min
));
// Update for spatial axis.
size_t
sp_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
output
.
size
();
++
i
)
{
for
(
size_t
k
=
0
;
k
<
scan
->
update
[
i
]
->
shape
.
size
();
++
k
,
++
sp_idx
)
{
IterVar
sp_ax
=
scan
->
spatial_axis_
[
sp_idx
];
CHECK
(
!
rmap
->
count
(
sp_ax
));
// In default, we always need all spatial axis
// Unless that axis only refers back to itself as a fixed point.
// TODO(tqchen): Add fix point detection.
(
*
rmap
)[
sp_ax
]
=
sp_ax
->
dom
;
}
}
}
void
GatherOpBound
(
const
Operation
&
op
,
const
std
::
unordered_map
<
Tensor
,
TensorDom
>&
tmap
,
const
std
::
unordered_map
<
Tensor
,
TensorDom
>&
tmap
,
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
if
(
op
.
as
<
ComputeOpNode
>
())
{
if
(
op
.
as
<
ComputeOpNode
>
())
{
auto
root_iter_vars
=
op
->
root_iter_vars
();
const
ComputeOpNode
*
compute
=
op
.
as
<
ComputeOpNode
>
();
const
ComputeOpNode
*
compute
=
op
.
as
<
ComputeOpNode
>
();
const
TensorDom
&
tdom
=
tmap
.
at
(
op
.
output
(
0
));
const
TensorDom
&
tdom
=
tmap
.
at
(
op
.
output
(
0
));
for
(
size_t
i
=
0
;
i
<
compute
->
axis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
compute
->
axis
.
size
();
++
i
)
{
Range
r
=
arith
::
Union
(
tdom
.
data
[
i
]).
cover_range
(
compute
->
axis
[
i
]
->
dom
);
Range
r
=
arith
::
Union
(
tdom
.
data
[
i
]).
cover_range
(
compute
->
axis
[
i
]
->
dom
);
CHECK
(
!
rmap
->
count
(
compute
->
axis
[
i
]));
CHECK
(
!
rmap
->
count
(
compute
->
axis
[
i
]));
...
@@ -238,6 +329,8 @@ void InferOpBound(const Operation& op,
...
@@ -238,6 +329,8 @@ void InferOpBound(const Operation& op,
CHECK
(
!
rmap
->
count
(
compute
->
reduce_axis
[
i
]));
CHECK
(
!
rmap
->
count
(
compute
->
reduce_axis
[
i
]));
(
*
rmap
)[
compute
->
reduce_axis
[
i
]]
=
compute
->
reduce_axis
[
i
]
->
dom
;
(
*
rmap
)[
compute
->
reduce_axis
[
i
]]
=
compute
->
reduce_axis
[
i
]
->
dom
;
}
}
}
else
if
(
op
.
as
<
ScanOpNode
>
())
{
GatherOpBound
(
op
.
as
<
ScanOpNode
>
(),
op
,
tmap
,
rmap
);
}
else
if
(
op
.
as
<
PlaceholderOpNode
>
())
{
}
else
if
(
op
.
as
<
PlaceholderOpNode
>
())
{
// dp nothing
// dp nothing
}
else
{
}
else
{
...
@@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage,
...
@@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage,
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
if
(
stage
->
attach_type
==
kInline
)
return
;
if
(
stage
->
attach_type
==
kInline
)
return
;
if
(
stage
->
attach_type
==
kRoot
||
stage
->
attach_type
==
kNone
)
{
if
(
stage
->
attach_type
==
kRoot
||
stage
->
attach_type
==
kNone
)
{
auto
root_iter_vars
=
stage
->
op
->
root_iter_vars
();
for
(
auto
iv
:
OutputRelatedIterVars
(
stage
->
op
))
{
for
(
auto
iv
:
root_iter_vars
)
{
CHECK
(
iv
->
dom
.
defined
());
CHECK
(
iv
->
dom
.
defined
());
CHECK
(
!
rmap
->
count
(
iv
));
CHECK
(
!
rmap
->
count
(
iv
));
(
*
rmap
)[
iv
]
=
iv
->
dom
;
(
*
rmap
)[
iv
]
=
iv
->
dom
;
...
@@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage,
...
@@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage,
PassUp
(
parent
,
*
rmap
,
&
up_state
);
PassUp
(
parent
,
*
rmap
,
&
up_state
);
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
for
(
auto
iv
:
parent
->
op
->
root_iter_vars
())
{
for
(
auto
iv
:
OutputRelatedIterVars
(
parent
->
op
))
{
Range
r
=
up_state
.
at
(
iv
).
cover_range
(
iv
->
dom
);
Range
r
;
if
(
up_state
.
count
(
iv
))
{
r
=
up_state
.
at
(
iv
).
cover_range
(
iv
->
dom
);
}
else
{
r
=
iv
->
dom
;
}
if
(
relax_set
.
size
()
!=
0
)
{
if
(
relax_set
.
size
()
!=
0
)
{
dom_map
[
iv
->
var
.
get
()]
=
EvalSet
(
r
,
relax_set
);
dom_map
[
iv
->
var
.
get
()]
=
EvalSet
(
r
,
relax_set
);
}
else
{
}
else
{
...
@@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage,
...
@@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage,
CHECK
(
found
)
CHECK
(
found
)
<<
"Invalid Schedule, cannot find the producer "
<<
stage
->
op
<<
"Invalid Schedule, cannot find the producer "
<<
stage
->
op
<<
" along the loop nest specified by compute_at of consumer "
<<
op
;
<<
" along the loop nest specified by compute_at of consumer "
<<
op
;
for
(
auto
iv
:
op
->
root_iter_vars
(
))
{
for
(
auto
iv
:
OutputRelatedIterVars
(
op
))
{
Range
r
=
rmap
->
at
(
iv
);
Range
r
=
rmap
->
at
(
iv
);
dom_map
[
iv
->
var
.
get
()]
=
EvalSet
(
r
,
relax_set
);
dom_map
[
iv
->
var
.
get
()]
=
EvalSet
(
r
,
relax_set
);
}
}
BoundProp
(
op
,
dom_map
,
&
tmap
);
BoundProp
(
op
,
dom_map
,
&
tmap
);
}
}
Inf
erOpBound
(
stage
->
op
,
tmap
,
rmap
);
Gath
erOpBound
(
stage
->
op
,
tmap
,
rmap
);
}
}
FeedGraph
CreateFeedGraph
(
const
Schedule
&
sch
)
{
FeedGraph
CreateFeedGraph
(
const
Schedule
&
sch
)
{
...
...
src/schedule/graph.cc
View file @
820a8597
...
@@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
...
@@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
if
(
call
!=
nullptr
&&
call
->
func
.
defined
())
{
if
(
call
!=
nullptr
&&
call
->
func
.
defined
())
{
Operation
call_op
(
call
->
func
.
node_
);
Operation
call_op
(
call
->
func
.
node_
);
deps
.
push_back
(
call_op
.
output
(
call
->
value_index
));
deps
.
push_back
(
call_op
.
output
(
call
->
value_index
));
if
(
call_op
.
defined
()
&&
visited
.
count
(
call_op
.
get
())
==
0
)
{
visited
.
insert
(
call_op
.
get
());
stack
.
push_back
(
call_op
);
}
}
}
};
};
ir
::
PostOrderVisit
(
op
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
ir
::
PostOrderVisit
(
op
.
as
<
ComputeOpNode
>
()
->
body
,
fvisit
);
rmap
.
Set
(
op
,
deps
);
}
else
if
(
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
op
.
as
<
ScanOpNode
>
();
for
(
Tensor
t
:
scan
->
init
)
{
deps
.
push_back
(
t
);
}
for
(
Tensor
t
:
scan
->
update
)
{
deps
.
push_back
(
t
);
}
}
else
if
(
op
.
as
<
PlaceholderOpNode
>
())
{
}
else
if
(
op
.
as
<
PlaceholderOpNode
>
())
{
// empty set of deps
rmap
.
Set
(
op
,
deps
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown Operation"
<<
op
->
type_key
();
LOG
(
FATAL
)
<<
"unknown Operation"
<<
op
->
type_key
();
}
}
rmap
.
Set
(
op
,
deps
);
for
(
Tensor
t
:
deps
)
{
if
(
t
->
op
.
defined
()
&&
visited
.
count
(
t
->
op
.
get
())
==
0
)
{
visited
.
insert
(
t
->
op
.
get
());
stack
.
push_back
(
t
->
op
);
}
}
}
}
return
rmap
;
return
rmap
;
}
}
...
...
src/schedule/schedule_lang.cc
View file @
820a8597
...
@@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
...
@@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
Stage
&
Stage
::
reorder
(
const
Array
<
IterVar
>&
order
)
{
// NOLINT(*)
Stage
&
Stage
::
reorder
(
const
Array
<
IterVar
>&
order
)
{
// NOLINT(*)
StageNode
*
self
=
operator
->
();
StageNode
*
self
=
operator
->
();
CHECK
(
!
self
->
op
.
as
<
ScanOpNode
>
())
<<
"Cannot reorder axis of scan"
;
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
std
::
vector
<
size_t
>
pos
;
std
::
vector
<
size_t
>
pos
;
...
...
src/schedule/schedule_ops.cc
View file @
820a8597
...
@@ -7,7 +7,9 @@
...
@@ -7,7 +7,9 @@
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include <tvm/schedule_pass.h>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "../pass/ir_util.h"
#include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/compute_expr.h"
#include "./graph.h"
#include "./graph.h"
...
@@ -18,6 +20,12 @@ namespace schedule {
...
@@ -18,6 +20,12 @@ namespace schedule {
using
namespace
arith
;
using
namespace
arith
;
using
namespace
ir
;
using
namespace
ir
;
// Two private scope marks
namespace
attr
{
constexpr
const
char
*
loop_scope
=
"loop_scope"
;
constexpr
const
char
*
scan_scope
=
"scan_scope"
;
}
// namespace attr
/*!
/*!
* \brief message passing to find if IterVar is related to reduction.
* \brief message passing to find if IterVar is related to reduction.
* \param s The stage to be used.
* \param s The stage to be used.
...
@@ -168,7 +176,6 @@ MakeLoopNest(const Stage& sch,
...
@@ -168,7 +176,6 @@ MakeLoopNest(const Stage& sch,
value_map
[
iv
]
=
iv
->
var
;
value_map
[
iv
]
=
iv
->
var
;
continue
;
continue
;
}
}
Range
dom
=
dom_map
.
at
(
iv
);
Range
dom
=
dom_map
.
at
(
iv
);
// initialize the offset and loop_level
// initialize the offset and loop_level
Var
var
=
iv
->
var
;
Var
var
=
iv
->
var
;
...
@@ -223,7 +230,7 @@ MakeLoopNest(const Stage& sch,
...
@@ -223,7 +230,7 @@ MakeLoopNest(const Stage& sch,
if
(
!
reduce_init_loop
)
{
if
(
!
reduce_init_loop
)
{
// annotate the extent of the IterVar
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
ir
::
attr
::
scope
,
iv
->
var
,
no_op
));
AttrStmt
::
make
(
iv
,
attr
::
loop_
scope
,
iv
->
var
,
no_op
));
}
}
}
}
// message passing to get offset of root iter vars.
// message passing to get offset of root iter vars.
...
@@ -307,8 +314,8 @@ Stmt MakeLoop(const Stage& s,
...
@@ -307,8 +314,8 @@ Stmt MakeLoop(const Stage& s,
init
=
Substitute
(
init
,
init_value_map
);
init
=
Substitute
(
init
,
init_value_map
);
init
=
MergeNest
(
init_nest
,
init
);
init
=
MergeNest
(
init_nest
,
init
);
// common nest
// common nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
common
(
nest
.
begin
(),
nest
.
begin
()
+
begin_loop
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
common
(
nest
.
begin
(),
nest
.
begin
()
+
begin_loop
+
1
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
reduce
(
nest
.
begin
()
+
begin_loop
,
nest
.
end
());
std
::
vector
<
std
::
vector
<
Stmt
>
>
reduce
(
nest
.
begin
()
+
begin_loop
+
1
,
nest
.
end
());
provide
=
MergeNest
(
reduce
,
provide
);
provide
=
MergeNest
(
reduce
,
provide
);
return
MergeNest
(
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
common
,
Block
::
make
(
init
,
provide
));
...
@@ -340,6 +347,29 @@ Stmt MakeRealize(const ComputeOpNode* op,
...
@@ -340,6 +347,29 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds
,
make_const
(
Bool
(
1
),
true
),
body
);
bounds
,
make_const
(
Bool
(
1
),
true
),
body
);
}
}
Stmt
MakeRealize
(
const
ScanOpNode
*
op
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
vector
<
Tensor
>&
tensors
,
Stmt
body
)
{
Range
sdom
=
dom_map
.
at
(
op
->
scan_axis
);
Range
tdom
=
Range
::
make_with_min_extent
(
0
,
ir
::
Simplify
(
sdom
->
extent
+
sdom
->
min
));
size_t
sp_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
const
Tensor
&
t
=
tensors
[
i
];
CHECK_EQ
(
static_cast
<
size_t
>
(
t
->
value_index
),
i
);
Halide
::
Internal
::
Region
bounds
;
bounds
.
push_back
(
tdom
);
for
(
size_t
k
=
0
;
k
<
op
->
update
[
i
]
->
shape
.
size
();
++
k
,
++
sp_idx
)
{
IterVar
sp_ax
=
op
->
spatial_axis_
[
sp_idx
];
bounds
.
push_back
(
dom_map
.
at
(
sp_ax
));
}
body
=
Realize
::
make
(
t
->
op
,
t
->
value_index
,
t
->
dtype
,
bounds
,
make_const
(
Bool
(
1
),
true
),
body
);
}
return
body
;
}
void
MakeReduction
(
const
ComputeOpNode
*
op
,
void
MakeReduction
(
const
ComputeOpNode
*
op
,
const
std
::
vector
<
Tensor
>&
tensors
,
const
std
::
vector
<
Tensor
>&
tensors
,
...
@@ -382,12 +412,18 @@ Stmt MakePipeline(const Stage& s,
...
@@ -382,12 +412,18 @@ Stmt MakePipeline(const Stage& s,
Stmt
init
,
provide
;
Stmt
init
,
provide
;
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
if
(
compute
)
{
if
(
compute
)
{
if
(
compute
->
reduce_axis
.
size
()
==
0
)
{
if
(
compute
->
reduce_axis
.
size
()
==
0
)
{
provide
=
MakeProvide
(
compute
,
tensors
);
provide
=
MakeProvide
(
compute
,
tensors
);
}
else
{
}
else
{
MakeReduction
(
compute
,
tensors
,
&
init
,
&
provide
);
MakeReduction
(
compute
,
tensors
,
&
init
,
&
provide
);
}
}
}
else
if
(
scan
)
{
// Provide is done by the sub operations.
provide
=
AttrStmt
::
make
(
s
->
op
,
attr
::
scan_scope
,
scan
->
scan_axis
->
var
,
Evaluate
::
make
(
0
));
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"not supported op "
<<
s
->
op
->
type_key
();
LOG
(
FATAL
)
<<
"not supported op "
<<
s
->
op
->
type_key
();
}
}
...
@@ -396,7 +432,12 @@ Stmt MakePipeline(const Stage& s,
...
@@ -396,7 +432,12 @@ Stmt MakePipeline(const Stage& s,
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
Stmt
pipeline
=
producer
;
Stmt
pipeline
=
producer
;
if
(
consumer
.
defined
())
{
// check if consumer is nop.
bool
is_no_op
{
false
};
const
Evaluate
*
ev
=
consumer
.
as
<
Evaluate
>
();
if
(
ev
&&
ev
->
value
.
as
<
IntImm
>
())
is_no_op
=
true
;
if
(
consumer
.
defined
()
&&
!
is_no_op
)
{
consumer
=
ProducerConsumer
::
make
(
s
->
op
,
false
,
consumer
);
consumer
=
ProducerConsumer
::
make
(
s
->
op
,
false
,
consumer
);
pipeline
=
Block
::
make
(
producer
,
consumer
);
pipeline
=
Block
::
make
(
producer
,
consumer
);
}
}
...
@@ -404,47 +445,103 @@ Stmt MakePipeline(const Stage& s,
...
@@ -404,47 +445,103 @@ Stmt MakePipeline(const Stage& s,
if
(
s
->
op
.
as
<
ComputeOpNode
>
())
{
if
(
s
->
op
.
as
<
ComputeOpNode
>
())
{
pipeline
=
MakeRealize
(
s
->
op
.
as
<
ComputeOpNode
>
(),
pipeline
=
MakeRealize
(
s
->
op
.
as
<
ComputeOpNode
>
(),
dom_map
,
tensors
,
pipeline
);
dom_map
,
tensors
,
pipeline
);
}
else
if
(
s
->
op
.
as
<
ScanOpNode
>
())
{
pipeline
=
MakeRealize
(
s
->
op
.
as
<
ScanOpNode
>
(),
dom_map
,
tensors
,
pipeline
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"not supported op"
;
LOG
(
FATAL
)
<<
"not supported op"
;
return
Stmt
();
}
}
// use attribute to mark scope of the operation.
// use attribute to mark scope of the operation.
pipeline
=
AttrStmt
::
make
(
pipeline
=
AttrStmt
::
make
(
s
->
op
,
"realize_scope"
,
s
->
op
,
ir
::
attr
::
realize_scope
,
StringImm
::
make
(
s
->
scope
),
StringImm
::
make
(
s
->
scope
),
pipeline
);
pipeline
);
return
pipeline
;
return
pipeline
;
}
}
// inject the operator's realization on the stmt.
// inject the operator's realization on the stmt.
class
Inject
Realize
:
public
IRMutator
{
class
Inject
Attach
:
public
IRMutator
{
public
:
public
:
InjectRealize
(
Stage
schedule
,
Map
<
IterVar
,
Range
>
dom_map
)
InjectAttach
(
const
Stage
&
stage
,
:
schedule
(
schedule
),
dom_map
(
dom_map
)
{}
const
Map
<
IterVar
,
Range
>&
dom_map
)
:
stage_
(
stage
),
dom_map_
(
dom_map
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
stmt
.
defined
());
CHECK
(
stmt
.
defined
());
stmt
=
IRMutator
::
Mutate
(
stmt
);
stmt
=
IRMutator
::
Mutate
(
stmt
);
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
!=
nullptr
&&
if
(
op
!=
nullptr
&&
op
->
type_key
==
"scope"
)
{
op
->
type_key
==
attr
::
loop_scope
)
{
if
(
op
->
node
==
s
chedule
->
attach_ivar
)
{
if
(
op
->
node
==
s
tage_
->
attach_ivar
)
{
CHECK
(
!
found_attach
);
CHECK
(
!
found_attach
);
found_attach
=
true
;
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
op
->
value
,
op
->
node
,
op
->
type_key
,
op
->
value
,
MakePipeline
(
schedule
,
dom_map
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
IRMutator
::
Mutate
(
op
->
body
)));
}
}
}
}
return
stmt
;
return
stmt
;
}
}
// whether attach point is found
bool
found_attach
{
false
};
private
:
// the operations to be carried
// the operations to be carried
Stage
schedule
;
const
Stage
&
stage_
;
// domain map
// domain map
Map
<
IterVar
,
Range
>
dom_map
;
const
Map
<
IterVar
,
Range
>&
dom_map_
;
};
// inject the operator's realization on the stmt.
class
InjectScanStep
:
public
IRMutator
{
public
:
InjectScanStep
(
const
Stage
&
stage
,
const
Operation
&
scan_op
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
bool
is_init
)
:
stage_
(
stage
),
scan_op_
(
scan_op
),
dom_map_
(
dom_map
),
is_init_
(
is_init
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
stmt
.
defined
());
stmt
=
IRMutator
::
Mutate
(
stmt
);
if
(
is_init_
)
{
const
ProducerConsumer
*
op
=
stmt
.
as
<
ProducerConsumer
>
();
if
(
op
!=
nullptr
&&
op
->
is_producer
&&
op
->
func
.
same_as
(
scan_op_
))
{
stmt
=
ProducerConsumer
::
make
(
op
->
func
,
true
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
found_attach
=
true
;
}
}
else
{
// update
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
!=
nullptr
&&
op
->
type_key
==
attr
::
scan_scope
)
{
if
(
op
->
node
.
same_as
(
scan_op_
))
{
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
}
}
}
return
stmt
;
}
// whether attach point is found
// whether attach point is found
bool
found_attach
{
false
};
bool
found_attach
{
false
};
private
:
// the operations to be carried
const
Stage
&
stage_
;
const
Operation
&
scan_op_
;
// domain map
const
Map
<
IterVar
,
Range
>&
dom_map_
;
// whether it is init.
bool
is_init_
;
};
};
Stmt
InjectInline
(
const
Operation
op
,
Stmt
body
)
{
Stmt
InjectInline
(
const
Operation
op
,
Stmt
body
)
{
...
@@ -459,27 +556,180 @@ Stmt InjectInline(const Operation op, Stmt body) {
...
@@ -459,27 +556,180 @@ Stmt InjectInline(const Operation op, Stmt body) {
return
Inline
(
body
,
op
,
args
,
compute
->
body
);
return
Inline
(
body
,
op
,
args
,
compute
->
body
);
}
}
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class
SchedulePostProc
:
public
IRMutator
{
public
:
Stmt
Mutate_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
s
)
final
{
if
(
to_remove_
.
count
(
op
->
func
.
get
()))
{
return
this
->
Mutate
(
op
->
body
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
LetStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
!
HasSideEffect
(
op
->
value
))
{
var_value_
[
op
->
var
.
get
()]
=
Mutate
(
op
->
value
);
return
this
->
Mutate
(
op
->
body
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type_key
==
attr
::
loop_scope
)
{
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type_key
==
attr
::
scan_scope
)
{
const
ScanOpNode
*
scan
=
op
->
node
.
as
<
ScanOpNode
>
();
CHECK
(
scan
);
var_value_
[
scan
->
scan_axis
->
var
.
get
()]
=
op
->
value
;
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type_key
==
ir
::
attr
::
realize_scope
)
{
if
(
to_remove_
.
count
(
op
->
node
.
get
()))
{
return
this
->
Mutate
(
op
->
body
);
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
Realize
*
op
,
const
Stmt
&
s
)
final
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
if
(
replace_
.
count
(
key
))
{
return
this
->
Mutate
(
op
->
body
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
auto
it
=
replace_
.
find
(
key
);
if
(
it
!=
replace_
.
end
())
{
const
Tensor
&
dst
=
it
->
second
.
first
;
Stmt
ret
=
Provide
::
make
(
dst
->
op
,
dst
->
value_index
,
op
->
value
,
RewriteArgs
(
it
->
second
.
second
,
op
->
args
));
return
IRMutator
::
Mutate_
(
ret
.
as
<
Provide
>
(),
ret
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
if
(
op
!=
nullptr
&&
op
->
call_type
==
Call
::
Halide
)
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
auto
it
=
replace_
.
find
(
key
);
if
(
it
!=
replace_
.
end
())
{
const
Tensor
&
dst
=
it
->
second
.
first
;
Expr
ret
=
Call
::
make
(
op
->
type
,
dst
->
op
->
name
,
RewriteArgs
(
it
->
second
.
second
,
op
->
args
),
op
->
call_type
,
dst
->
op
,
dst
->
value_index
);
return
IRMutator
::
Mutate_
(
ret
.
as
<
Call
>
(),
ret
);
}
}
return
IRMutator
::
Mutate_
(
op
,
e
);
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
var_value_
.
find
(
op
);
if
(
it
!=
var_value_
.
end
())
{
return
it
->
second
;
}
else
{
return
e
;
}
}
void
Init
(
const
Schedule
&
sch
)
{
for
(
Stage
s
:
sch
->
stages
)
{
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
if
(
!
scan
)
continue
;
for
(
size_t
i
=
0
;
i
<
scan
->
update
.
size
();
++
i
)
{
Tensor
t
=
s
->
op
.
output
(
i
);
AddReplace
(
scan
->
init
[
i
],
t
,
Expr
());
AddReplace
(
scan
->
update
[
i
],
t
,
scan
->
scan_axis
->
var
);
AddReplace
(
scan
->
state_placeholder
[
i
],
t
,
Expr
());
}
}
}
private
:
void
AddReplace
(
Tensor
src
,
Tensor
dst
,
Expr
head_idx
)
{
replace_
[
TensorKey
{
src
->
op
,
src
->
value_index
}]
=
std
::
make_pair
(
dst
,
head_idx
);
to_remove_
.
insert
(
src
->
op
.
get
());
}
Array
<
Expr
>
RewriteArgs
(
Expr
head
,
Array
<
Expr
>
args
)
{
if
(
!
head
.
defined
())
return
args
;
Array
<
Expr
>
new_args
{
head
};
for
(
Expr
e
:
args
)
{
new_args
.
push_back
(
e
);
}
return
new_args
;
}
// The scan value
std
::
unordered_map
<
const
Variable
*
,
Expr
>
var_value_
;
// buffer replacement
std
::
unordered_map
<
TensorKey
,
std
::
pair
<
Tensor
,
Expr
>
>
replace_
;
// replaced functions
std
::
unordered_set
<
const
Node
*>
to_remove_
;
};
Stmt
ScheduleOps
(
Stmt
ScheduleOps
(
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map
)
{
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map
)
{
Stmt
body
=
Stmt
();
Stmt
body
=
Stmt
();
// scan init and scan updates
std
::
unordered_map
<
Operation
,
std
::
pair
<
Operation
,
bool
>
>
scan_attach
;
for
(
Stage
s
:
sch
->
stages
)
{
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
if
(
!
scan
)
continue
;
for
(
Tensor
t
:
scan
->
init
)
{
if
(
scan_attach
.
count
(
t
->
op
))
{
CHECK
(
scan_attach
.
at
(
t
->
op
).
first
.
same_as
(
s
->
op
))
<<
"Scan init tensor can only belong to one scan"
;
}
else
{
scan_attach
[
t
->
op
]
=
std
::
make_pair
(
s
->
op
,
true
);
}
}
for
(
Tensor
t
:
scan
->
update
)
{
if
(
scan_attach
.
count
(
t
->
op
))
{
CHECK
(
scan_attach
.
at
(
t
->
op
).
first
.
same_as
(
s
->
op
))
<<
"Scan update tensor can only belong to one scan"
;
}
else
{
scan_attach
[
t
->
op
]
=
std
::
make_pair
(
s
->
op
,
false
);
}
}
}
// reverse the post DFS order.
// reverse the post DFS order.
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
Stage
s
=
sch
->
stages
[
i
-
1
];
Stage
s
=
sch
->
stages
[
i
-
1
];
// no need to specify place holder op.
// no need to specify place holder op.
if
(
s
->
op
.
as
<
PlaceholderOpNode
>
())
continue
;
if
(
s
->
op
.
as
<
PlaceholderOpNode
>
())
continue
;
if
(
s
->
attach_type
==
kInline
)
{
if
(
scan_attach
.
count
(
s
->
op
))
{
CHECK
(
s
->
attach_type
==
kNone
||
s
->
attach_type
==
kInline
)
<<
"Cannot specify compute_at for scan's init/update"
;
CHECK
(
body
.
defined
());
const
auto
&
p
=
scan_attach
.
at
(
s
->
op
);
InjectScanStep
mu
(
s
,
p
.
first
,
dom_map
,
p
.
second
);
body
=
mu
.
Mutate
(
body
);
CHECK
(
mu
.
found_attach
)
<<
"did not find attachment point for scan.init/update"
;
}
else
if
(
s
->
attach_type
==
kInline
)
{
body
=
InjectInline
(
s
->
op
,
body
);
body
=
InjectInline
(
s
->
op
,
body
);
}
else
if
(
s
->
attach_type
==
kRoot
||
s
->
attach_type
==
kNone
)
{
}
else
if
(
s
->
attach_type
==
kRoot
||
s
->
attach_type
==
kNone
)
{
body
=
MakePipeline
(
s
,
dom_map
,
body
);
body
=
MakePipeline
(
s
,
dom_map
,
body
);
}
else
if
(
s
->
attach_type
==
kScope
)
{
}
else
if
(
s
->
attach_type
==
kScope
)
{
CHECK
(
body
.
defined
());
CHECK
(
body
.
defined
());
Inject
Realize
mutator
(
s
,
dom_map
);
Inject
Attach
mutator
(
s
,
dom_map
);
body
=
mutator
.
Mutate
(
body
);
body
=
mutator
.
Mutate
(
body
);
CHECK
(
mutator
.
found_attach
)
CHECK
(
mutator
.
found_attach
)
<<
"did not find attachment point"
;
<<
"did not find attachment point"
;
}
}
}
}
return
body
;
SchedulePostProc
post_proc
;
post_proc
.
Init
(
sch
);
return
post_proc
.
Mutate
(
body
);
}
}
}
// namespace schedule
}
// namespace schedule
...
...
tests/python/integration/test_scan.py
0 → 100644
View file @
820a8597
import
tvm
import
numpy
as
np
def
test_scan
():
m
=
tvm
.
Var
(
"m"
)
n
=
tvm
.
Var
(
"n"
)
t
=
tvm
.
IterVar
((
1
,
m
),
name
=
"t"
)
X
=
tvm
.
placeholder
((
m
,
n
),
name
=
"X"
)
s_state
=
tvm
.
placeholder
((
m
,
n
))
s_init
=
tvm
.
compute
((
1
,
n
),
lambda
_
,
i
:
X
[
0
,
i
])
s_update
=
tvm
.
compute
((
n
,),
lambda
i
:
s_state
[
t
-
1
,
i
]
+
X
[
t
,
i
])
res
=
tvm
.
scan
(
t
,
s_init
,
s_update
,
s_state
)
# schedule
s
=
tvm
.
Schedule
(
res
.
op
)
num_thread
=
256
block_x
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.x"
)
thread_x
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.x"
)
_
,
x
=
s
[
s_init
]
.
split
(
s_init
.
op
.
axis
[
1
],
factor
=
num_thread
,
outer
=
block_x
)
_
,
x
=
s
[
s_init
]
.
split
(
x
,
outer
=
thread_x
)
_
,
x
=
s
[
s_update
]
.
split
(
s_update
.
op
.
axis
[
0
],
factor
=
num_thread
,
outer
=
block_x
)
_
,
x
=
s
[
s_update
]
.
split
(
x
,
outer
=
thread_x
)
# one line to build the function.
def
check_device
(
target
):
codes
=
[]
fscan
=
tvm
.
build
(
s
,
[
X
,
res
],
target
,
record_codes
=
codes
,
name
=
"myscan"
)
if
target
==
"cuda"
:
ctx
=
tvm
.
gpu
(
0
)
else
:
ctx
=
tvm
.
cl
(
0
)
if
not
ctx
.
enabled
:
return
for
c
in
codes
[
1
:]:
print
(
c
)
# launch the kernel.
n
=
1024
m
=
10
a_np
=
np
.
random
.
uniform
(
size
=
(
m
,
n
))
.
astype
(
res
.
dtype
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
((
m
,
n
),
dtype
=
res
.
dtype
),
ctx
)
fscan
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
np
.
cumsum
(
a_np
,
axis
=
0
))
tvm
.
init_opencl
()
check_device
(
"cuda"
)
if
__name__
==
"__main__"
:
test_scan
()
tests/python/unittest/test_lang_tensor.py
View file @
820a8597
...
@@ -34,6 +34,20 @@ def test_tensor_reduce():
...
@@ -34,6 +34,20 @@ def test_tensor_reduce():
assert
(
str
(
C_loaded
)
==
str
(
C
))
assert
(
str
(
C_loaded
)
==
str
(
C
))
def
test_tensor_scan
():
m
=
tvm
.
Var
(
"m"
)
n
=
tvm
.
Var
(
"n"
)
t
=
tvm
.
IterVar
((
1
,
m
),
"t"
)
x
=
tvm
.
placeholder
((
m
,
n
))
s
=
tvm
.
placeholder
((
m
,
n
))
res
=
tvm
.
scan
(
t
,
tvm
.
compute
((
1
,
n
),
lambda
_
,
i
:
x
[
0
,
i
]),
tvm
.
compute
((
n
,),
lambda
i
:
s
[
t
-
1
,
i
]
+
x
[
t
,
i
]),
s
)
assert
tuple
(
res
.
shape
)
==
(
m
,
n
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_tensor
()
test_tensor
()
test_tensor_reduce
()
test_tensor_reduce
()
test_tensor_scan
()
tests/python/unittest/test_pass_simplify.py
View file @
820a8597
...
@@ -18,9 +18,15 @@ def test_simplify():
...
@@ -18,9 +18,15 @@ def test_simplify():
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
+
4
)
+
1
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
+
4
)
+
1
,
(
j
+
1
)
*
4
-
4
*
j
+
i
),
(
j
+
1
)
*
4
-
4
*
j
+
i
),
None
)))
None
)))
print
(
stmt
)
stmt
=
tvm
.
ir_pass
.
CanonicalSimplify
(
stmt
)
stmt
=
tvm
.
ir_pass
.
CanonicalSimplify
(
stmt
)
print
(
stmt
)
def
test_basic
():
m
=
tvm
.
Var
(
'm'
)
ret
=
tvm
.
ir_pass
.
CanonicalSimplify
(
tvm
.
make
.
Evaluate
(
m
-
1
))
assert
str
(
ret
.
value
)
==
"(m - 1)"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_basic
()
test_simplify
()
test_simplify
()
tests/python/unittest/test_schedule_schedule_ops.py
View file @
820a8597
...
@@ -6,13 +6,11 @@ def test_schedule0():
...
@@ -6,13 +6,11 @@ def test_schedule0():
l
=
tvm
.
Var
(
'l'
)
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
A1
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A
[
i
,
j
],
name
=
'A1'
)
A1
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A
[
i
,
j
],
name
=
'A1'
)
s
=
tvm
.
Schedule
(
A1
.
op
)
s
=
tvm
.
Schedule
(
A1
.
op
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
def
test_schedule1
():
def
test_schedule1
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
@@ -25,7 +23,7 @@ def test_schedule1():
...
@@ -25,7 +23,7 @@ def test_schedule1():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
def
test_schedule2
():
def
test_schedule2
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
@@ -40,8 +38,28 @@ def test_schedule2():
...
@@ -40,8 +38,28 @@ def test_schedule2():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
test_schedule_scan
():
m
=
tvm
.
Var
(
"m"
)
n
=
tvm
.
Var
(
"n"
)
l
=
tvm
.
Var
(
"l"
)
t
=
tvm
.
IterVar
((
1
,
m
),
name
=
"t"
)
x
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
const
(
1
,
"float32"
),
name
=
"x"
)
s_state
=
tvm
.
placeholder
((
m
,
n
))
s_init
=
tvm
.
compute
((
1
,
n
),
lambda
_
,
i
:
x
[
0
,
i
])
s_update
=
tvm
.
compute
((
n
,),
lambda
i
:
s_state
[
t
-
1
,
i
]
+
x
[
t
,
i
])
res
=
tvm
.
scan
(
t
,
s_init
,
s_update
,
s_state
)
assert
tuple
(
res
.
shape
)
==
(
m
,
n
)
s
=
tvm
.
Schedule
(
res
.
op
)
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
res
.
op
.
scan_axis
]
.
min
.
value
==
1
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
print
(
stmt
)
def
test_auto_inline
():
def
test_auto_inline
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
n
=
tvm
.
Var
(
'n'
)
n
=
tvm
.
Var
(
'n'
)
...
@@ -55,10 +73,10 @@ def test_auto_inline():
...
@@ -55,10 +73,10 @@ def test_auto_inline():
tvm
.
schedule
.
AutoInlineElemWise
(
s
)
tvm
.
schedule
.
AutoInlineElemWise
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_schedule_scan
()
test_schedule0
()
test_schedule0
()
test_schedule1
()
test_schedule1
()
test_schedule2
()
test_schedule2
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment