Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
34d2aae3
Commit
34d2aae3
authored
Jul 03, 2017
by
Tianqi Chen
Committed by
GitHub
Jul 03, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[BUFFER/REFACTOR] Buffer byte_offset-> elem_offset, add buffer_bind_scope (#209)
parent
b0e41b9a
Hide whitespace changes
Inline
Side-by-side
Showing
24 changed files
with
444 additions
and
141 deletions
+444
-141
HalideIR
+1
-1
include/tvm/buffer.h
+21
-8
include/tvm/expr.h
+8
-0
include/tvm/ir.h
+20
-4
include/tvm/operation.h
+1
-1
python/tvm/api.py
+9
-7
python/tvm/collections.py
+0
-5
python/tvm/expr.py
+4
-0
python/tvm/intrin.py
+2
-2
python/tvm/stmt.py
+4
-0
python/tvm/tensor.py
+6
-0
src/api/api_lang.cc
+5
-0
src/contrib/cblas/cblas.cc
+6
-3
src/lang/buffer.cc
+77
-16
src/lang/ir.cc
+1
-0
src/lang/tensor.cc
+1
-1
src/op/compute_op.cc
+89
-52
src/op/extern_op.cc
+21
-3
src/pass/lower_packed_call.cc
+7
-1
src/pass/make_api.cc
+14
-3
src/pass/storage_flatten.cc
+121
-33
tests/python/unittest/test_codegen_llvm.py
+7
-1
tests/python/unittest/test_lang_tensor.py
+7
-0
tests/python/unittest/test_schedule_schedule_ops.py
+12
-0
No files found.
HalideIR
@
860199ee
Subproject commit
e42653d7c3a604eb9f6ee1b5f989ddadd1cea69
c
Subproject commit
860199eea031a4ea694b8fce03ad0bf8127910a
c
include/tvm/buffer.h
View file @
34d2aae3
...
@@ -16,10 +16,11 @@ namespace tvm {
...
@@ -16,10 +16,11 @@ namespace tvm {
// Internal node container Buffer
// Internal node container Buffer
class
BufferNode
;
class
BufferNode
;
/*!
/*!
* \brief Buffer is a symbolic n-darray structure.
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
* It is a composition of primitive symbolic types,
* used to specify
input/output strcuture of the program
.
* used to specify
the memory layout of the Tensor used in program input
.
*/
*/
class
Buffer
:
public
NodeRef
{
class
Buffer
:
public
NodeRef
{
public
:
public
:
...
@@ -39,6 +40,21 @@ class Buffer : public NodeRef {
...
@@ -39,6 +40,21 @@ class Buffer : public NodeRef {
*/
*/
Stmt
MakeStore
(
Array
<
Expr
>
index
,
Expr
value
)
const
;
Stmt
MakeStore
(
Array
<
Expr
>
index
,
Expr
value
)
const
;
/*!
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
* \return The strided version of the buffer.
*/
Buffer
MakeStrideView
()
const
;
/*!
* \brief Make a new symbolic buffer representing a slice of the buffer.
* \param begins The beginning position of each dimension.
* \param extents The extent of each dimension.
* \note This function will make target buffer as compact as possible.
* If stride is not needed in the slice, it won't be presented
* \return the result buffer.
*/
Buffer
MakeSlice
(
Array
<
Expr
>
begins
,
Array
<
Expr
>
extents
)
const
;
/*!
* \brief access the internal node container
* \brief access the internal node container
* \return the pointer to the internal node container
* \return the pointer to the internal node container
*/
*/
...
@@ -63,17 +79,14 @@ class BufferNode : public Node {
...
@@ -63,17 +79,14 @@ class BufferNode : public Node {
* This can be an empty array, indicating array is contiguous
* This can be an empty array, indicating array is contiguous
*/
*/
Array
<
Expr
>
strides
;
Array
<
Expr
>
strides
;
/*!
/*! \brief The offset in terms of number of dtype elements (including lanes) */
* \brief The offset in bytes to the beginning pointer to data
Expr
elem_offset
;
* Can be undefined, indicating this must be zero.
*/
Expr
byte_offset
;
// Meta data
// Meta data
/*! \brief optional name of the buffer */
/*! \brief optional name of the buffer */
std
::
string
name
;
std
::
string
name
;
/*! \brief storage scope of the buffer, if other than global */
/*! \brief storage scope of the buffer, if other than global */
std
::
string
scope
;
std
::
string
scope
;
/*! \brief Alignment
bytes size of byte_offset
*/
/*! \brief Alignment
multiple in terms of dtype elements (including lanes)
*/
int
offset_alignment
;
int
offset_alignment
;
/*! \brief constructor */
/*! \brief constructor */
BufferNode
()
{}
BufferNode
()
{}
...
@@ -83,7 +96,7 @@ class BufferNode : public Node {
...
@@ -83,7 +96,7 @@ class BufferNode : public Node {
v
->
Visit
(
"dtype"
,
&
dtype
);
v
->
Visit
(
"dtype"
,
&
dtype
);
v
->
Visit
(
"shape"
,
&
shape
);
v
->
Visit
(
"shape"
,
&
shape
);
v
->
Visit
(
"strides"
,
&
strides
);
v
->
Visit
(
"strides"
,
&
strides
);
v
->
Visit
(
"
byte_offset"
,
&
byte
_offset
);
v
->
Visit
(
"
elem_offset"
,
&
elem
_offset
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"offset_alignment"
,
&
offset_alignment
);
v
->
Visit
(
"offset_alignment"
,
&
offset_alignment
);
...
...
include/tvm/expr.h
View file @
34d2aae3
...
@@ -61,6 +61,14 @@ inline TVMType Type2TVMType(Type t) {
...
@@ -61,6 +61,14 @@ inline TVMType Type2TVMType(Type t) {
return
ret
;
return
ret
;
}
}
// Get number of bytes considering vector type.
inline
int
GetVectorBytes
(
Type
dtype
)
{
int
data_bits
=
dtype
.
bits
()
*
dtype
.
lanes
();
CHECK_EQ
(
data_bits
%
8
,
0U
)
<<
"Need to load/store by multiple of bytes"
;
return
data_bits
/
8
;
}
/*! \brief a named variable in TVM */
/*! \brief a named variable in TVM */
class
Var
:
public
Halide
::
VarExpr
{
class
Var
:
public
Halide
::
VarExpr
{
public
:
public
:
...
...
include/tvm/ir.h
View file @
34d2aae3
...
@@ -167,8 +167,16 @@ constexpr const char* prefetch_scope = "prefetch_scope";
...
@@ -167,8 +167,16 @@ constexpr const char* prefetch_scope = "prefetch_scope";
constexpr
const
char
*
scan_update_scope
=
"scan_update_scope"
;
constexpr
const
char
*
scan_update_scope
=
"scan_update_scope"
;
/*! \brief Mark of scan init scope */
/*! \brief Mark of scan init scope */
constexpr
const
char
*
scan_init_scope
=
"scan_init_scope"
;
constexpr
const
char
*
scan_init_scope
=
"scan_init_scope"
;
/*! \brief extern operator scope */
/*!
constexpr
const
char
*
extern_op_scope
=
"extern_op_scope"
;
* \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
* stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
* The scope represents that we need to bind the storage region of tensor to buffer.
* This will affect replacement of some variables inside the scope that
* corresponds to field of buffer to be the actual expressions of tensor during
* storage flattening phase.
*/
constexpr
const
char
*
buffer_bind_scope
=
"buffer_bind_scope"
;
// Pipeline related attributes
// Pipeline related attributes
/*! \brief channel read scope */
/*! \brief channel read scope */
constexpr
const
char
*
channel_read_scope
=
"channel_read_scope"
;
constexpr
const
char
*
channel_read_scope
=
"channel_read_scope"
;
...
@@ -195,6 +203,14 @@ namespace intrinsic {
...
@@ -195,6 +203,14 @@ namespace intrinsic {
*/
*/
constexpr
const
char
*
tvm_address_of
=
"tvm_address_of"
;
constexpr
const
char
*
tvm_address_of
=
"tvm_address_of"
;
/*!
/*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
*
* Handle tvm_tuple(value0, value1, ..., value_n);
*/
constexpr
const
char
*
tvm_tuple
=
"tvm_tuple"
;
/*!
* \brief See pesudo code
* \brief See pesudo code
*
*
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
...
@@ -250,14 +266,14 @@ constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
...
@@ -250,14 +266,14 @@ constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
* Expr strides,
* Expr strides,
* Expr ndim,
* Expr ndim,
* Expr dtype,
* Expr dtype,
* Expr
byte
_offset) {
* Expr
elem
_offset) {
* ret = alloca stack DLTensor();
* ret = alloca stack DLTensor();
* ret->data = data;
* ret->data = data;
* ret->shape = shape;
* ret->shape = shape;
* ret->strides = strides != 0 ? strides : nullptr;
* ret->strides = strides != 0 ? strides : nullptr;
* ret->ndim = ndim;
* ret->ndim = ndim;
* ret->dtype = dtype.type();
* ret->dtype = dtype.type();
* ret->byte_offset =
byte_offset
;
* ret->byte_offset =
elem_offset * sizeof(dtype)
;
* return ret;
* return ret;
* }
* }
*/
*/
...
...
include/tvm/operation.h
View file @
34d2aae3
...
@@ -62,7 +62,7 @@ class OperationNode : public FunctionBaseNode {
...
@@ -62,7 +62,7 @@ class OperationNode : public FunctionBaseNode {
virtual
Array
<
Expr
>
output_shape
(
size_t
i
)
const
=
0
;
virtual
Array
<
Expr
>
output_shape
(
size_t
i
)
const
=
0
;
/*!
/*!
* \brief List all the input Tensors.
* \brief List all the input Tensors.
* \return List
i
f input tensors.
* \return List
o
f input tensors.
*/
*/
virtual
Array
<
Tensor
>
InputTensors
()
const
=
0
;
virtual
Array
<
Tensor
>
InputTensors
()
const
=
0
;
/*!
/*!
...
...
python/tvm/api.py
View file @
34d2aae3
...
@@ -287,6 +287,7 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
...
@@ -287,6 +287,7 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
update
))]
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
update
))]
return
res
[
0
]
if
len
(
res
)
==
1
else
res
return
res
[
0
]
if
len
(
res
)
==
1
else
res
def
extern
(
shape
,
inputs
,
fcompute
,
def
extern
(
shape
,
inputs
,
fcompute
,
name
=
"extern"
,
dtype
=
None
):
name
=
"extern"
,
dtype
=
None
):
"""Compute several tensor via extern function.
"""Compute several tensor via extern function.
...
@@ -374,7 +375,7 @@ def decl_buffer(shape,
...
@@ -374,7 +375,7 @@ def decl_buffer(shape,
name
=
"buffer"
,
name
=
"buffer"
,
data
=
None
,
data
=
None
,
strides
=
None
,
strides
=
None
,
byte
_offset
=
None
,
elem
_offset
=
None
,
scope
=
""
,
scope
=
""
,
offset_alignment
=
0
):
offset_alignment
=
0
):
"""Decleare a new symbolic buffer.
"""Decleare a new symbolic buffer.
...
@@ -401,8 +402,9 @@ def decl_buffer(shape,
...
@@ -401,8 +402,9 @@ def decl_buffer(shape,
strides: array of Expr
strides: array of Expr
The stride of the buffer.
The stride of the buffer.
byte_offset: Expr, optional
elem_offset: Expr, optional
The offset in bytes to data pointer.
The beginning offset of the array to data.
In terms of number of elements of dtype.
scope: str, optional
scope: str, optional
The storage scope of the buffer, if not global.
The storage scope of the buffer, if not global.
...
@@ -423,7 +425,7 @@ def decl_buffer(shape,
...
@@ -423,7 +425,7 @@ def decl_buffer(shape,
to create function that only handles specific case of data structure
to create function that only handles specific case of data structure
and make compiled function benefit from it.
and make compiled function benefit from it.
If user pass strides and
byte
_offset is passed as None
If user pass strides and
elem
_offset is passed as None
when constructing the function, then the function will be specialized
when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned.
for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides,
If user pass a fully generic symbolic array to the strides,
...
@@ -436,7 +438,7 @@ def decl_buffer(shape,
...
@@ -436,7 +438,7 @@ def decl_buffer(shape,
data
=
var
(
name
,
"handle"
)
data
=
var
(
name
,
"handle"
)
return
_api_internal
.
_Buffer
(
return
_api_internal
.
_Buffer
(
data
,
dtype
,
shape
,
strides
,
byte
_offset
,
name
,
scope
,
offset_alignment
)
data
,
dtype
,
shape
,
strides
,
elem
_offset
,
name
,
scope
,
offset_alignment
)
def
_IterVar
(
dom
,
name
,
iter_type
,
thread_tag
=
''
):
def
_IterVar
(
dom
,
name
,
iter_type
,
thread_tag
=
''
):
...
@@ -464,11 +466,11 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
...
@@ -464,11 +466,11 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
if
dom
is
not
None
:
if
dom
is
not
None
:
if
isinstance
(
dom
,
(
list
,
tuple
)):
if
isinstance
(
dom
,
(
list
,
tuple
)):
if
len
(
dom
)
!=
2
:
if
len
(
dom
)
!=
2
:
raise
ValueError
(
"need to
list of ranges"
)
raise
TypeError
(
"need to be
list of ranges"
)
dom
=
Range
(
dom
[
0
],
dom
[
1
])
dom
=
Range
(
dom
[
0
],
dom
[
1
])
if
not
isinstance
(
dom
,
_collections
.
Range
):
if
not
isinstance
(
dom
,
_collections
.
Range
):
raise
Valu
eError
(
"dom need to be Range"
)
raise
Typ
eError
(
"dom need to be Range"
)
name
=
name
if
name
else
'iter'
name
=
name
if
name
else
'iter'
v
=
var
(
name
)
v
=
var
(
name
)
return
_api_internal
.
_IterVar
(
dom
,
v
,
iter_type
,
thread_tag
)
return
_api_internal
.
_IterVar
(
dom
,
v
,
iter_type
,
thread_tag
)
...
...
python/tvm/collections.py
View file @
34d2aae3
...
@@ -26,8 +26,6 @@ class Array(NodeBase):
...
@@ -26,8 +26,6 @@ class Array(NodeBase):
def
__len__
(
self
):
def
__len__
(
self
):
return
_api_internal
.
_ArraySize
(
self
)
return
_api_internal
.
_ArraySize
(
self
)
def
__repr__
(
self
):
return
'['
+
(
','
.
join
(
str
(
x
)
for
x
in
self
))
+
']'
@register_node
@register_node
class
Map
(
NodeBase
):
class
Map
(
NodeBase
):
...
@@ -52,9 +50,6 @@ class Map(NodeBase):
...
@@ -52,9 +50,6 @@ class Map(NodeBase):
def
__len__
(
self
):
def
__len__
(
self
):
return
_api_internal
.
_MapSize
(
self
)
return
_api_internal
.
_MapSize
(
self
)
def
__repr__
(
self
):
return
'{'
+
(
", "
.
join
(
str
(
x
[
0
])
+
": "
+
str
(
x
[
1
])
for
x
in
self
.
items
()))
+
'}'
@register_node
@register_node
class
Range
(
NodeBase
):
class
Range
(
NodeBase
):
...
...
python/tvm/expr.py
View file @
34d2aae3
...
@@ -237,6 +237,10 @@ class Broadcast(Expr):
...
@@ -237,6 +237,10 @@ class Broadcast(Expr):
pass
pass
@register_node
@register_node
class
Shuffle
(
Expr
):
pass
@register_node
class
Call
(
Expr
):
class
Call
(
Expr
):
Extern
=
0
Extern
=
0
ExternCPlusPlus
=
1
ExternCPlusPlus
=
1
...
...
python/tvm/intrin.py
View file @
34d2aae3
"""Intrinsics and math functions in TVM."""
"""
Expression
Intrinsics and math functions in TVM."""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
._ffi.function
import
register_func
as
_register_func
from
._ffi.function
import
register_func
as
_register_func
...
@@ -20,7 +20,7 @@ def _pack_buffer(buf):
...
@@ -20,7 +20,7 @@ def _pack_buffer(buf):
strides
,
strides
,
len
(
buf
.
shape
),
len
(
buf
.
shape
),
const
(
0
,
dtype
=
buf
.
dtype
),
const
(
0
,
dtype
=
buf
.
dtype
),
buf
.
byte
_offset
]
buf
.
elem
_offset
]
return
_make
.
Call
(
"handle"
,
"tvm_stack_make_array"
,
return
_make
.
Call
(
"handle"
,
"tvm_stack_make_array"
,
pack_args
,
_Call
.
Intrinsic
,
None
,
0
)
pack_args
,
_Call
.
Intrinsic
,
None
,
0
)
...
...
python/tvm/stmt.py
View file @
34d2aae3
...
@@ -73,3 +73,7 @@ class IfThenElse(Stmt):
...
@@ -73,3 +73,7 @@ class IfThenElse(Stmt):
@register_node
@register_node
class
Evaluate
(
Stmt
):
class
Evaluate
(
Stmt
):
pass
pass
@register_node
class
Prefetch
(
Stmt
):
pass
python/tvm/tensor.py
View file @
34d2aae3
...
@@ -118,6 +118,12 @@ class Operation(NodeBase):
...
@@ -118,6 +118,12 @@ class Operation(NodeBase):
"""Number of outputs of this op."""
"""Number of outputs of this op."""
return
_api_internal
.
_OpNumOutputs
(
self
)
return
_api_internal
.
_OpNumOutputs
(
self
)
@property
def
input_tensors
(
self
):
"""List of input tensors to this op."""
return
_api_internal
.
_OpInputTensors
(
self
)
@register_node
@register_node
class
PlaceholderOp
(
Operation
):
class
PlaceholderOp
(
Operation
):
"""Placeholder operation."""
"""Placeholder operation."""
...
...
src/api/api_lang.cc
View file @
34d2aae3
...
@@ -218,6 +218,11 @@ TVM_REGISTER_API("_OpNumOutputs")
...
@@ -218,6 +218,11 @@ TVM_REGISTER_API("_OpNumOutputs")
*
ret
=
args
[
0
].
operator
Operation
()
->
num_outputs
();
*
ret
=
args
[
0
].
operator
Operation
()
->
num_outputs
();
});
});
TVM_REGISTER_API
(
"_OpInputTensors"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Operation
()
->
InputTensors
();
});
TVM_REGISTER_API
(
"_IterVar"
)
TVM_REGISTER_API
(
"_IterVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IterVarNode
::
make
(
*
ret
=
IterVarNode
::
make
(
...
...
src/contrib/cblas/cblas.cc
View file @
34d2aae3
...
@@ -40,10 +40,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
...
@@ -40,10 +40,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
transa
?
A
->
shape
[
1
]
:
A
->
shape
[
0
],
transa
?
A
->
shape
[
1
]
:
A
->
shape
[
0
],
transa
?
B
->
shape
[
1
]
:
B
->
shape
[
0
],
transa
?
B
->
shape
[
1
]
:
B
->
shape
[
0
],
1.0
f
,
1.0
f
,
static_cast
<
float
*>
(
B
->
data
),
B
->
shape
[
1
],
reinterpret_cast
<
float
*>
(
static_cast
<
char
*>
(
B
->
data
)
+
B
->
byte_offset
),
static_cast
<
float
*>
(
A
->
data
),
A
->
shape
[
1
],
B
->
shape
[
1
],
reinterpret_cast
<
float
*>
(
static_cast
<
char
*>
(
A
->
data
)
+
A
->
byte_offset
),
A
->
shape
[
1
],
0.0
f
,
0.0
f
,
static_cast
<
float
*>
(
C
->
data
),
C
->
shape
[
1
]);
reinterpret_cast
<
float
*>
(
static_cast
<
char
*>
(
C
->
data
)
+
C
->
byte_offset
),
C
->
shape
[
1
]);
});
});
}
// namespace contrib
}
// namespace contrib
}
// namespace tvm
}
// namespace tvm
src/lang/buffer.cc
View file @
34d2aae3
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
*/
*/
#include <tvm/buffer.h>
#include <tvm/buffer.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
namespace
tvm
{
namespace
tvm
{
...
@@ -28,27 +29,43 @@ Buffer decl_buffer(Array<Expr> shape,
...
@@ -28,27 +29,43 @@ Buffer decl_buffer(Array<Expr> shape,
name
,
""
,
0
);
name
,
""
,
0
);
}
}
inline
Expr
BufferOffset
(
const
BufferNode
*
n
,
Array
<
Expr
>
index
)
{
// The buffer offset in convention of number of elements of
Expr
base
;
// original data ignoring number of lanes.
inline
Expr
ElemOffset
(
const
BufferNode
*
n
,
Array
<
Expr
>
index
)
{
Expr
base
=
n
->
elem_offset
;
if
(
n
->
strides
.
size
()
==
0
)
{
if
(
n
->
strides
.
size
()
==
0
)
{
CHECK_EQ
(
n
->
shape
.
size
(),
index
.
size
());
CHECK_EQ
(
n
->
shape
.
size
(),
index
.
size
());
base
=
index
[
0
];
if
(
is_zero
(
base
))
{
base
=
index
[
0
];
}
else
{
base
=
base
+
index
[
0
];
}
for
(
size_t
i
=
1
;
i
<
index
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
index
.
size
();
++
i
)
{
base
=
base
*
n
->
shape
[
i
]
+
index
[
i
];
base
=
base
*
n
->
shape
[
i
]
+
index
[
i
];
}
}
}
else
{
}
else
{
CHECK_EQ
(
n
->
strides
.
size
(),
index
.
size
());
CHECK_EQ
(
n
->
strides
.
size
(),
index
.
size
());
base
=
index
[
0
]
*
n
->
strides
[
0
];
if
(
is_zero
(
base
))
{
base
=
index
[
0
]
*
n
->
strides
[
0
];
}
else
{
base
=
base
+
index
[
0
]
*
n
->
strides
[
0
];
}
for
(
size_t
i
=
1
;
i
<
index
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
index
.
size
();
++
i
)
{
base
=
base
+
index
[
i
]
*
n
->
strides
[
i
];
base
=
base
+
index
[
i
]
*
n
->
strides
[
i
];
}
}
}
}
if
(
!
is_zero
(
n
->
byte_offset
))
{
base
=
base
+
(
n
->
byte_offset
/
n
->
dtype
.
bytes
());
}
return
base
;
return
base
;
}
}
// Buffer access offset.
inline
Expr
BufferOffset
(
const
BufferNode
*
n
,
Array
<
Expr
>
index
)
{
Expr
offset
=
ElemOffset
(
n
,
index
);
if
(
n
->
dtype
.
lanes
()
!=
1
)
{
offset
=
offset
*
make_const
(
offset
.
type
(),
n
->
dtype
.
lanes
());
}
return
offset
;
}
Expr
Buffer
::
MakeLoad
(
Array
<
Expr
>
index
)
const
{
Expr
Buffer
::
MakeLoad
(
Array
<
Expr
>
index
)
const
{
const
BufferNode
*
n
=
operator
->
();
const
BufferNode
*
n
=
operator
->
();
return
ir
::
Load
::
make
(
return
ir
::
Load
::
make
(
...
@@ -63,11 +80,58 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
...
@@ -63,11 +80,58 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const_true
(
n
->
dtype
.
lanes
()));
const_true
(
n
->
dtype
.
lanes
()));
}
}
Buffer
Buffer
::
MakeStrideView
()
const
{
if
((
*
this
)
->
strides
.
size
()
!=
0
)
return
*
this
;
std
::
vector
<
Expr
>
temp
;
auto
n
=
std
::
make_shared
<
BufferNode
>
(
*
operator
->
());
Expr
acc
=
make_const
(
n
->
shape
[
0
].
type
(),
1
);
for
(
size_t
i
=
n
->
shape
.
size
();
i
!=
0
;
--
i
)
{
temp
.
push_back
(
acc
);
acc
=
acc
*
n
->
shape
[
i
-
1
];
}
for
(
size_t
i
=
temp
.
size
();
i
!=
0
;
--
i
)
{
n
->
strides
.
push_back
(
temp
[
i
-
1
]);
}
return
Buffer
(
n
);
}
Buffer
Buffer
::
MakeSlice
(
Array
<
Expr
>
begins
,
Array
<
Expr
>
extents
)
const
{
const
BufferNode
*
n
=
operator
->
();
Expr
elem_offset
=
ElemOffset
(
n
,
begins
);
Array
<
Expr
>
strides
=
n
->
strides
;
if
(
strides
.
size
()
==
0
)
{
bool
can_relax
=
true
;
bool
need_stride
=
false
;
// check if stride is needed.
for
(
size_t
i
=
0
;
i
<
extents
.
size
();
++
i
)
{
if
(
!
can_relax
)
{
if
(
!
is_zero
(
begins
[
i
])
||
!
is_zero
(
ir
::
Simplify
(
extents
[
i
]
-
n
->
shape
[
i
])))
{
need_stride
=
true
;
}
}
if
(
!
is_one
(
extents
[
i
]))
can_relax
=
false
;
}
// make stride.
if
(
need_stride
)
{
return
MakeStrideView
().
MakeSlice
(
begins
,
extents
);
}
}
return
BufferNode
::
make
(
n
->
data
,
n
->
dtype
,
extents
,
strides
,
elem_offset
,
n
->
name
+
"_slice"
,
n
->
scope
,
0
);
}
Buffer
BufferNode
::
make
(
Var
data
,
Buffer
BufferNode
::
make
(
Var
data
,
Type
dtype
,
Type
dtype
,
Array
<
Expr
>
shape
,
Array
<
Expr
>
shape
,
Array
<
Expr
>
strides
,
Array
<
Expr
>
strides
,
Expr
byte
_offset
,
Expr
elem
_offset
,
std
::
string
name
,
std
::
string
name
,
std
::
string
scope
,
std
::
string
scope
,
int
offset_alignment
)
{
int
offset_alignment
)
{
...
@@ -78,16 +142,13 @@ Buffer BufferNode::make(Var data,
...
@@ -78,16 +142,13 @@ Buffer BufferNode::make(Var data,
n
->
strides
=
std
::
move
(
strides
);
n
->
strides
=
std
::
move
(
strides
);
n
->
name
=
std
::
move
(
name
);
n
->
name
=
std
::
move
(
name
);
n
->
scope
=
std
::
move
(
scope
);
n
->
scope
=
std
::
move
(
scope
);
if
(
!
byte
_offset
.
defined
())
{
if
(
!
elem
_offset
.
defined
())
{
byte
_offset
=
make_const
(
n
->
shape
[
0
].
type
(),
0
);
elem
_offset
=
make_const
(
n
->
shape
[
0
].
type
(),
0
);
}
}
if
(
offset_alignment
!=
0
)
{
if
(
offset_alignment
==
0
)
{
CHECK_EQ
(
offset_alignment
%
dtype
.
bytes
(),
0
)
offset_alignment
=
1
;
<<
"Offset alignments must be at least "
<<
dtype
.
bytes
();
}
else
{
offset_alignment
=
dtype
.
bytes
();
}
}
n
->
byte_offset
=
byte
_offset
;
n
->
elem_offset
=
elem
_offset
;
n
->
offset_alignment
=
offset_alignment
;
n
->
offset_alignment
=
offset_alignment
;
return
Buffer
(
n
);
return
Buffer
(
n
);
}
}
...
...
src/lang/ir.cc
View file @
34d2aae3
...
@@ -42,6 +42,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -42,6 +42,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<<
", identity_element="
<<
op
->
identity_element
<<
", identity_element="
<<
op
->
identity_element
<<
")"
;
<<
")"
;
});
});
}
// namespace Internal
}
// namespace Internal
}
// namespace Halide
}
// namespace Halide
...
...
src/lang/tensor.cc
View file @
34d2aae3
...
@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape,
...
@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape,
Operation
op
,
Operation
op
,
int
value_index
)
{
int
value_index
)
{
auto
n
=
std
::
make_shared
<
TensorNode
>
();
auto
n
=
std
::
make_shared
<
TensorNode
>
();
n
->
shape
=
s
hape
;
n
->
shape
=
s
td
::
move
(
shape
)
;
n
->
dtype
=
dtype
;
n
->
dtype
=
dtype
;
n
->
op
=
op
;
n
->
op
=
op
;
n
->
value_index
=
value_index
;
n
->
value_index
=
value_index
;
...
...
src/op/compute_op.cc
View file @
34d2aae3
...
@@ -251,7 +251,7 @@ Stmt Substitute(Stmt s,
...
@@ -251,7 +251,7 @@ Stmt Substitute(Stmt s,
return
ir
::
Substitute
(
s
,
temp
);
return
ir
::
Substitute
(
s
,
temp
);
}
}
// Cross Thread reduction
marker.
// Cross Thread reduction
bool
IsCrossThreadReduction
(
const
ComputeOpNode
*
self
,
bool
IsCrossThreadReduction
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
)
{
const
Stage
&
stage
)
{
// Verify correctness of leaf nest.
// Verify correctness of leaf nest.
...
@@ -360,6 +360,7 @@ Stmt MakeCrossThreadReduction(
...
@@ -360,6 +360,7 @@ Stmt MakeCrossThreadReduction(
return
MergeNest
(
nest
,
body
);
return
MergeNest
(
nest
,
body
);
}
}
// Normal computation.
Stmt
MakeProvide
(
const
ComputeOpNode
*
op
,
Stmt
MakeProvide
(
const
ComputeOpNode
*
op
,
const
Tensor
&
t
)
{
const
Tensor
&
t
)
{
Array
<
Expr
>
args
;
Array
<
Expr
>
args
;
...
@@ -369,60 +370,56 @@ Stmt MakeProvide(const ComputeOpNode* op,
...
@@ -369,60 +370,56 @@ Stmt MakeProvide(const ComputeOpNode* op,
return
Provide
::
make
(
t
->
op
,
t
->
value_index
,
op
->
body
[
t
->
value_index
],
args
);
return
Provide
::
make
(
t
->
op
,
t
->
value_index
,
op
->
body
[
t
->
value_index
],
args
);
}
}
Stmt
ComputeOpNode
::
BuildProvide
(
// loop nest structure for general compute
// This the the loop nest structured used in compute.
// Does not include the loop body.
struct
ComputeLoopNest
{
// The common number of loops between init and main
size_t
num_common_loop
;
// predicates for the initialize loop
std
::
vector
<
Expr
>
init_predicates
;
// Initialization nest involved.
std
::
vector
<
std
::
vector
<
Stmt
>
>
init_nest
;
// Value map for the init code
std
::
unordered_map
<
IterVar
,
Expr
>
init_vmap
;
// Predicates for the main update loop
std
::
vector
<
Expr
>
main_predicates
;
// The general loop nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
main_nest
;
// Value map for the IterVar.
std
::
unordered_map
<
IterVar
,
Expr
>
main_vmap
;
};
ComputeLoopNest
MakeComputeLoopNest
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
CHECK_EQ
(
stage
->
op
.
operator
->
(),
self
);
ComputeLoopNest
ret
;
if
(
IsCrossThreadReduction
(
this
,
stage
))
{
// make main loop nest
// specially handle cross thread reduction.
ret
.
main_nest
=
op
::
MakeLoopNest
(
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
);
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
);
}
ret
.
main_predicates
=
op
::
MakeBoundCheck
(
stage
,
dom_map
,
false
,
std
::
unordered_set
<
IterVar
>
(),
ret
.
main_vmap
);
size_t
size
=
this
->
body
.
size
();
for
(
auto
&
e
:
ret
.
main_predicates
)
{
Stmt
init
;
e
=
likely
(
e
);
Stmt
provide
;
if
(
this
->
reduce_axis
.
size
()
==
0
)
{
std
::
vector
<
Stmt
>
provides
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
provides
.
emplace_back
(
MakeProvide
(
this
,
stage
->
op
.
output
(
i
)));
}
provide
=
Block
::
make
(
provides
);
}
else
{
Array
<
Tensor
>
source
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
source
.
push_back
(
stage
->
op
.
output
(
i
));
}
MakeReduction
(
this
,
source
,
&
init
,
&
provide
);
}
}
// make loop nest
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
auto
nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
value_map
);
auto
preds
=
op
::
MakeBoundCheck
(
stage
,
dom_map
,
false
,
std
::
unordered_set
<
IterVar
>
(),
value_map
);
for
(
auto
&
e
:
preds
)
e
=
likely
(
e
);
nest
.
push_back
(
op
::
MakeIfNest
(
preds
));
if
(
stage
->
store_predicate
.
defined
())
{
if
(
stage
->
store_predicate
.
defined
())
{
nest
.
emplace_back
(
op
::
MakeIfNest
({
stage
->
store_predicate
})
);
ret
.
main_predicates
.
push_back
(
stage
->
store_predicate
);
}
}
provide
=
Substitute
(
provide
,
value_map
);
if
(
self
->
reduce_axis
.
size
()
!=
0
)
{
if
(
init
.
defined
())
{
// try to find the location to insert the initialization.
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
// Fuse the initialization and provide loop when possible.
std
::
unordered_map
<
IterVar
,
int
>
update_state
;
std
::
unordered_map
<
IterVar
,
int
>
update_state
;
for
(
IterVar
iv
:
this
->
reduce_axis
)
{
for
(
IterVar
iv
:
self
->
reduce_axis
)
{
update_state
[
iv
]
=
2
;
update_state
[
iv
]
=
2
;
}
}
for
(
IterVar
iv
:
this
->
axis
)
{
for
(
IterVar
iv
:
self
->
axis
)
{
update_state
[
iv
]
=
1
;
update_state
[
iv
]
=
1
;
}
}
// find which iter var is related to reduction and which is related to axis.
// find which iter var is related to reduction and which is related to axis.
schedule
::
PassDownBitMaskOr
(
stage
,
&
update_state
);
schedule
::
PassDownBitMaskOr
(
stage
,
&
update_state
);
auto
leaf_iter_vars
=
stage
->
leaf_iter_vars
;
auto
leaf_iter_vars
=
stage
->
leaf_iter_vars
;
std
::
unordered_map
<
IterVar
,
Expr
>
init_value_map
;
// first first loop that is related to reduction.
// first first loop that is related to reduction.
size_t
begin_loop
=
leaf_iter_vars
.
size
();
size_t
begin_loop
=
leaf_iter_vars
.
size
();
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
...
@@ -431,29 +428,69 @@ Stmt ComputeOpNode::BuildProvide(
...
@@ -431,29 +428,69 @@ Stmt ComputeOpNode::BuildProvide(
if
((
flag
&
2
)
!=
0
)
{
if
((
flag
&
2
)
!=
0
)
{
begin_loop
=
i
;
break
;
begin_loop
=
i
;
break
;
}
}
init_value_map
[
iv
]
=
value_
map
.
at
(
iv
);
ret
.
init_vmap
[
iv
]
=
ret
.
main_v
map
.
at
(
iv
);
}
}
ret
.
num_common_loop
=
begin_loop
;
// skip loops that does not relates to axis.
// skip loops that does not relates to axis.
std
::
unordered_set
<
IterVar
>
skip_iter
;
std
::
unordered_set
<
IterVar
>
skip_iter
;
for
(
auto
kv
:
update_state
)
{
for
(
auto
kv
:
update_state
)
{
int
flag
=
kv
.
second
;
int
flag
=
kv
.
second
;
if
((
flag
&
1
)
==
0
)
skip_iter
.
insert
(
kv
.
first
);
if
((
flag
&
1
)
==
0
)
skip_iter
.
insert
(
kv
.
first
);
}
}
auto
init_nest
=
op
::
MakeLoopNest
(
ret
.
init_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
begin_loop
,
true
,
stage
,
dom_map
,
begin_loop
,
true
,
skip_iter
,
&
init_value_map
);
skip_iter
,
&
(
ret
.
init_vmap
));
auto
preds
=
op
::
MakeBoundCheck
(
stage
,
dom_map
,
true
,
skip_iter
,
init_value_map
);
ret
.
init_predicates
=
op
::
MakeBoundCheck
(
for
(
auto
&
e
:
preds
)
e
=
likely
(
e
);
stage
,
dom_map
,
true
,
skip_iter
,
ret
.
init_vmap
);
init_nest
.
push_back
(
op
::
MakeIfNest
(
preds
));
for
(
auto
&
e
:
ret
.
init_predicates
)
{
init
=
Substitute
(
init
,
init_value_map
);
e
=
likely
(
e
);
init
=
MergeNest
(
init_nest
,
init
);
}
}
else
{
ret
.
num_common_loop
=
ret
.
main_nest
.
size
()
-
1
;
}
// copy elison here.
return
ret
;
}
// implement the provide utility.
Stmt
ComputeOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
if
(
IsCrossThreadReduction
(
this
,
stage
))
{
// specially handle cross thread reduction.
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
);
}
// grab the nest structure
ComputeLoopNest
n
=
MakeComputeLoopNest
(
this
,
stage
,
dom_map
);
// Normal loop structure
n
.
init_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
init_predicates
));
n
.
main_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
main_predicates
));
if
(
this
->
reduce_axis
.
size
()
!=
0
)
{
// make reduction.
Stmt
init
,
provide
;
Array
<
Tensor
>
source
;
for
(
size_t
i
=
0
;
i
<
this
->
body
.
size
();
++
i
)
{
source
.
push_back
(
stage
->
op
.
output
(
i
));
}
MakeReduction
(
this
,
source
,
&
init
,
&
provide
);
init
=
Substitute
(
init
,
n
.
init_vmap
);
init
=
MergeNest
(
n
.
init_nest
,
init
);
// common nest
// common nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
common
(
nest
.
begin
(),
nest
.
begin
()
+
begin_loop
+
1
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
common
(
std
::
vector
<
std
::
vector
<
Stmt
>
>
reduce
(
nest
.
begin
()
+
begin_loop
+
1
,
nest
.
end
());
n
.
main_nest
.
begin
(),
n
.
main_nest
.
begin
()
+
n
.
num_common_loop
+
1
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
reduce
(
n
.
main_nest
.
begin
()
+
n
.
num_common_loop
+
1
,
n
.
main_nest
.
end
());
provide
=
Substitute
(
provide
,
n
.
main_vmap
);
provide
=
MergeNest
(
reduce
,
provide
);
provide
=
MergeNest
(
reduce
,
provide
);
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
}
else
{
}
else
{
return
MergeNest
(
nest
,
provide
);
std
::
vector
<
Stmt
>
provides
;
for
(
size_t
i
=
0
;
i
<
this
->
body
.
size
();
++
i
)
{
provides
.
emplace_back
(
MakeProvide
(
this
,
stage
->
op
.
output
(
i
)));
}
Stmt
provide
=
Substitute
(
Block
::
make
(
provides
),
n
.
main_vmap
);
return
MergeNest
(
n
.
main_nest
,
provide
);
}
}
}
}
}
// namespace tvm
}
// namespace tvm
src/op/extern_op.cc
View file @
34d2aae3
...
@@ -128,8 +128,26 @@ Stmt ExternOpNode::BuildProvide(
...
@@ -128,8 +128,26 @@ Stmt ExternOpNode::BuildProvide(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
return
AttrStmt
::
make
(
Stmt
ret
=
this
->
body
;
stage
->
op
,
ir
::
attr
::
extern_op_scope
,
auto
f_push_bind
=
[
&
ret
](
Buffer
buffer
,
Tensor
tensor
)
{
StringImm
::
make
(
name
),
body
);
Array
<
NodeRef
>
bind_spec
;
Array
<
Expr
>
tuple
;
bind_spec
.
push_back
(
buffer
);
bind_spec
.
push_back
(
tensor
);
for
(
size_t
k
=
0
;
k
<
buffer
->
shape
.
size
();
++
k
)
{
tuple
.
push_back
(
make_const
(
buffer
->
shape
[
k
].
type
(),
0
));
tuple
.
push_back
(
buffer
->
shape
[
k
]);
}
ret
=
AttrStmt
::
make
(
bind_spec
,
attr
::
buffer_bind_scope
,
Call
::
make
(
Handle
(),
intrinsic
::
tvm_tuple
,
tuple
,
Call
::
Intrinsic
),
ret
);
};
for
(
size_t
i
=
output_placeholders
.
size
();
i
!=
0
;
--
i
)
{
f_push_bind
(
output_placeholders
[
i
-
1
],
stage
->
op
.
output
(
i
-
1
));
}
for
(
size_t
i
=
inputs
.
size
();
i
!=
0
;
--
i
)
{
f_push_bind
(
input_placeholders
[
i
-
1
],
inputs
[
i
-
1
]);
}
return
ret
;
}
}
}
// namespace tvm
}
// namespace tvm
src/pass/lower_packed_call.cc
View file @
34d2aae3
...
@@ -131,9 +131,15 @@ class PackedCallBuilder : public IRMutator {
...
@@ -131,9 +131,15 @@ class PackedCallBuilder : public IRMutator {
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrTypeLanes
,
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrTypeLanes
,
make_const
(
UInt
(
16
),
dtype
.
lanes
())));
make_const
(
UInt
(
16
),
dtype
.
lanes
())));
// set byte offset
int
data_bytes
=
GetVectorBytes
(
dtype
);
Expr
byte_offset
=
op
->
args
[
5
];
if
(
!
is_zero
(
byte_offset
))
{
byte_offset
=
byte_offset
*
make_const
(
byte_offset
.
type
(),
data_bytes
);
}
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrByteOffset
,
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrByteOffset
,
Convert
(
Int
(
64
),
op
->
args
[
5
]
)));
Convert
(
UInt
(
64
),
byte_offset
)));
CHECK
(
device_type_
.
defined
())
<<
"Unknown device type in current IR"
;
CHECK
(
device_type_
.
defined
())
<<
"Unknown device type in current IR"
;
CHECK
(
device_id_
.
defined
())
<<
"Unknown device id in current IR"
;
CHECK
(
device_id_
.
defined
())
<<
"Unknown device id in current IR"
;
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
...
...
src/pass/make_api.cc
View file @
34d2aae3
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <unordered_set>
#include <unordered_set>
#include "./ir_util.h"
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
...
@@ -222,9 +223,19 @@ LoweredFunc MakeAPI(Stmt body,
...
@@ -222,9 +223,19 @@ LoweredFunc MakeAPI(Stmt body,
}
}
}
}
// Byte_offset field.
// Byte_offset field.
f_push
(
buf
->
byte_offset
,
int
data_bytes
=
GetVectorBytes
(
buf
->
dtype
);
TVMArrayGet
(
UInt
(
64
),
v_arg
,
intrinsic
::
kArrByteOffset
),
int64_t
const_offset
;
v_arg
->
name_hint
+
".byte_offset"
);
if
(
arith
::
GetConst
(
buf
->
elem_offset
,
&
const_offset
))
{
f_push
(
make_const
(
buf
->
elem_offset
.
type
(),
const_offset
*
data_bytes
),
TVMArrayGet
(
UInt
(
64
),
v_arg
,
intrinsic
::
kArrByteOffset
),
v_arg
->
name_hint
+
".byte_offset"
);
}
else
{
f_push
(
buf
->
elem_offset
,
cast
(
buf
->
elem_offset
.
type
(),
(
TVMArrayGet
(
UInt
(
64
),
v_arg
,
intrinsic
::
kArrByteOffset
)
/
make_const
(
UInt
(
64
),
data_bytes
))),
v_arg
->
name_hint
+
".elem_offset"
);
}
// device info.
// device info.
f_push
(
device_id
,
f_push
(
device_id
,
TVMArrayGet
(
Int
(
32
),
v_arg
,
intrinsic
::
kArrDeviceId
),
TVMArrayGet
(
Int
(
32
),
v_arg
,
intrinsic
::
kArrDeviceId
),
...
...
src/pass/storage_flatten.cc
View file @
34d2aae3
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
#include <tvm/buffer.h>
#include <tvm/operation.h>
#include <unordered_map>
#include <unordered_map>
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -31,9 +31,12 @@ class StorageFlattener : public IRMutator {
...
@@ -31,9 +31,12 @@ class StorageFlattener : public IRMutator {
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Store
>
();
op
=
stmt
.
as
<
Store
>
();
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
var_remap_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
extern_buf_remap_
.
end
())
{
if
(
it
!=
var_remap_
.
end
()
&&
return
Store
::
make
(
it
->
second
,
op
->
value
,
op
->
index
,
op
->
predicate
);
!
it
->
second
.
same_as
(
op
->
buffer_var
))
{
CHECK
(
it
->
second
.
as
<
Variable
>
());
VarExpr
buf_var
(
it
->
second
.
node_
);
return
Store
::
make
(
buf_var
,
op
->
value
,
op
->
index
,
op
->
predicate
);
}
else
{
}
else
{
return
stmt
;
return
stmt
;
}
}
...
@@ -50,8 +53,8 @@ class StorageFlattener : public IRMutator {
...
@@ -50,8 +53,8 @@ class StorageFlattener : public IRMutator {
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
curr_thread_scope_
.
pop_back
();
curr_thread_scope_
.
pop_back
();
return
stmt
;
return
stmt
;
}
else
if
(
op
->
attr_key
==
attr
::
extern_op
_scope
)
{
}
else
if
(
op
->
attr_key
==
attr
::
buffer_bind
_scope
)
{
return
Handle
ExternOp
(
op
);
return
Handle
BufferBindScope
(
op
);
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
...
@@ -115,17 +118,20 @@ class StorageFlattener : public IRMutator {
...
@@ -115,17 +118,20 @@ class StorageFlattener : public IRMutator {
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Load
>
();
op
=
expr
.
as
<
Load
>
();
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
var_remap_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
extern_buf_remap_
.
end
())
{
if
(
it
!=
var_remap_
.
end
()
&&
return
Load
::
make
(
op
->
type
,
it
->
second
,
op
->
index
,
op
->
predicate
);
!
it
->
second
.
same_as
(
op
->
buffer_var
))
{
CHECK
(
it
->
second
.
as
<
Variable
>
());
VarExpr
buf_var
(
it
->
second
.
node_
);
return
Load
::
make
(
op
->
type
,
buf_var
,
op
->
index
,
op
->
predicate
);
}
else
{
}
else
{
return
expr
;
return
expr
;
}
}
}
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
extern_buf
_remap_
.
find
(
op
);
auto
it
=
var
_remap_
.
find
(
op
);
if
(
it
!=
extern_buf
_remap_
.
end
())
{
if
(
it
!=
var
_remap_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
else
{
}
else
{
return
e
;
return
e
;
...
@@ -150,35 +156,115 @@ class StorageFlattener : public IRMutator {
...
@@ -150,35 +156,115 @@ class StorageFlattener : public IRMutator {
}
}
private
:
private
:
Stmt
HandleExternOp
(
const
AttrStmt
*
op
)
{
// Bind the symbol sym to value if it is a Variable
const
ExternOpNode
*
ext_op
=
op
->
node
.
as
<
ExternOpNode
>
();
// send a sequence of asserts if it is a constant constrant.
CHECK
(
ext_op
);
// hint_name: used for error message
Operation
func
(
op
->
node
.
node_
);
// add_keys: a list of newly binded keys
CHECK_EQ
(
extern_buf_remap_
.
size
(),
0U
);
// add_asserts: a list of asserts during the bind
for
(
size_t
i
=
0
;
i
<
ext_op
->
output_placeholders
.
size
();
++
i
)
{
void
BindSymbol
(
Expr
sym
,
TensorKey
key
{
func
,
static_cast
<
int
>
(
i
)};
Expr
value
,
CHECK
(
buf_map_
.
count
(
key
))
std
::
string
hint_name
,
<<
"Cannot find allocated buffer for "
<<
key
.
f
std
::
vector
<
const
Variable
*>*
add_keys
,
<<
"("
<<
key
.
value_index
<<
")"
;
std
::
vector
<
Stmt
>*
add_asserts
)
{
extern_buf_remap_
[
ext_op
->
output_placeholders
[
i
]
->
data
.
get
()]
=
if
(
const
Variable
*
v
=
sym
.
as
<
Variable
>
())
{
buf_map_
.
at
(
key
).
buffer
->
data
;
auto
it
=
var_remap_
.
find
(
v
);
if
(
it
==
var_remap_
.
end
())
{
add_keys
->
push_back
(
v
);
var_remap_
[
v
]
=
value
;
return
;
}
}
// add assertions
std
::
ostringstream
os
;
os
<<
"BufferBind constaint fail "
<<
hint_name
;
add_asserts
->
emplace_back
(
AssertStmt
::
make
(
sym
==
value
,
os
.
str
()));
}
// Start bind
Stmt
HandleBufferBindScope
(
const
AttrStmt
*
op
)
{
Array
<
NodeRef
>
arr
(
op
->
node
.
node_
);
CHECK_EQ
(
arr
.
size
(),
2U
);
const
BufferNode
*
buffer
=
arr
[
0
].
as
<
BufferNode
>
();
const
TensorNode
*
tensor
=
arr
[
1
].
as
<
TensorNode
>
();
const
Call
*
tuple
=
op
->
value
.
as
<
Call
>
();
CHECK
(
buffer
&&
tensor
);
CHECK
(
tuple
&&
tuple
->
is_intrinsic
(
intrinsic
::
tvm_tuple
));
TensorKey
key
{
tensor
->
op
,
tensor
->
value_index
};
CHECK
(
buf_map_
.
count
(
key
));
const
BufferEntry
&
be
=
buf_map_
.
at
(
key
);
CHECK
(
!
be
.
released
);
CHECK_EQ
(
tuple
->
args
.
size
(),
be
.
buffer
->
shape
.
size
()
*
2
);
Array
<
Expr
>
begins
,
extents
;
if
(
be
.
bounds
.
size
()
!=
0
)
{
CHECK_EQ
(
tuple
->
args
.
size
(),
be
.
bounds
.
size
()
*
2
);
for
(
size_t
i
=
0
;
i
<
be
.
buffer
->
shape
.
size
();
++
i
)
{
begins
.
push_back
(
arith
::
ComputeExpr
<
Sub
>
(
tuple
->
args
[
2
*
i
],
be
.
bounds
[
i
]
->
min
));
extents
.
push_back
(
tuple
->
args
[
2
*
i
+
1
]);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
tuple
->
args
.
size
();
i
+=
2
)
{
begins
.
push_back
(
tuple
->
args
[
i
]);
extents
.
push_back
(
tuple
->
args
[
i
+
1
]);
}
}
Buffer
slice
=
be
.
buffer
.
MakeSlice
(
begins
,
extents
);
if
(
buffer
->
strides
.
size
()
==
0
)
{
CHECK_EQ
(
slice
->
strides
.
size
(),
0U
)
<<
"Trying to bind compact buffer to strided one"
;
}
else
{
slice
=
slice
.
MakeStrideView
();
}
CHECK_EQ
(
slice
->
strides
.
size
(),
buffer
->
strides
.
size
());
// start binding
std
::
vector
<
const
Variable
*>
keys
;
std
::
vector
<
Stmt
>
asserts
;
BindSymbol
(
buffer
->
data
,
slice
->
data
,
buffer
->
name
+
".data"
,
&
keys
,
&
asserts
);
for
(
size_t
i
=
0
;
i
<
buffer
->
shape
.
size
();
++
i
)
{
std
::
ostringstream
field_name
;
field_name
<<
buffer
->
name
<<
".shape["
<<
i
<<
']'
;
BindSymbol
(
buffer
->
shape
[
i
],
slice
->
shape
[
i
],
field_name
.
str
(),
&
keys
,
&
asserts
);
}
for
(
size_t
i
=
0
;
i
<
buffer
->
strides
.
size
();
++
i
)
{
std
::
ostringstream
field_name
;
field_name
<<
buffer
->
name
<<
".strides["
<<
i
<<
']'
;
BindSymbol
(
buffer
->
strides
[
i
],
slice
->
strides
[
i
],
field_name
.
str
(),
&
keys
,
&
asserts
);
}
BindSymbol
(
buffer
->
elem_offset
,
slice
->
elem_offset
,
buffer
->
name
+
".elem_offset"
,
&
keys
,
&
asserts
);
CHECK_EQ
(
buffer
->
scope
,
slice
->
scope
)
<<
"Buffer bind scope mismatch"
;
// Apply the remaps
Stmt
body
=
this
->
Mutate
(
op
->
body
);
for
(
size_t
i
=
0
;
i
<
asserts
.
size
();
++
i
)
{
Stmt
ret
=
Simplify
(
this
->
Mutate
(
asserts
[
i
]));
if
(
const
AssertStmt
*
assert_op
=
ret
.
as
<
AssertStmt
>
())
{
if
(
!
is_zero
(
assert_op
->
condition
))
{
body
=
Block
::
make
(
ret
,
body
);
}
else
{
LOG
(
FATAL
)
<<
"BindBuffer have unmet assertion: "
<<
ret
;
}
}
}
}
for
(
size_t
i
=
0
;
i
<
ext_op
->
inputs
.
size
();
++
i
)
{
// remove the binds
TensorKey
key
{
ext_op
->
inputs
[
i
]
->
op
,
ext_op
->
inputs
[
i
]
->
value_index
};
for
(
const
Variable
*
op
:
keys
)
{
CHECK
(
buf_map_
.
count
(
key
));
var_remap_
.
erase
(
op
);
extern_buf_remap_
[
ext_op
->
input_placeholders
[
i
]
->
data
.
get
()]
=
buf_map_
.
at
(
key
).
buffer
->
data
;
}
}
Stmt
ret
=
Mutate
(
op
->
body
);
return
body
;
extern_buf_remap_
.
clear
();
return
ret
;
}
}
// The buffer entry in the flatten map
// The buffer entry in the flatten map
struct
BufferEntry
{
struct
BufferEntry
{
// the buffer of storage
// the buffer of storage
Buffer
buffer
;
Buffer
buffer
;
// the bounds of realization, can be null
// the bounds of realization, can be null
, means everything
Region
bounds
;
Region
bounds
;
// Whether the buffer is external
// Whether the buffer is external
bool
external
{
false
};
bool
external
{
false
};
...
@@ -200,7 +286,9 @@ class StorageFlattener : public IRMutator {
...
@@ -200,7 +286,9 @@ class StorageFlattener : public IRMutator {
}
}
};
};
// The buffer assignment map
// The buffer assignment map
std
::
unordered_map
<
const
Variable
*
,
Var
>
extern_buf_remap_
;
// Variable remap
std
::
unordered_map
<
const
Variable
*
,
Expr
>
var_remap_
;
// Buffer map
std
::
unordered_map
<
TensorKey
,
BufferEntry
>
buf_map_
;
std
::
unordered_map
<
TensorKey
,
BufferEntry
>
buf_map_
;
std
::
unordered_map
<
const
Node
*
,
std
::
string
>
storage_scope_
;
std
::
unordered_map
<
const
Node
*
,
std
::
string
>
storage_scope_
;
// The current thread scope.
// The current thread scope.
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
34d2aae3
...
@@ -14,8 +14,13 @@ def test_llvm_add_pipeline():
...
@@ -14,8 +14,13 @@ def test_llvm_add_pipeline():
def
check_llvm
():
def
check_llvm
():
if
not
tvm
.
module
.
enabled
(
"llvm"
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
return
# Specifically allow offset to test codepath when offset is available
Ab
=
tvm
.
decl_buffer
(
A
.
shape
,
A
.
dtype
,
elem_offset
=
tvm
.
var
(
'Aoffset'
),
name
=
'A'
)
binds
=
{
A
:
Ab
}
# build and invoke the kernel.
# build and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
"llvm"
)
f
=
tvm
.
build
(
s
,
[
A
b
,
B
,
C
],
"llvm"
,
binds
=
binds
)
ctx
=
tvm
.
cpu
(
0
)
ctx
=
tvm
.
cpu
(
0
)
# launch the kernel.
# launch the kernel.
n
=
nn
n
=
nn
...
@@ -25,6 +30,7 @@ def test_llvm_add_pipeline():
...
@@ -25,6 +30,7 @@ def test_llvm_add_pipeline():
f
(
a
,
b
,
c
)
f
(
a
,
b
,
c
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
check_llvm
()
check_llvm
()
...
...
tests/python/unittest/test_lang_tensor.py
View file @
34d2aae3
...
@@ -168,7 +168,14 @@ def test_tuple_with_different_deps():
...
@@ -168,7 +168,14 @@ def test_tuple_with_different_deps():
assert
stmt
.
node
==
C
.
op
and
len
(
ret
)
==
1
assert
stmt
.
node
==
C
.
op
and
len
(
ret
)
==
1
def
test_tensor_inputs
():
x
=
tvm
.
placeholder
((
1
,),
name
=
'x'
)
y
=
tvm
.
compute
(
x
.
shape
,
lambda
i
:
x
[
i
]
+
x
[
i
])
assert
tuple
(
y
.
op
.
input_tensors
)
==
(
x
,)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_tensor_inputs
()
test_tensor_reduce_multi_axis
()
test_tensor_reduce_multi_axis
()
test_conv1d
()
test_conv1d
()
test_tensor_slice
()
test_tensor_slice
()
...
...
tests/python/unittest/test_schedule_schedule_ops.py
View file @
34d2aae3
...
@@ -72,6 +72,17 @@ def test_auto_inline():
...
@@ -72,6 +72,17 @@ def test_auto_inline():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
test_schedule_const_bound
():
n
=
128
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
A1
=
tvm
.
compute
((
n
,),
lambda
i
:
A
[
i
]
+
1
,
name
=
'A1'
)
s
=
tvm
.
create_schedule
(
A1
.
op
)
xo
,
xi
=
s
[
A1
]
.
split
(
A1
.
op
.
axis
[
0
],
8
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
test_inline_mixed
():
def
test_inline_mixed
():
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
...
@@ -150,6 +161,7 @@ def test_schedule_cache():
...
@@ -150,6 +161,7 @@ def test_schedule_cache():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_schedule_const_bound
()
test_scan_inline1
()
test_scan_inline1
()
test_scan_inline2
()
test_scan_inline2
()
test_inline_mixed
()
test_inline_mixed
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment