Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b90620ea
Commit
b90620ea
authored
Oct 05, 2018
by
ziheng
Committed by
Tianqi Chen
Oct 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG] Generalize compute to tensor region (#1476)
parent
3d62cf7c
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1059 additions
and
131 deletions
+1059
-131
3rdparty/dmlc-core
+1
-1
include/tvm/expr.h
+2
-0
include/tvm/operation.h
+70
-2
include/tvm/tensor_intrin.h
+53
-0
python/tvm/api.py
+29
-10
python/tvm/tensor.py
+12
-0
python/tvm/tensor_intrin.py
+26
-2
src/api/api_lang.cc
+20
-0
src/lang/tensor.cc
+36
-8
src/op/compute_op.cc
+35
-0
src/op/compute_op.h
+16
-1
src/op/tensor_compute_op.cc
+361
-0
src/op/tensorize.cc
+0
-45
src/pass/arg_binder.cc
+3
-1
src/schedule/schedule_dataflow_rewrite.cc
+191
-61
tests/python/unittest/test_lang_tensor.py
+74
-0
tests/python/unittest/test_schedule_schedule_ops.py
+130
-0
No files found.
dmlc-core
@
946a5401
Subproject commit
4f0564ec769477c66d480dd966088f172050c874
Subproject commit
946a54012d0c390675ab5b46cd990838d4183d6f
include/tvm/expr.h
View file @
b90620ea
...
...
@@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range {
TVM_DLL
static
Range
make_by_min_extent
(
Expr
min
,
Expr
extent
);
};
using
Region
=
Array
<
Range
>
;
/*!
* \brief Type of iteration variable.
* Each IterVar have a specific type.
...
...
include/tvm/operation.h
View file @
b90620ea
...
...
@@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
}
/*!
* \return The list of iteration variable at root
* \note root_iter_vars de
d
ides the shape of the outputs.
* \note root_iter_vars de
c
ides the shape of the outputs.
*/
virtual
Array
<
IterVar
>
root_iter_vars
()
const
=
0
;
/*!
...
...
@@ -240,6 +240,74 @@ class TVM_DLL ComputeOpNode : public OperationNode {
};
/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
class
TensorComputeOpNode
:
public
OperationNode
{
public
:
/*! \brief IterVar on each axis */
Array
<
IterVar
>
axis
;
/*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
Array
<
IterVar
>
reduce_axis
;
/*! \brief number of axes that can be scheduled */
int
schedulable_ndim
;
/*! \brief TensorIntrin used to compute */
TensorIntrin
intrin
;
/*! \brief input tensors of intrin */
Array
<
Tensor
>
inputs
;
/*! \brief region of input tensors */
Array
<
Region
>
input_regions
;
/*! \brief constructor */
TensorComputeOpNode
()
{}
// override functions
int
num_outputs
()
const
final
;
Array
<
IterVar
>
root_iter_vars
()
const
final
;
Type
output_dtype
(
size_t
i
)
const
final
;
Array
<
Expr
>
output_shape
(
size_t
i
)
const
final
;
Array
<
Tensor
>
InputTensors
()
const
final
;
Operation
ReplaceInputs
(
const
Operation
&
self
,
const
std
::
unordered_map
<
Tensor
,
Tensor
>&
rmap
)
const
final
;
void
PropBoundToInputs
(
const
Operation
&
self
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
,
std
::
unordered_map
<
Tensor
,
TensorDom
>*
out_dom_map
)
const
final
;
void
GatherBound
(
const
Operation
&
self
,
const
std
::
unordered_map
<
Tensor
,
TensorDom
>&
tensor_dom
,
std
::
unordered_map
<
IterVar
,
Range
>*
out_dom_map
)
const
final
;
Stmt
BuildRealize
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
realize_map
,
const
Stmt
&
body
)
const
final
;
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
debug_keep_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"tag"
,
&
tag
);
v
->
Visit
(
"axis"
,
&
axis
);
v
->
Visit
(
"reduce_axis"
,
&
reduce_axis
);
v
->
Visit
(
"schedulable_ndim"
,
&
schedulable_ndim
);
v
->
Visit
(
"intrin"
,
&
intrin
);
v
->
Visit
(
"inputs"
,
&
inputs
);
v
->
Visit
(
"input_regions"
,
&
input_regions
);
}
static
Operation
make
(
std
::
string
name
,
std
::
string
tag
,
Array
<
IterVar
>
axis
,
Array
<
IterVar
>
reduce_axis
,
int
schedulable_ndim
,
TensorIntrin
intrin
,
Array
<
Tensor
>
tensors
,
Array
<
Region
>
regions
);
static
constexpr
const
char
*
_type_key
=
"TensorComputeOp"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TensorComputeOpNode
,
OperationNode
);
};
/*!
* \brief Symbolic scan.
*/
class
ScanOpNode
:
public
OperationNode
{
...
...
@@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode {
public
:
/*! \brief The input tensors */
Array
<
Tensor
>
inputs
;
/*! \brief Symbolic placeholder representationinputs */
/*! \brief Symbolic placeholder representation
of
inputs */
Array
<
Buffer
>
input_placeholders
;
/*! \brief Symbolic placeholder representation of outputs */
Array
<
Buffer
>
output_placeholders
;
...
...
include/tvm/tensor_intrin.h
View file @
b90620ea
...
...
@@ -89,5 +89,58 @@ class TensorIntrinNode : public Node {
inline
const
TensorIntrinNode
*
TensorIntrin
::
operator
->
()
const
{
return
static_cast
<
const
TensorIntrinNode
*>
(
node_
.
get
());
}
// Internal node container of tensor intrinsic calling.
class
TensorIntrinCallNode
;
/*! \brief Tensor intrinsic calling node. */
class
TensorIntrinCall
:
public
NodeRef
{
public
:
TensorIntrinCall
()
{}
explicit
TensorIntrinCall
(
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
TensorIntrinCallNode
*
operator
->
()
const
;
/*! \brief specify container node */
using
ContainerType
=
TensorIntrinCallNode
;
};
class
TensorIntrinCallNode
:
public
Node
{
public
:
/*! \brief the tensor intrinsic */
TensorIntrin
intrin
;
/*! \brief input tensors of the intrinsic */
Array
<
Tensor
>
tensors
;
/*! \brief regions of input tensors */
Array
<
Region
>
regions
;
/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array
<
IterVar
>
reduce_axis
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"intrin"
,
&
intrin
);
v
->
Visit
(
"tensors"
,
&
tensors
);
v
->
Visit
(
"regions"
,
&
regions
);
v
->
Visit
(
"reduce_axis"
,
&
reduce_axis
);
}
static
TensorIntrinCall
make
(
TensorIntrin
intrin
,
Array
<
Tensor
>
tensors
,
Array
<
Region
>
regions
,
Array
<
IterVar
>
reduce_axis
);
static
constexpr
const
char
*
_type_key
=
"TensorIntrinCall"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TensorIntrinCallNode
,
Node
);
};
inline
const
TensorIntrinCallNode
*
TensorIntrinCall
::
operator
->
()
const
{
return
static_cast
<
const
TensorIntrinCallNode
*>
(
node_
.
get
());
}
}
// namespace tvm
#endif // TVM_TENSOR_INTRIN_H_
python/tvm/api.py
View file @
b90620ea
...
...
@@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
raise
ValueError
(
"nested tag is not allowed for now"
)
tag
=
_tag
.
TagScope
.
get_current
()
.
tag
shape
=
(
shape
,)
if
isinstance
(
shape
,
_expr
.
Expr
)
else
shape
# for python3
shape
=
tuple
([
int
(
s
)
if
isinstance
(
s
,
float
)
else
s
for
s
in
shape
])
ndim
=
len
(
shape
)
code
=
fcompute
.
__code__
if
fcompute
.
__code__
.
co_argcount
==
0
:
out_ndim
=
ndim
if
code
.
co_argcount
==
0
:
arg_names
=
[
"i
%
d"
%
i
for
i
in
range
(
ndim
)]
else
:
arg_names
=
code
.
co_varnames
[:
code
.
co_argcount
]
out_ndim
=
code
.
co_argcount
if
ndim
!=
len
(
arg_names
):
if
out_
ndim
!=
len
(
arg_names
):
raise
ValueError
(
"fcompute do not match dimension, ndim=
%
d"
%
ndim
)
dim_var
=
[
_IterVar
((
0
,
s
),
x
,
0
)
for
x
,
s
in
zip
(
arg_names
,
shape
)]
dim_var
=
[
_IterVar
((
0
,
s
),
x
,
0
)
for
x
,
s
in
zip
(
arg_names
,
shape
[:
out_ndim
]
)]
body
=
fcompute
(
*
[
v
.
var
for
v
in
dim_var
])
if
not
isinstance
(
body
,
(
list
,
tuple
)):
body
=
[
body
]
body
=
convert
(
body
)
op_node
=
_api_internal
.
_ComputeOp
(
name
,
tag
,
attrs
,
dim_var
,
body
)
if
isinstance
(
body
,
_tensor
.
TensorIntrinCall
):
for
i
,
s
in
enumerate
(
shape
[
out_ndim
:]):
var_name
=
"ax"
+
str
(
i
)
dim_var
.
append
(
_IterVar
((
0
,
s
),
var_name
,
4
))
op_node
=
_api_internal
.
_TensorComputeOp
(
name
,
tag
,
dim_var
,
body
.
reduce_axis
,
out_ndim
,
body
.
intrin
,
body
.
tensors
,
body
.
regions
)
else
:
if
not
isinstance
(
body
,
(
list
,
tuple
)):
body
=
[
body
]
body
=
convert
(
body
)
op_node
=
_api_internal
.
_ComputeOp
(
name
,
tag
,
attrs
,
dim_var
,
body
)
num
=
op_node
.
num_outputs
outputs
=
tuple
(
op_node
.
output
(
i
)
for
i
in
range
(
num
))
return
outputs
[
0
]
if
num
==
1
else
outputs
...
...
@@ -529,14 +548,14 @@ def decl_buffer(shape,
dtype
=
float32
if
dtype
is
None
else
dtype
strides
=
()
if
strides
is
None
else
strides
if
offset_factor
!=
0
and
elem_offset
is
None
:
elem_offset
=
var
(
'
%
s_elem_offset'
%
name
,
shape
[
0
]
.
dtype
)
shape_dtype
=
shape
[
0
]
.
dtype
if
hasattr
(
shape
[
0
],
"dtype"
)
else
"int32"
elem_offset
=
var
(
'
%
s_elem_offset'
%
name
,
shape_dtype
)
if
data
is
None
:
data
=
var
(
name
,
"handle"
)
return
_api_internal
.
_Buffer
(
data
,
dtype
,
shape
,
strides
,
elem_offset
,
name
,
scope
,
data_alignment
,
offset_factor
)
def
_IterVar
(
dom
,
name
,
iter_type
,
thread_tag
=
''
):
"""Internal function to create IterVar
...
...
python/tvm/tensor.py
View file @
b90620ea
...
...
@@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Data content of the tensor."""
return
self
.
tensor
.
dtype
@register_node
class
TensorIntrinCall
(
NodeBase
):
"""Intermediate structure for calling a tensor intrinsic."""
pass
itervar_cls
=
None
...
...
@@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp):
return
"
%
s.v
%
d"
%
(
op
.
name
,
self
.
value_index
)
class
Operation
(
NodeBase
):
"""Represent an operation that generate a tensor"""
...
...
@@ -156,6 +162,12 @@ class ComputeOp(Operation):
@register_node
class
TensorComputeOp
(
Operation
):
"""Tensor operation."""
pass
@register_node
class
ScanOp
(
Operation
):
"""Scan operation."""
@property
...
...
python/tvm/tensor_intrin.py
View file @
b90620ea
...
...
@@ -6,9 +6,25 @@ from . import expr as _expr
from
.
import
stmt
as
_stmt
from
.
import
make
as
_make
from
.
import
tensor
as
_tensor
from
.
import
schedule
as
_schedule
from
.build_module
import
current_build_config
from
._ffi.node
import
NodeBase
,
register_node
def
_get_region
(
tslice
):
region
=
[]
for
idx
in
tslice
.
indices
:
if
isinstance
(
idx
,
slice
):
assert
idx
.
step
is
None
region
.
append
(
_api
.
Range
(
idx
.
start
,
idx
.
stop
))
else
:
if
isinstance
(
idx
,
_schedule
.
IterVar
):
begin
=
idx
.
var
else
:
begin
=
idx
region
.
append
(
_make
.
range_by_min_extent
(
begin
,
1
))
return
region
@register_node
class
TensorIntrin
(
NodeBase
):
"""Tensor intrinsic functions for certain computation.
...
...
@@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
--------
decl_tensor_intrin: Construct a TensorIntrin
"""
pass
def
__call__
(
self
,
*
args
,
**
kwargs
):
tensors
=
[
x
.
tensor
for
x
in
args
]
regions
=
[
_get_region
(
x
)
for
x
in
args
]
reduce_axis
=
[]
if
"reduce_axis"
in
kwargs
:
reduce_axis
=
kwargs
[
"reduce_axis"
]
if
not
isinstance
(
reduce_axis
,
(
list
,
tuple
)):
reduce_axis
=
[
reduce_axis
]
reduce_axis
=
_api
.
convert
(
reduce_axis
)
return
_api_internal
.
_TensorIntrinCall
(
self
,
tensors
,
regions
,
reduce_axis
)
def
decl_tensor_intrin
(
op
,
fcompute
,
...
...
src/api/api_lang.cc
View file @
b90620ea
...
...
@@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
args
[
6
]);
});
TVM_REGISTER_API
(
"_TensorIntrinCall"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TensorIntrinCallNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_API
(
"_TensorEqual"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Tensor
()
==
args
[
1
].
operator
Tensor
();
...
...
@@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp")
args
[
7
]);
});
TVM_REGISTER_API
(
"_TensorComputeOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TensorComputeOpNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
args
[
6
],
args
[
7
]);
});
TVM_REGISTER_API
(
"_ExternOp"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ExternOpNode
::
make
(
args
[
0
],
...
...
src/lang/tensor.cc
View file @
b90620ea
...
...
@@ -10,6 +10,8 @@
namespace
tvm
{
// Tensor
Expr
Tensor
::
operator
()(
Array
<
Var
>
indices
)
const
{
Array
<
Expr
>
arr
(
indices
.
begin
(),
indices
.
end
());
return
operator
()(
arr
);
...
...
@@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const {
return
n
;
}
Tensor
Operation
::
output
(
size_t
i
)
const
{
auto
node
=
make_node
<
TensorNode
>
();
node
->
op
=
*
this
;
node
->
value_index
=
i
;
node
->
dtype
=
(
*
this
)
->
output_dtype
(
i
);
node
->
shape
=
(
*
this
)
->
output_shape
(
i
);
return
Tensor
(
node
);
}
Tensor
TensorNode
::
make
(
Array
<
Expr
>
shape
,
Type
dtype
,
Operation
op
,
...
...
@@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE
(
TensorNode
);
Tensor
Operation
::
output
(
size_t
i
)
const
{
auto
node
=
make_node
<
TensorNode
>
();
node
->
op
=
*
this
;
node
->
value_index
=
i
;
node
->
dtype
=
(
*
this
)
->
output_dtype
(
i
);
node
->
shape
=
(
*
this
)
->
output_shape
(
i
);
return
Tensor
(
node
);
}
// TensorIntrin
TensorIntrin
TensorIntrinNode
::
make
(
std
::
string
name
,
Operation
op
,
...
...
@@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
TVM_REGISTER_NODE_TYPE
(
TensorIntrinNode
);
// TensorIntrinCall
TensorIntrinCall
TensorIntrinCallNode
::
make
(
TensorIntrin
intrin
,
Array
<
Tensor
>
tensors
,
Array
<
Region
>
regions
,
Array
<
IterVar
>
reduce_axis
)
{
auto
n
=
make_node
<
TensorIntrinCallNode
>
();
n
->
intrin
=
std
::
move
(
intrin
);
n
->
tensors
=
std
::
move
(
tensors
);
n
->
regions
=
std
::
move
(
regions
);
n
->
reduce_axis
=
std
::
move
(
reduce_axis
);
return
TensorIntrinCall
(
n
);
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TensorIntrinCallNode
>
([](
const
TensorIntrinCallNode
*
n
,
IRPrinter
*
p
)
{
p
->
stream
<<
"TensorIntrinCall(intrin="
<<
n
->
intrin
<<
", "
<<
n
<<
")"
;
});
TVM_REGISTER_NODE_TYPE
(
TensorIntrinCallNode
);
}
// namespace tvm
src/op/compute_op.cc
View file @
b90620ea
...
...
@@ -13,6 +13,7 @@
#include "compute_op.h"
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
...
...
@@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) {
v
.
Run
();
}
Stmt
TransformUpdate
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
ComputeLoopNest
&
n
,
Stmt
body
,
Stmt
update
)
{
Array
<
Expr
>
conds
;
std
::
unordered_set
<
const
Variable
*>
banned
;
for
(
size_t
i
=
0
;
i
<
stage
->
leaf_iter_vars
.
size
();
++
i
)
{
IterVar
iv
=
stage
->
leaf_iter_vars
[
i
];
auto
iit
=
stage
->
iter_var_attrs
.
find
(
iv
);
if
(
iit
!=
stage
->
iter_var_attrs
.
end
())
{
const
IterVarAttr
&
attr
=
(
*
iit
).
second
;
if
(
attr
->
iter_type
==
kTensorized
)
{
break
;
}
}
if
(
iv
->
iter_type
==
kCommReduce
)
{
auto
vit
=
dom_map
.
find
(
iv
);
CHECK
(
vit
!=
dom_map
.
end
());
const
Range
&
vrange
=
vit
->
second
;
conds
.
push_back
(
likely
(
iv
->
var
>
vrange
->
min
));
banned
.
insert
(
iv
->
var
.
get
());
}
}
for
(
const
Expr
&
pred
:
n
.
main_predicates
)
{
if
(
ir
::
ExprUseVar
(
pred
,
banned
))
{
LOG
(
FATAL
)
<<
"Tensorize update transform failed, the condition "
<<
pred
<<
" has a conflict with the reset condition"
;
}
}
return
IfThenElse
::
make
(
arith
::
ComputeReduce
<
ir
::
Or
>
(
conds
,
const_true
(
1
)),
update
,
body
);
}
}
// namespace tvm
src/op/compute_op.h
View file @
b90620ea
...
...
@@ -14,7 +14,7 @@
namespace
tvm
{
// loop nest structure for general compute
// This the
the
loop nest structured used in compute.
// This the loop nest structured used in compute.
// Does not include the loop body.
struct
ComputeLoopNest
{
// The common number of loops between init and main
...
...
@@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
debug_keep_trivial_loop
);
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param stage The stage for tensorizing.
* \param dom_map The range of each iter var.
* \param n The loop nest structured used in compute.
* \param body The body func in tensorize intrin
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
Stmt
TransformUpdate
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
ComputeLoopNest
&
n
,
Stmt
body
,
Stmt
update
);
}
// namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
src/op/tensor_compute_op.cc
0 → 100644
View file @
b90620ea
/*!
* Copyright (c) 2017 by Contributors
* \brief Tensor Compute Op.
* \file tensor_compute_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./op_util.h"
#include "./compute_op.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
using
namespace
ir
;
// TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TensorComputeOpNode
>
([](
const
TensorComputeOpNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"tensor_compute_op("
<<
op
->
name
<<
", "
<<
op
<<
")"
;
});
TVM_REGISTER_NODE_TYPE
(
TensorComputeOpNode
);
int
TensorComputeOpNode
::
num_outputs
()
const
{
return
static_cast
<
int
>
(
this
->
intrin
->
buffers
.
size
()
-
this
->
inputs
.
size
());
}
Array
<
IterVar
>
TensorComputeOpNode
::
root_iter_vars
()
const
{
Array
<
IterVar
>
ret
=
axis
;
for
(
IterVar
iv
:
reduce_axis
)
{
ret
.
push_back
(
iv
);
}
return
ret
;
}
Type
TensorComputeOpNode
::
output_dtype
(
size_t
i
)
const
{
return
this
->
intrin
->
buffers
[
this
->
inputs
.
size
()
+
i
]
->
dtype
;
}
Array
<
Expr
>
TensorComputeOpNode
::
output_shape
(
size_t
i
)
const
{
Array
<
Expr
>
shape
;
for
(
const
auto
&
ivar
:
this
->
axis
)
{
shape
.
push_back
(
ivar
->
dom
->
extent
);
}
return
shape
;
}
Operation
TensorComputeOpNode
::
make
(
std
::
string
name
,
std
::
string
tag
,
Array
<
IterVar
>
axis
,
Array
<
IterVar
>
reduce_axis
,
int
schedulable_ndim
,
TensorIntrin
intrin
,
Array
<
Tensor
>
tensors
,
Array
<
Region
>
regions
)
{
auto
n
=
make_node
<
TensorComputeOpNode
>
();
n
->
name
=
std
::
move
(
name
);
n
->
tag
=
std
::
move
(
tag
);
n
->
axis
=
std
::
move
(
axis
);
n
->
reduce_axis
=
std
::
move
(
reduce_axis
);
n
->
schedulable_ndim
=
std
::
move
(
schedulable_ndim
);
n
->
intrin
=
std
::
move
(
intrin
);
n
->
inputs
=
std
::
move
(
tensors
);
n
->
input_regions
=
std
::
move
(
regions
);
return
Operation
(
n
);
}
Array
<
Tensor
>
TensorComputeOpNode
::
InputTensors
()
const
{
return
inputs
;
}
Operation
TensorComputeOpNode
::
ReplaceInputs
(
const
Operation
&
self
,
const
std
::
unordered_map
<
Tensor
,
Tensor
>&
rmap
)
const
{
CHECK_EQ
(
self
.
operator
->
(),
this
);
auto
n
=
make_node
<
TensorComputeOpNode
>
(
*
this
);
auto
intrin
=
make_node
<
TensorIntrinNode
>
(
*
(
this
->
intrin
.
operator
->
()));
intrin
->
body
=
op
::
ReplaceTensor
(
this
->
intrin
->
body
,
rmap
);
if
(
intrin
->
reduce_init
.
defined
())
{
intrin
->
reduce_init
=
op
::
ReplaceTensor
(
this
->
intrin
->
reduce_init
,
rmap
);
}
if
(
intrin
->
reduce_update
.
defined
())
{
intrin
->
reduce_update
=
op
::
ReplaceTensor
(
this
->
intrin
->
reduce_update
,
rmap
);
}
for
(
size_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
Tensor
t
=
n
->
inputs
[
i
];
if
(
rmap
.
count
(
t
))
{
n
->
inputs
.
Set
(
i
,
rmap
.
at
(
t
));
}
}
if
(
intrin
->
body
.
same_as
(
n
->
intrin
->
body
)
&&
intrin
->
reduce_init
.
same_as
(
n
->
intrin
->
reduce_init
)
&&
intrin
->
reduce_update
.
same_as
(
n
->
intrin
->
reduce_update
)
&&
inputs
.
same_as
(
n
->
inputs
))
{
return
self
;
}
else
{
n
->
intrin
=
TensorIntrin
(
intrin
);
return
Operation
(
n
);
}
}
void
TensorComputeOpNode
::
PropBoundToInputs
(
const
Operation
&
self
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
,
std
::
unordered_map
<
Tensor
,
TensorDom
>*
out_dom_map
)
const
{
for
(
size_t
i
=
0
;
i
<
this
->
inputs
.
size
();
++
i
)
{
Tensor
t
=
this
->
inputs
[
i
];
Region
region
=
input_regions
[
i
];
auto
it
=
out_dom_map
->
find
(
t
);
if
(
it
==
out_dom_map
->
end
())
continue
;
TensorDom
&
dom
=
it
->
second
;
for
(
size_t
j
=
0
;
j
<
t
.
ndim
();
++
j
)
{
dom
.
data
[
j
].
emplace_back
(
EvalSet
(
region
[
j
],
dom_map
));
}
}
}
void
TensorComputeOpNode
::
GatherBound
(
const
Operation
&
self
,
const
std
::
unordered_map
<
Tensor
,
TensorDom
>&
tensor_dom
,
std
::
unordered_map
<
IterVar
,
Range
>*
out_dom_map
)
const
{
const
TensorDom
&
tdom
=
tensor_dom
.
at
(
self
.
output
(
0
));
for
(
size_t
i
=
0
;
i
<
this
->
axis
.
size
();
++
i
)
{
Range
r
=
arith
::
Union
(
tdom
.
data
.
at
(
i
)).
cover_range
(
this
->
axis
[
i
]
->
dom
);
CHECK
(
!
out_dom_map
->
count
(
this
->
axis
[
i
]));
(
*
out_dom_map
)[
this
->
axis
[
i
]]
=
r
;
}
for
(
size_t
i
=
0
;
i
<
this
->
reduce_axis
.
size
();
++
i
)
{
CHECK
(
!
out_dom_map
->
count
(
this
->
reduce_axis
[
i
]));
(
*
out_dom_map
)[
this
->
reduce_axis
[
i
]]
=
this
->
reduce_axis
[
i
]
->
dom
;
}
}
Stmt
TensorComputeOpNode
::
BuildRealize
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
realize_map
,
const
Stmt
&
body
)
const
{
CHECK_EQ
(
stage
->
op
.
get
(),
this
);
HalideIR
::
Internal
::
Region
bounds
;
for
(
IterVar
iv
:
this
->
axis
)
{
bounds
.
push_back
(
realize_map
.
at
(
iv
));
}
Stmt
realize
=
body
;
for
(
int
i
=
this
->
num_outputs
();
i
>
0
;
--
i
)
{
Tensor
t
=
stage
->
op
.
output
(
i
-
1
);
realize
=
ir
::
Realize
::
make
(
t
->
op
,
t
->
value_index
,
t
->
dtype
,
bounds
,
const_true
(),
realize
);
// alignment requirement, only useful for compute
for
(
int
i
=
0
;
i
<
schedulable_ndim
;
++
i
)
{
auto
it
=
stage
->
iter_var_attrs
.
find
(
this
->
axis
[
i
]);
if
(
it
!=
stage
->
iter_var_attrs
.
end
())
{
IterVarAttr
attr
=
(
*
it
).
second
;
if
(
attr
->
dim_align_factor
!=
0
)
{
Array
<
Expr
>
tuple
=
{
static_cast
<
int
>
(
i
),
attr
->
dim_align_factor
,
attr
->
dim_align_offset
};
realize
=
ir
::
AttrStmt
::
make
(
t
,
ir
::
attr
::
buffer_dim_align
,
Call
::
make
(
Handle
(),
ir
::
intrinsic
::
tvm_tuple
,
tuple
,
Call
::
Intrinsic
),
realize
);
}
}
}
}
return
realize
;
}
ComputeLoopNest
MakeLoopNest
(
const
TensorComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
debug_keep_trivial_loop
)
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
self
);
ComputeLoopNest
ret
;
// make main loop nest
ret
.
main_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
,
debug_keep_trivial_loop
);
ret
.
main_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
main_vmap
,
false
,
std
::
unordered_set
<
IterVar
>
());
for
(
auto
&
e
:
ret
.
main_predicates
)
{
e
=
likely
(
e
);
}
if
(
stage
->
store_predicate
.
defined
())
{
ret
.
main_predicates
.
push_back
(
stage
->
store_predicate
);
}
if
(
self
->
reduce_axis
.
size
()
!=
0
)
{
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std
::
unordered_map
<
IterVar
,
int
>
update_state
;
for
(
IterVar
iv
:
self
->
reduce_axis
)
{
update_state
[
iv
]
=
2
;
}
for
(
int
i
=
0
;
i
<
self
->
schedulable_ndim
;
++
i
)
{
update_state
[
self
->
axis
[
i
]]
=
1
;
}
// find which iter var is related to reduction and which is related to axis.
schedule
::
PassDownBitMaskOr
(
stage
,
&
update_state
);
auto
leaf_iter_vars
=
stage
->
leaf_iter_vars
;
// first first loop that is related to reduction.
size_t
begin_loop
=
leaf_iter_vars
.
size
();
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
auto
iv
=
leaf_iter_vars
[
i
];
int
flag
=
update_state
.
at
(
iv
);
if
((
flag
&
2
)
!=
0
)
{
begin_loop
=
i
;
break
;
}
ret
.
init_vmap
[
iv
]
=
ret
.
main_vmap
.
at
(
iv
);
}
ret
.
num_common_loop
=
begin_loop
;
// skip loops that does not relates to axis.
std
::
unordered_set
<
IterVar
>
skip_iter
;
for
(
auto
kv
:
update_state
)
{
int
flag
=
kv
.
second
;
if
((
flag
&
1
)
==
0
)
skip_iter
.
insert
(
kv
.
first
);
}
ret
.
init_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
begin_loop
,
true
,
skip_iter
,
&
(
ret
.
init_vmap
),
debug_keep_trivial_loop
);
ret
.
init_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
init_vmap
,
true
,
skip_iter
);
for
(
auto
&
e
:
ret
.
init_predicates
)
{
e
=
likely
(
e
);
}
}
else
{
CHECK_EQ
(
ret
.
main_nest
.
size
(),
stage
->
leaf_iter_vars
.
size
()
+
1
);
ret
.
num_common_loop
=
stage
->
leaf_iter_vars
.
size
();
}
// copy elison here.
return
ret
;
}
Stmt
TensorComputeOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
debug_keep_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
// Start bind data.
Stmt
nop
=
Evaluate
::
make
(
0
);
std
::
vector
<
Stmt
>
input_bind_nest
,
output_bind_nest
;
Array
<
Tensor
>
inputs
=
this
->
InputTensors
();
// input binding
size_t
num_inputs
=
inputs
.
size
();
for
(
size_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
Tensor
tensor
=
inputs
[
i
];
Region
region
=
this
->
input_regions
[
i
];
Buffer
buffer
=
this
->
intrin
->
buffers
[
i
];
Array
<
NodeRef
>
bind_spec
{
buffer
,
tensor
};
Array
<
Expr
>
tuple
;
for
(
size_t
i
=
0
;
i
<
region
.
size
();
++
i
)
{
tuple
.
push_back
(
region
[
i
]
->
min
);
tuple
.
push_back
(
region
[
i
]
->
extent
);
}
input_bind_nest
.
emplace_back
(
AttrStmt
::
make
(
bind_spec
,
ir
::
attr
::
buffer_bind_scope
,
Call
::
make
(
Handle
(),
ir
::
intrinsic
::
tvm_tuple
,
tuple
,
Call
::
Intrinsic
),
nop
));
}
// output binding
for
(
int
i
=
0
;
i
<
this
->
num_outputs
();
++
i
)
{
Tensor
tensor
=
stage
->
op
.
output
(
i
);
Buffer
buffer
=
this
->
intrin
->
buffers
[
num_inputs
+
i
];
Array
<
NodeRef
>
bind_spec
{
buffer
,
tensor
};
Array
<
Expr
>
tuple
;
for
(
size_t
i
=
0
;
i
<
this
->
axis
.
size
();
++
i
)
{
auto
ivar
=
this
->
axis
[
i
];
if
(
i
<
static_cast
<
size_t
>
(
this
->
schedulable_ndim
))
{
tuple
.
push_back
(
ivar
->
var
);
tuple
.
push_back
(
1
);
}
else
{
Range
dom
=
ivar
->
dom
;
tuple
.
push_back
(
dom
->
min
);
tuple
.
push_back
(
dom
->
extent
);
}
}
output_bind_nest
.
emplace_back
(
AttrStmt
::
make
(
bind_spec
,
ir
::
attr
::
buffer_bind_scope
,
Call
::
make
(
Handle
(),
ir
::
intrinsic
::
tvm_tuple
,
tuple
,
Call
::
Intrinsic
),
nop
));
}
// Check variable remap
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vmap
;
ir
::
ArgBinder
binder
(
&
vmap
);
size_t
tloc
=
stage
->
leaf_iter_vars
.
size
();
ComputeLoopNest
n
=
MakeLoopNest
(
this
,
stage
,
dom_map
,
debug_keep_trivial_loop
);
if
(
this
->
reduce_axis
.
size
()
==
0
)
{
std
::
vector
<
std
::
vector
<
Stmt
>
>
nest
(
n
.
main_nest
.
begin
(),
n
.
main_nest
.
begin
()
+
tloc
+
1
);
nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
main_predicates
));
CHECK_EQ
(
n
.
init_predicates
.
size
(),
0U
);
CHECK
(
this
->
intrin
->
body
.
defined
())
<<
"Normal store op for intrin "
<<
this
<<
" is not defined"
;
Stmt
body
=
MergeNest
(
output_bind_nest
,
this
->
intrin
->
body
);
body
=
MergeNest
(
input_bind_nest
,
body
);
body
=
ir
::
Substitute
(
body
,
vmap
);
body
=
MergeNest
(
binder
.
asserts
(),
body
);
body
=
op
::
Substitute
(
body
,
n
.
main_vmap
);
Stmt
ret
=
MergeNest
(
nest
,
body
);
return
ret
;
}
else
{
// Need to split reduction
CHECK
(
this
->
intrin
->
reduce_update
.
defined
())
<<
"Reduction update op is not defined"
;
// Need init and update steps
CHECK_NE
(
this
->
reduce_axis
.
size
(),
0U
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
common
(
n
.
main_nest
.
begin
(),
n
.
main_nest
.
begin
()
+
n
.
num_common_loop
+
1
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
update_nest
(
n
.
main_nest
.
begin
()
+
n
.
num_common_loop
+
1
,
n
.
main_nest
.
begin
()
+
tloc
+
1
);
update_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
main_predicates
));
if
(
this
->
intrin
->
reduce_init
.
defined
())
{
// init nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
init_nest
(
n
.
init_nest
.
begin
(),
n
.
init_nest
.
begin
()
+
tloc
+
1
);
init_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
init_predicates
));
Stmt
init
=
MergeNest
(
output_bind_nest
,
this
->
intrin
->
reduce_init
);
init
=
op
::
Substitute
(
init
,
n
.
init_vmap
);
init
=
MergeNest
(
init_nest
,
init
);
// The update
Stmt
update
=
MergeNest
(
output_bind_nest
,
this
->
intrin
->
reduce_update
);
update
=
MergeNest
(
input_bind_nest
,
update
);
update
=
ir
::
Substitute
(
update
,
vmap
);
update
=
MergeNest
(
binder
.
asserts
(),
update
);
update
=
op
::
Substitute
(
update
,
n
.
main_vmap
);
update
=
MergeNest
(
update_nest
,
update
);
return
MergeNest
(
common
,
Block
::
make
(
init
,
update
));
}
else
{
// When init op is not available, use body op for reset in the first iter.
CHECK
(
this
->
intrin
->
body
.
defined
())
<<
"Normal body op is not defined"
;
Stmt
update
=
TransformUpdate
(
stage
,
dom_map
,
n
,
this
->
intrin
->
body
,
this
->
intrin
->
reduce_update
);
update
=
MergeNest
(
output_bind_nest
,
update
);
update
=
MergeNest
(
input_bind_nest
,
update
);
update
=
ir
::
Substitute
(
update
,
vmap
);
update
=
MergeNest
(
binder
.
asserts
(),
update
);
update
=
op
::
Substitute
(
update
,
n
.
main_vmap
);
update
=
MergeNest
(
update_nest
,
update
);
return
MergeNest
(
common
,
update
);
}
}
}
}
// namespace tvm
src/op/tensorize.cc
View file @
b90620ea
...
...
@@ -10,7 +10,6 @@
#include "op_util.h"
#include "compute_op.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
...
...
@@ -323,50 +322,6 @@ void VerifyTensorizeBody(
}
}
/*!
* \brief Transform the update part when there is no init func in tensorizing
* \param stage The stage for tensorizing.
* \param dom_map The range of each iter var.
* \param n The loop nest structured used in compute.
* \param body The body func in tensorize intrin
* \param update The update func in tensorize intrin
* \return Transformed result.
*/
Stmt
TransformUpdate
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
ComputeLoopNest
&
n
,
Stmt
body
,
Stmt
update
)
{
Array
<
Expr
>
conds
;
std
::
unordered_set
<
const
Variable
*>
banned
;
for
(
size_t
i
=
0
;
i
<
stage
->
leaf_iter_vars
.
size
();
++
i
)
{
IterVar
iv
=
stage
->
leaf_iter_vars
[
i
];
auto
iit
=
stage
->
iter_var_attrs
.
find
(
iv
);
if
(
iit
!=
stage
->
iter_var_attrs
.
end
())
{
const
IterVarAttr
&
attr
=
(
*
iit
).
second
;
if
(
attr
->
iter_type
==
kTensorized
)
{
break
;
}
}
if
(
iv
->
iter_type
==
kCommReduce
)
{
auto
vit
=
dom_map
.
find
(
iv
);
CHECK
(
vit
!=
dom_map
.
end
());
const
Range
&
vrange
=
vit
->
second
;
conds
.
push_back
(
likely
(
iv
->
var
>
vrange
->
min
));
banned
.
insert
(
iv
->
var
.
get
());
}
}
for
(
const
Expr
&
pred
:
n
.
main_predicates
)
{
if
(
ir
::
ExprUseVar
(
pred
,
banned
))
{
LOG
(
FATAL
)
<<
"Tensorize update transform failed, the condition "
<<
pred
<<
" has a conflict with the reset condition"
;
}
}
return
IfThenElse
::
make
(
arith
::
ComputeReduce
<
ir
::
Or
>
(
conds
,
const_true
(
1
)),
update
,
body
);
}
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
...
...
src/pass/arg_binder.cc
View file @
b90620ea
...
...
@@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg,
// bind pointer and offset.
if
(
is_zero
(
arg
->
elem_offset
))
{
CHECK
(
is_zero
(
value
->
elem_offset
))
<<
"Trying to bind a Buffer with offset into one without offset"
;
<<
"Trying to bind a Buffer with offset into one without offset "
<<
" required elem_offset="
<<
arg
->
elem_offset
<<
", provided elem_offset="
<<
value
->
elem_offset
;
}
this
->
Bind
(
arg
->
data
,
value
->
data
,
arg_name
+
".data"
);
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
b90620ea
...
...
@@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor,
return
cache
;
}
// Cache write and relayout the data according to loop pattern
Array
<
Tensor
>
CacheWriteWithReLayout
(
Schedule
sch
,
const
Array
<
Tensor
>&
tensor_array
,
const
std
::
string
&
scope
)
{
size_t
tensor_size
=
tensor_array
.
size
();
sch
->
InvalidateCache
();
Tensor
tensor
=
tensor_array
[
0
];
Stage
orig_stage
=
sch
[
tensor
->
op
];
const
ComputeOpNode
*
compute
=
orig_stage
->
op
.
as
<
ComputeOpNode
>
();
std
::
unordered_set
<
IterVar
>
red_axis
;
for
(
IterVar
iv
:
compute
->
reduce_axis
)
{
template
<
typename
OpType
>
void
PrepareAxisMapping
(
Stage
orig_stage
,
OpType
*
op
,
std
::
unordered_set
<
IterVar
>*
p_red_axis
,
Array
<
IterVar
>*
p_new_axis
,
std
::
unordered_map
<
IterVar
,
Range
>*
p_dom_map
,
std
::
unordered_map
<
const
Variable
*
,
Expr
>*
p_vsub
,
std
::
unordered_map
<
const
Variable
*
,
Expr
>*
p_vsub2newvar
,
std
::
vector
<
Expr
>*
p_predicates
)
{
auto
&
red_axis
=
*
p_red_axis
;
auto
&
new_axis
=
*
p_new_axis
;
auto
&
dom_map
=
*
p_dom_map
;
auto
&
vsub
=
*
p_vsub
;
auto
&
vsub2newvar
=
*
p_vsub2newvar
;
auto
&
predicates
=
*
p_predicates
;
for
(
IterVar
iv
:
op
->
reduce_axis
)
{
red_axis
.
insert
(
iv
);
}
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
;
Array
<
IterVar
>
new_axis
;
for
(
IterVar
iv
:
compute
->
axis
)
{
for
(
IterVar
iv
:
op
->
axis
)
{
dom_map
[
iv
]
=
iv
->
dom
;
}
schedule
::
PassDownDomain
(
orig_stage
,
&
dom_map
,
true
);
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub2newvar
;
std
::
vector
<
Expr
>
predicates
;
{
// The source->cache
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
...
...
@@ -178,17 +178,85 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
}
// skip reduction iteration.
std
::
unordered_set
<
IterVar
>
skip_bound_check
;
for
(
IterVar
iv
:
compute
->
reduce_axis
)
{
for
(
IterVar
iv
:
op
->
reduce_axis
)
{
skip_bound_check
.
insert
(
iv
);
}
schedule
::
PassUpIndex
(
orig_stage
,
dom_map
,
&
value_map
,
true
);
predicates
=
schedule
::
MakeBoundCheck
(
orig_stage
,
dom_map
,
value_map
,
true
,
skip_bound_check
);
// The root axis
for
(
IterVar
iv
:
compute
->
axis
)
{
vsub
[
iv
->
var
.
get
()]
=
value_map
.
at
(
iv
);
for
(
IterVar
iv
:
op
->
axis
)
{
if
(
value_map
.
count
(
iv
))
{
vsub
[
iv
->
var
.
get
()]
=
value_map
.
at
(
iv
);
}
// to handle tensor axis
}
}
}
Array
<
Tensor
>
ReplaceOriginalOp
(
Schedule
sch
,
Stage
orig_stage
,
const
std
::
string
&
scope
,
Operation
cache_op
,
Operation
orig_new_op
,
size_t
tensor_size
)
{
Array
<
Tensor
>
cache_tensor_list
;
for
(
size_t
i
=
0
;
i
<
tensor_size
;
i
++
)
{
Tensor
cache_tensor
=
cache_op
.
output
(
i
);
cache_tensor_list
.
push_back
(
cache_tensor
);
}
// The replace of the dataflow
std
::
unordered_map
<
Tensor
,
Tensor
>
vmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
rvmap
;
vmap
[
orig_stage
->
op
.
output
(
0
)]
=
orig_new_op
.
output
(
0
);
rvmap
[
orig_new_op
.
output
(
0
)]
=
orig_stage
->
op
.
output
(
0
);
for
(
size_t
i
=
0
;
i
<
tensor_size
;
i
++
)
{
vmap
[
orig_stage
->
op
.
output
(
0
)]
=
orig_new_op
.
output
(
0
);
rvmap
[
orig_new_op
.
output
(
0
)]
=
orig_stage
->
op
.
output
(
0
);
}
ReplaceDataFlow
(
sch
->
stages
,
&
vmap
,
&
rvmap
);
// mutate orig stage
orig_stage
->
op
=
orig_new_op
;
orig_stage
->
all_iter_vars
=
orig_stage
->
op
->
root_iter_vars
();
orig_stage
->
leaf_iter_vars
=
orig_stage
->
all_iter_vars
;
orig_stage
->
relations
=
Array
<
IterVarRelation
>
();
// create schedule for new cached stage.
ArrayNode
*
stages
=
sch
->
stages
.
CopyOnWrite
();
size_t
pos
=
FindNodeRef
(
stages
,
orig_stage
);
Stage
cache_stage
=
Stage
(
cache_op
);
cache_stage
.
set_scope
(
scope
);
CHECK_LT
(
pos
,
stages
->
data
.
size
());
stages
->
data
.
insert
(
stages
->
data
.
begin
()
+
pos
,
cache_stage
.
node_
);
sch
->
stage_map
.
Set
(
cache_op
,
cache_stage
);
// Update group
cache_stage
->
group
=
orig_stage
->
group
;
if
(
cache_stage
->
group
.
defined
())
{
++
cache_stage
->
group
->
num_child_stages
;
}
return
cache_tensor_list
;
}
// Cache write and relayout the data according to loop pattern
Array
<
Tensor
>
CacheWriteWithReLayout
(
Schedule
sch
,
const
Array
<
Tensor
>&
tensor_array
,
const
std
::
string
&
scope
)
{
size_t
tensor_size
=
tensor_array
.
size
();
sch
->
InvalidateCache
();
Tensor
tensor
=
tensor_array
[
0
];
Stage
orig_stage
=
sch
[
tensor
->
op
];
const
ComputeOpNode
*
compute
=
orig_stage
->
op
.
as
<
ComputeOpNode
>
();
std
::
unordered_set
<
IterVar
>
red_axis
;
Array
<
IterVar
>
new_axis
;
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub2newvar
;
std
::
vector
<
Expr
>
predicates
;
PrepareAxisMapping
(
orig_stage
,
compute
,
&
red_axis
,
&
new_axis
,
&
dom_map
,
&
vsub
,
&
vsub2newvar
,
&
predicates
);
Expr
body
;
Array
<
Expr
>
body_list
;
...
...
@@ -198,7 +266,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
body
=
InjectPredicate
(
predicates
,
body
);
body
=
VarReplacer
(
vsub2newvar
).
Mutate
(
body
);
// Reduce nodes in ONE computeOp must be the same except value_index
// This is right only if the ori
n
ginal body ensures Reduce nodes are the same
// This is right only if the original body ensures Reduce nodes are the same
if
(
body
->
is_type
<
ir
::
Reduce
>
())
{
const
ir
::
Reduce
*
reduce_body
=
body
.
as
<
ir
::
Reduce
>
();
if
(
first_reduce
!=
nullptr
)
{
...
...
@@ -234,48 +302,107 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
Operation
cache_op
=
ComputeOpNode
::
make
(
compute
->
name
+
"."
+
scope
,
compute
->
tag
,
compute
->
attrs
,
new_axis
,
body_list
);
Array
<
Tensor
>
cache_tensor_list
;
Array
<
Expr
>
cache_expr_list
;
for
(
size_t
i
=
0
;
i
<
tensor_size
;
i
++
)
{
Tensor
cache_tensor
=
cache_op
.
output
(
i
);
cache_tensor_list
.
push_back
(
cache_tensor
);
cache_expr_list
.
push_back
(
cache_tensor
(
args
));
}
Operation
orig_new_op
=
ComputeOpNode
::
make
(
compute
->
name
,
compute
->
tag
,
compute
->
attrs
,
compute
->
axis
,
cache_expr_list
);
// The replace of the dataflow
std
::
unordered_map
<
Tensor
,
Tensor
>
vmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
rvmap
;
vmap
[
orig_stage
->
op
.
output
(
0
)]
=
orig_new_op
.
output
(
0
);
rvmap
[
orig_new_op
.
output
(
0
)]
=
orig_stage
->
op
.
output
(
0
);
for
(
size_t
i
=
0
;
i
<
tensor_size
;
i
++
)
{
vmap
[
orig_stage
->
op
.
output
(
0
)]
=
orig_new_op
.
output
(
0
);
rvmap
[
orig_new_op
.
output
(
0
)]
=
orig_stage
->
op
.
output
(
0
);
return
ReplaceOriginalOp
(
sch
,
orig_stage
,
scope
,
cache_op
,
orig_new_op
,
tensor_size
);
}
// for tensor compute op
Array
<
Tensor
>
CacheWriteWithReLayoutTensor
(
Schedule
sch
,
const
Array
<
Tensor
>&
tensor_array
,
const
std
::
string
&
scope
)
{
size_t
tensor_size
=
tensor_array
.
size
();
sch
->
InvalidateCache
();
Tensor
tensor
=
tensor_array
[
0
];
Stage
orig_stage
=
sch
[
tensor
->
op
];
const
TensorComputeOpNode
*
tensor_op
=
orig_stage
->
op
.
as
<
TensorComputeOpNode
>
();
CHECK_EQ
(
tensor_op
->
num_outputs
(),
1
)
<<
"cache write only support single output tensor_compute_op"
;
std
::
unordered_set
<
IterVar
>
red_axis
;
Array
<
IterVar
>
new_axis
;
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub2newvar
;
std
::
vector
<
Expr
>
predicates
;
PrepareAxisMapping
(
orig_stage
,
tensor_op
,
&
red_axis
,
&
new_axis
,
&
dom_map
,
&
vsub
,
&
vsub2newvar
,
&
predicates
);
for
(
int
i
=
tensor_op
->
schedulable_ndim
;
i
<
static_cast
<
int
>
(
tensor_op
->
axis
.
size
());
++
i
)
{
IterVar
iv
=
tensor_op
->
axis
[
i
];
IterVar
new_iv
=
IterVarNode
::
make
(
iv
->
dom
,
iv
->
var
.
copy_with_suffix
(
".c"
),
iv
->
iter_type
);
new_axis
.
push_back
(
new_iv
);
}
Array
<
Region
>
new_regions
;
for
(
Region
old_region
:
tensor_op
->
input_regions
)
{
Region
region
;
for
(
Range
r
:
old_region
)
{
Expr
min
=
VarReplacer
(
vsub2newvar
).
Mutate
(
r
->
min
);
Expr
extent
=
VarReplacer
(
vsub2newvar
).
Mutate
(
r
->
extent
);
region
.
push_back
(
Range
::
make_by_min_extent
(
min
,
extent
));
}
new_regions
.
push_back
(
region
);
}
ReplaceDataFlow
(
sch
->
stages
,
&
vmap
,
&
rvmap
);
// mutate orig stage
orig_stage
->
op
=
orig_new_op
;
orig_stage
->
all_iter_vars
=
orig_stage
->
op
->
root_iter_vars
();
orig_stage
->
leaf_iter_vars
=
orig_stage
->
all_iter_vars
;
orig_stage
->
relations
=
Array
<
IterVarRelation
>
();
// create schedule for new cached stage.
ArrayNode
*
stages
=
sch
->
stages
.
CopyOnWrite
();
size_t
pos
=
FindNodeRef
(
stages
,
orig_stage
);
Stage
cache_stage
=
Stage
(
cache_op
);
cache_stage
.
set_scope
(
scope
);
CHECK_LT
(
pos
,
stages
->
data
.
size
());
stages
->
data
.
insert
(
stages
->
data
.
begin
()
+
pos
,
cache_stage
.
node_
);
sch
->
stage_map
.
Set
(
cache_op
,
cache_stage
);
// Update group
cache_stage
->
group
=
orig_stage
->
group
;
if
(
cache_stage
->
group
.
defined
())
{
++
cache_stage
->
group
->
num_child_stages
;
Operation
cache_op
=
TensorComputeOpNode
::
make
(
tensor_op
->
name
+
"."
+
scope
,
tensor_op
->
tag
,
new_axis
,
tensor_op
->
reduce_axis
,
tensor_op
->
schedulable_ndim
,
tensor_op
->
intrin
,
tensor_op
->
inputs
,
new_regions
);
// axis will be used in generating compute op
Array
<
IterVar
>
compute_axis
=
tensor_op
->
axis
;
for
(
size_t
i
=
tensor_op
->
schedulable_ndim
;
i
<
tensor_op
->
axis
.
size
();
++
i
)
{
IterVar
iv
=
tensor_op
->
axis
[
i
];
IterVar
aiv
=
IterVarNode
::
make
(
iv
->
dom
,
iv
->
var
,
kDataPar
);
compute_axis
.
Set
(
i
,
aiv
);
}
return
cache_tensor_list
;
// The reader args
Array
<
Expr
>
args
;
{
// cache->compute
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
for
(
IterVar
iv
:
compute_axis
)
{
value_map
[
iv
]
=
iv
->
var
;
}
schedule
::
PassDownIndex
(
orig_stage
,
dom_map
,
&
value_map
,
true
);
for
(
IterVar
iv
:
orig_stage
->
leaf_iter_vars
)
{
if
(
red_axis
.
count
(
iv
))
continue
;
args
.
push_back
(
value_map
.
at
(
iv
));
}
// tensorized region axis
for
(
size_t
i
=
tensor_op
->
schedulable_ndim
;
i
<
tensor_op
->
axis
.
size
();
++
i
)
{
IterVar
iv
=
compute_axis
[
i
];
args
.
push_back
(
value_map
.
at
(
iv
));
}
}
Array
<
Expr
>
cache_expr_list
;
for
(
size_t
i
=
0
;
i
<
tensor_size
;
i
++
)
{
Tensor
cache_tensor
=
cache_op
.
output
(
i
);
cache_expr_list
.
push_back
(
cache_tensor
(
args
));
}
Operation
orig_new_op
=
ComputeOpNode
::
make
(
tensor_op
->
name
,
tensor_op
->
tag
,
{},
compute_axis
,
cache_expr_list
);
return
ReplaceOriginalOp
(
sch
,
orig_stage
,
scope
,
cache_op
,
orig_new_op
,
tensor_size
);
}
Array
<
Tensor
>
Schedule
::
cache_write
(
const
Array
<
Tensor
>&
tensor_array
,
const
std
::
string
&
scope
)
{
(
*
this
)
->
InvalidateCache
();
...
...
@@ -291,23 +418,26 @@ Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
CHECK
(
orig_stage
.
same_as
(
tmp_stage
))
<<
"Input tensor list must be generated by ONE computeOp"
;
}
return
CacheWriteWithReLayout
(
*
this
,
tensor_array
,
scope
);
}
Tensor
Schedule
::
cache_write
(
const
Tensor
&
tensor
,
const
std
::
string
&
scope
)
{
// support original compute and tensor compute both
(
*
this
)
->
InvalidateCache
();
Stage
orig_stage
=
operator
[](
tensor
->
op
);
const
ComputeOpNode
*
compute
=
tensor
->
op
.
as
<
ComputeOpNode
>
();
CHECK
(
compute
)
<<
"cache write only take ComputeOp as writers"
;
CHECK_EQ
(
compute
->
num_outputs
(),
1
)
<<
"cache write only support single output ComputeOp"
;
return
(
CacheWriteWithReLayout
(
*
this
,
{
tensor
},
scope
))[
0
];
const
char
*
type_key
=
tensor
->
op
->
type_key
();
if
(
!
strcmp
(
type_key
,
"ComputeOp"
))
{
return
(
CacheWriteWithReLayout
(
*
this
,
{
tensor
},
scope
))[
0
];
}
else
if
(
!
strcmp
(
type_key
,
"TensorComputeOp"
))
{
return
(
CacheWriteWithReLayoutTensor
(
*
this
,
{
tensor
},
scope
))[
0
];
}
else
{
LOG
(
FATAL
)
<<
"cache write only take ComputeOp or TensorComputeOp as writers"
;
return
Tensor
();
}
}
void
RebaseNonZeroMinLoop
(
const
Schedule
&
sch
)
{
std
::
unordered_map
<
IterVar
,
IterVar
>
rebase_map
;
for
(
Stage
s
:
sch
->
stages
)
{
...
...
tests/python/unittest/test_lang_tensor.py
View file @
b90620ea
...
...
@@ -85,6 +85,78 @@ def test_tensor_reduce():
assert
(
isinstance
(
C_loaded
,
tvm
.
tensor
.
Tensor
))
assert
(
str
(
C_loaded
)
==
str
(
C
))
def
test_tensor_compute1
():
m
=
1024
factor
=
16
dtype
=
'float32'
def
intrin_vadd
(
n
):
x
=
tvm
.
placeholder
((
n
,))
y
=
tvm
.
placeholder
((
n
,))
z
=
tvm
.
compute
(
x
.
shape
,
lambda
i
:
x
[
i
]
+
y
[
i
])
def
intrin_func
(
ins
,
outs
):
ib
=
tvm
.
ir_builder
.
create
()
ib
.
emit
(
tvm
.
call_extern
(
outs
[
0
]
.
dtype
,
'vadd'
,
ins
[
0
]
.
access_ptr
(
"r"
),
ins
[
1
]
.
access_ptr
(
'r'
),
outs
[
0
]
.
access_ptr
(
'wr'
)))
return
ib
.
get
()
with
tvm
.
build_config
(
offset_factor
=
n
):
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
intrin_func
)
vadd
=
intrin_vadd
(
factor
)
A
=
tvm
.
placeholder
((
m
//
factor
,
factor
),
name
=
"A"
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
((
m
//
factor
,
factor
),
name
=
"B"
,
dtype
=
dtype
)
C
=
tvm
.
compute
((
m
//
factor
,
factor
),
lambda
i
:
vadd
(
A
[
i
,
0
:
factor
],
B
[
i
,
0
:
factor
]))
s
=
tvm
.
create_schedule
(
C
.
op
)
stmt
=
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
simple_mode
=
True
)
assert
isinstance
(
stmt
.
body
.
body
,
tvm
.
stmt
.
Evaluate
)
def
test_tensor_compute2
():
M
=
2048
N
=
1024
L
=
1024
factor
=
16
factor1
=
32
factor2
=
32
dtype
=
'float32'
def
intrin_gemm
(
m
,
n
,
l
):
k
=
tvm
.
reduce_axis
((
0
,
l
))
x
=
tvm
.
placeholder
((
m
,
l
))
y
=
tvm
.
placeholder
((
n
,
l
))
# in theory, no relation
z
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
x
[
i
][
k
]
*
y
[
j
][
k
],
axis
=
k
))
def
intrin_func
(
ins
,
outs
):
x_ptr
=
ins
[
0
]
.
access_ptr
(
"r"
)
y_ptr
=
ins
[
1
]
.
access_ptr
(
"r"
)
z_ptr
=
outs
[
0
]
.
access_ptr
(
"w"
)
body
=
tvm
.
call_packed
(
"gemv"
,
x_ptr
,
y_ptr
,
z_ptr
,
m
,
n
,
l
)
reset
=
tvm
.
call_packed
(
"fill_zero"
,
z_ptr
,
m
,
n
)
update
=
tvm
.
call_packed
(
"gemv_add"
,
x_ptr
,
y_ptr
,
z_ptr
,
m
,
n
,
l
)
return
body
,
reset
,
update
with
tvm
.
build_config
(
offset_factor
=
n
):
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
intrin_func
)
vgemm
=
intrin_gemm
(
factor1
,
factor2
,
factor
)
A
=
tvm
.
placeholder
((
M
//
factor1
,
L
//
factor
,
factor1
,
factor
),
name
=
"A"
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
((
N
//
factor2
,
L
//
factor
,
factor2
,
factor
),
name
=
"B"
,
dtype
=
dtype
)
k
=
tvm
.
reduce_axis
((
0
,
L
//
factor
),
name
=
'k'
)
C
=
tvm
.
compute
((
M
//
factor1
,
N
//
factor2
,
factor1
,
factor2
),
lambda
i
,
j
:
vgemm
(
A
[
i
,
k
,
0
:
factor1
,
0
:
factor
],
B
[
j
,
k
,
0
:
factor2
,
0
:
factor
],
reduce_axis
=
k
))
s
=
tvm
.
create_schedule
(
C
.
op
)
stmt
=
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
simple_mode
=
True
)
assert
isinstance
(
stmt
.
body
.
body
.
body
.
first
,
tvm
.
stmt
.
Evaluate
)
assert
isinstance
(
stmt
.
body
.
body
.
body
.
rest
.
body
,
tvm
.
stmt
.
Evaluate
)
def
test_tensor_scan
():
m
=
tvm
.
var
(
"m"
)
...
...
@@ -221,6 +293,8 @@ if __name__ == "__main__":
test_conv1d
()
test_tensor_slice
()
test_tensor
()
test_tensor_compute1
()
test_tensor_compute2
()
test_tensor_reduce
()
test_tensor_scan
()
test_scan_multi_out
()
...
...
tests/python/unittest/test_schedule_schedule_ops.py
View file @
b90620ea
...
...
@@ -276,6 +276,133 @@ def test_schedule_bound_condition():
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
assert
(
isinstance
(
stmt
.
body
.
body
.
first
.
body
.
body
.
then_case
,
tvm
.
stmt
.
IfThenElse
))
def
intrin_gemv
(
m
,
n
):
w
=
tvm
.
placeholder
((
m
,
n
),
name
=
'w'
)
x
=
tvm
.
placeholder
((
n
,),
name
=
'x'
)
k
=
tvm
.
reduce_axis
((
0
,
n
),
name
=
'k'
)
z
=
tvm
.
compute
((
m
,),
lambda
i
:
tvm
.
sum
(
w
[
i
,
k
]
*
x
[
k
],
axis
=
k
),
name
=
'z'
)
Wb
=
tvm
.
decl_buffer
(
w
.
shape
,
w
.
dtype
,
name
=
"W"
,
offset_factor
=
16
,
strides
=
[
tvm
.
var
(
'ldw'
),
1
])
def
intrin_func
(
ins
,
outs
):
ww
,
xx
=
ins
zz
=
outs
[
0
]
ww_ptr
=
ww
.
access_ptr
(
"r"
)
xx_ptr
=
xx
.
access_ptr
(
"r"
)
zz_ptr
=
zz
.
access_ptr
(
"w"
)
body
=
tvm
.
call_packed
(
"gemm"
,
ww_ptr
,
xx_ptr
,
zz_ptr
,
n
,
ww
.
strides
[
0
])
reset
=
tvm
.
call_packed
(
"fill_zero"
,
zz_ptr
,
n
)
update
=
tvm
.
call_packed
(
"gemv_add"
,
ww_ptr
,
xx_ptr
,
zz_ptr
,
n
,
ww
.
strides
[
0
])
return
body
,
reset
,
update
with
tvm
.
build_config
(
data_alignment
=
16
,
offset_factor
=
16
):
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
intrin_func
,
binds
=
{
w
:
Wb
})
def
test_schedule_tensor_compute1
():
# basic: split, reorder, tile
M
,
N
,
L
=
2048
,
1024
,
512
factor
,
rfactor
=
16
,
16
A
=
tvm
.
placeholder
((
N
//
factor
,
L
//
rfactor
,
factor
,
rfactor
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
M
,
L
//
rfactor
,
rfactor
),
name
=
'B'
)
k
=
tvm
.
reduce_axis
((
0
,
L
//
rfactor
),
name
=
'k'
)
gemv
=
intrin_gemv
(
factor
,
rfactor
)
C
=
tvm
.
compute
((
N
,
M
//
factor
,
factor
),
lambda
i
,
j
:
gemv
(
A
[
i
,
k
,
0
:
factor
,
0
:
factor
],
B
[
j
,
k
,
0
:
rfactor
],
reduce_axis
=
k
),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
ai
,
aj
,
ax
=
s
[
C
]
.
op
.
axis
aio
,
aii
=
s
[
C
]
.
split
(
ai
,
16
)
s
[
C
]
.
reorder
(
aio
,
aj
,
aii
)
aioo
,
ajo
,
aioi
,
aji
=
s
[
C
]
.
tile
(
aio
,
aj
,
16
,
4
)
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
intrin_vadd
(
n
,
cache_read
=
False
,
cache_write
=
False
):
scope_ubuf
=
'local'
dtype
=
'float32'
x
=
tvm
.
placeholder
((
n
,),
dtype
=
dtype
,
name
=
'vx'
)
y
=
tvm
.
placeholder
((
n
,),
dtype
=
dtype
,
name
=
'vy'
)
z
=
tvm
.
compute
(
x
.
shape
,
lambda
i
:
x
[
i
]
+
y
[
i
],
name
=
'z'
)
s
=
tvm
.
create_schedule
(
z
.
op
)
def
create_buffer
(
t
):
return
tvm
.
decl_buffer
(
t
.
shape
,
t
.
dtype
,
name
=
'W'
+
t
.
name
,
scope
=
scope_ubuf
,
offset_factor
=
16
)
binds
=
{}
if
cache_read
:
binds
[
x
]
=
create_buffer
(
x
)
binds
[
y
]
=
create_buffer
(
y
)
if
cache_write
:
binds
[
z
]
=
create_buffer
(
z
)
def
intrin_func
(
ins
,
outs
):
ib
=
tvm
.
ir_builder
.
create
()
ib
.
emit
(
tvm
.
call_extern
(
outs
[
0
]
.
dtype
,
'vadd'
,
ins
[
0
]
.
access_ptr
(
"r"
),
ins
[
1
]
.
access_ptr
(
'r'
),
outs
[
0
]
.
access_ptr
(
'wr'
)))
return
ib
.
get
()
with
tvm
.
build_config
(
offset_factor
=
16
):
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
intrin_func
,
binds
=
binds
)
def
test_schedule_tensor_compute2
():
# cache_read, cache_write
M
=
1024
factor
=
16
dtype
=
'float32'
scope_ubuf
=
'local'
A
=
tvm
.
placeholder
((
M
//
factor
,
factor
),
name
=
"A"
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
((
M
//
factor
,
factor
),
name
=
"B"
,
dtype
=
dtype
)
vadd
=
intrin_vadd
(
factor
,
True
,
True
)
C
=
tvm
.
compute
((
M
//
factor
,
factor
),
lambda
i
:
vadd
(
A
[
i
,
0
:
factor
],
B
[
i
,
0
:
factor
]),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
AL
=
s
.
cache_read
(
A
,
scope_ubuf
,
C
)
BL
=
s
.
cache_read
(
B
,
scope_ubuf
,
C
)
CL
=
s
.
cache_write
(
C
,
scope_ubuf
)
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
test_schedule_tensor_compute3
():
# compute_at
M
=
1024
factor
=
16
dtype
=
'float32'
A
=
tvm
.
placeholder
((
M
//
factor
,
factor
),
name
=
"A"
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
((
M
//
factor
,
factor
),
name
=
"B"
,
dtype
=
dtype
)
Bi
=
tvm
.
compute
((
M
//
factor
,
factor
),
lambda
i
,
j
:
B
[
i
,
j
]
+
5
,
name
=
"Bi"
)
vadd
=
intrin_vadd
(
factor
)
C
=
tvm
.
compute
((
M
//
factor
,
factor
),
lambda
i
:
vadd
(
A
[
i
,
0
:
factor
],
Bi
[
i
,
0
:
factor
]),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
s
[
Bi
]
.
compute_at
(
s
[
C
],
C
.
op
.
axis
[
0
])
s
=
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
if
__name__
==
"__main__"
:
test_schedule_middle_cache
()
test_inline_multi_reduce
()
...
...
@@ -294,3 +421,6 @@ if __name__ == "__main__":
test_schedule2
()
test_schedule_cache
()
test_schedule_bound_condition
()
test_schedule_tensor_compute1
()
test_schedule_tensor_compute2
()
test_schedule_tensor_compute3
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment