Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
00506a62
Commit
00506a62
authored
Jul 05, 2017
by
Tianqi Chen
Committed by
GitHub
Jul 05, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Add body to AssertStmt (#220)
* [IR] Add body to AssertStmt * fix lint
parent
c9da7254
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
135 additions
and
44 deletions
+135
-44
HalideIR
+1
-1
include/tvm/arithmetic.h
+2
-2
include/tvm/buffer.h
+17
-5
include/tvm/ir.h
+1
-1
python/tvm/api.py
+15
-4
python/tvm/build.py
+9
-2
src/api/api_ir.cc
+1
-1
src/api/api_lang.cc
+2
-1
src/codegen/codegen_c.cc
+1
-0
src/codegen/llvm/codegen_llvm.cc
+25
-0
src/codegen/stack_vm/codegen_stack_vm.cc
+1
-0
src/lang/buffer.cc
+14
-5
src/pass/arg_binder.cc
+25
-11
src/pass/ir_deep_compare.cc
+1
-1
src/pass/ir_mutator.cc
+5
-2
src/pass/ir_util.cc
+4
-1
src/pass/ir_visitor.cc
+1
-0
src/pass/make_api.cc
+4
-4
tests/python/unittest/test_codegen_llvm.py
+6
-3
No files found.
HalideIR
@
36ecc1ee
Subproject commit
860199eea031a4ea694b8fce03ad0bf8127910ac
Subproject commit
36ecc1eec0898411ae70e98c315b03247d5fb4a0
include/tvm/arithmetic.h
View file @
00506a62
...
@@ -119,9 +119,9 @@ class IntSet : public NodeRef {
...
@@ -119,9 +119,9 @@ class IntSet : public NodeRef {
*/
*/
struct
ModularEntry
{
struct
ModularEntry
{
/*! \brief The base */
/*! \brief The base */
int
base
;
int
base
{
0
}
;
/*! \brief linear co-efficient */
/*! \brief linear co-efficient */
int
coeff
;
int
coeff
{
1
}
;
/*! \return entry represent everything */
/*! \return entry represent everything */
static
ModularEntry
everything
()
{
static
ModularEntry
everything
()
{
...
...
include/tvm/buffer.h
View file @
00506a62
...
@@ -68,7 +68,10 @@ class Buffer : public NodeRef {
...
@@ -68,7 +68,10 @@ class Buffer : public NodeRef {
class
BufferNode
:
public
Node
{
class
BufferNode
:
public
Node
{
public
:
public
:
// Data fields.
// Data fields.
/*! \brief The pointer to the head of the data */
/*!
* \brief The pointer to the head of the data
* \sa data_alignment The alignment of data in bytes.
*/
Var
data
;
Var
data
;
/*! \brief data type in the content of the tensor */
/*! \brief data type in the content of the tensor */
Type
dtype
;
Type
dtype
;
...
@@ -86,8 +89,13 @@ class BufferNode : public Node {
...
@@ -86,8 +89,13 @@ class BufferNode : public Node {
std
::
string
name
;
std
::
string
name
;
/*! \brief storage scope of the buffer, if other than global */
/*! \brief storage scope of the buffer, if other than global */
std
::
string
scope
;
std
::
string
scope
;
/*! \brief Alignment multiple in terms of dtype elements (including lanes) */
/*! \brief Alignment requirement of data pointer in bytes. */
int
offset_alignment
;
int
data_alignment
;
/*!
* \brief Factor of elem_offset field,
* elem_offset is guaranteed to be multiple of offset_factor.
*/
int
offset_factor
;
/*! \brief constructor */
/*! \brief constructor */
BufferNode
()
{}
BufferNode
()
{}
...
@@ -99,9 +107,12 @@ class BufferNode : public Node {
...
@@ -99,9 +107,12 @@ class BufferNode : public Node {
v
->
Visit
(
"elem_offset"
,
&
elem_offset
);
v
->
Visit
(
"elem_offset"
,
&
elem_offset
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"offset_alignment"
,
&
offset_alignment
);
v
->
Visit
(
"data_alignment"
,
&
data_alignment
);
v
->
Visit
(
"offset_factor"
,
&
offset_factor
);
}
}
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
static
Buffer
make
(
Var
ptr
,
static
Buffer
make
(
Var
ptr
,
Type
dtype
,
Type
dtype
,
Array
<
Expr
>
shape
,
Array
<
Expr
>
shape
,
...
@@ -109,7 +120,8 @@ class BufferNode : public Node {
...
@@ -109,7 +120,8 @@ class BufferNode : public Node {
Expr
byte_offset
,
Expr
byte_offset
,
std
::
string
name
,
std
::
string
name
,
std
::
string
scope
,
std
::
string
scope
,
int
offset_alignment
);
int
data_alignment
,
int
offset_factor
);
static
constexpr
const
char
*
_type_key
=
"Buffer"
;
static
constexpr
const
char
*
_type_key
=
"Buffer"
;
TVM_DECLARE_NODE_TYPE_INFO
(
BufferNode
,
Node
);
TVM_DECLARE_NODE_TYPE_INFO
(
BufferNode
,
Node
);
...
...
include/tvm/ir.h
View file @
00506a62
...
@@ -135,7 +135,7 @@ struct TensorKey {
...
@@ -135,7 +135,7 @@ struct TensorKey {
}
}
};
};
/*! \brief namespace of possible attribute sin AttrStmt.
type
_key */
/*! \brief namespace of possible attribute sin AttrStmt.
attr
_key */
namespace
attr
{
namespace
attr
{
// The above attr does not pass to ir stage.
// The above attr does not pass to ir stage.
/*! \brief Mark launching extent of thread, used by device API. */
/*! \brief Mark launching extent of thread, used by device API. */
...
...
python/tvm/api.py
View file @
00506a62
...
@@ -390,7 +390,8 @@ def decl_buffer(shape,
...
@@ -390,7 +390,8 @@ def decl_buffer(shape,
strides
=
None
,
strides
=
None
,
elem_offset
=
None
,
elem_offset
=
None
,
scope
=
""
,
scope
=
""
,
offset_alignment
=
0
):
data_alignment
=
0
,
offset_factor
=
0
):
"""Decleare a new symbolic buffer.
"""Decleare a new symbolic buffer.
Normally buffer is created automatically during lower and build.
Normally buffer is created automatically during lower and build.
...
@@ -423,8 +424,15 @@ def decl_buffer(shape,
...
@@ -423,8 +424,15 @@ def decl_buffer(shape,
The storage scope of the buffer, if not global.
The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory.
If scope equals empty string, it means it is global memory.
offset_alignment: int, optional
data_alignment: int, optional
The alignment of offset
The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional
The factor of elem_offset field, when set,
elem_offset is required to be multiple of offset_factor.
If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
Returns
Returns
-------
-------
...
@@ -447,11 +455,14 @@ def decl_buffer(shape,
...
@@ -447,11 +455,14 @@ def decl_buffer(shape,
shape
=
(
shape
,)
if
isinstance
(
shape
,
(
_expr
.
Expr
,
_Integral
))
else
shape
shape
=
(
shape
,)
if
isinstance
(
shape
,
(
_expr
.
Expr
,
_Integral
))
else
shape
dtype
=
float32
if
dtype
is
None
else
dtype
dtype
=
float32
if
dtype
is
None
else
dtype
strides
=
()
if
strides
is
None
else
strides
strides
=
()
if
strides
is
None
else
strides
if
offset_factor
!=
0
and
elem_offset
is
None
:
elem_offset
=
var
(
'
%
s_elem_offset'
%
name
,
shape
[
0
]
.
dtype
)
if
data
is
None
:
if
data
is
None
:
data
=
var
(
name
,
"handle"
)
data
=
var
(
name
,
"handle"
)
return
_api_internal
.
_Buffer
(
return
_api_internal
.
_Buffer
(
data
,
dtype
,
shape
,
strides
,
elem_offset
,
name
,
scope
,
offset_alignment
)
data
,
dtype
,
shape
,
strides
,
elem_offset
,
name
,
scope
,
data_alignment
,
offset_factor
)
def
_IterVar
(
dom
,
name
,
iter_type
,
thread_tag
=
''
):
def
_IterVar
(
dom
,
name
,
iter_type
,
thread_tag
=
''
):
...
...
python/tvm/build.py
View file @
00506a62
...
@@ -26,7 +26,8 @@ class BuildConfig(object):
...
@@ -26,7 +26,8 @@ class BuildConfig(object):
'auto_unroll_max_step'
:
0
,
'auto_unroll_max_step'
:
0
,
'auto_unroll_min_depth'
:
1
,
'auto_unroll_min_depth'
:
1
,
'unroll_explicit'
:
True
,
'unroll_explicit'
:
True
,
'detect_global_barrier'
:
False
'detect_global_barrier'
:
False
,
'offset_factor'
:
0
}
}
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
self
.
_old_scope
=
None
self
.
_old_scope
=
None
...
@@ -76,6 +77,10 @@ def build_config(**kwargs):
...
@@ -76,6 +77,10 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True
detect_global_barrier: bool, default=True
Whether detect global barrier.
Whether detect global barrier.
offset_factor: int, default=0
The factor used in default buffer declaration.
If specified as 0, offset field is not used.
Returns
Returns
-------
-------
config: BuildConfig
config: BuildConfig
...
@@ -105,10 +110,12 @@ def get_binds(args, binds=None):
...
@@ -105,10 +110,12 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments.
The list of symbolic buffers of arguments.
"""
"""
binds
=
{}
if
binds
is
None
else
binds
.
copy
()
binds
=
{}
if
binds
is
None
else
binds
.
copy
()
offset_factor
=
BuildConfig
.
current
.
offset_factor
arg_list
=
[]
arg_list
=
[]
for
x
in
args
:
for
x
in
args
:
if
isinstance
(
x
,
tensor
.
Tensor
):
if
isinstance
(
x
,
tensor
.
Tensor
):
buf
=
api
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
buf
=
api
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
,
offset_factor
=
offset_factor
)
assert
x
not
in
binds
assert
x
not
in
binds
binds
[
x
]
=
buf
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
arg_list
.
append
(
buf
)
...
...
src/api/api_ir.cc
View file @
00506a62
...
@@ -143,7 +143,7 @@ REGISTER_MAKE2(Cast);
...
@@ -143,7 +143,7 @@ REGISTER_MAKE2(Cast);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE
2
(
AssertStmt
);
REGISTER_MAKE
3
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE5
(
Allocate
);
REGISTER_MAKE5
(
Allocate
);
REGISTER_MAKE4
(
Provide
);
REGISTER_MAKE4
(
Provide
);
...
...
src/api/api_lang.cc
View file @
00506a62
...
@@ -152,7 +152,8 @@ TVM_REGISTER_API("_Buffer")
...
@@ -152,7 +152,8 @@ TVM_REGISTER_API("_Buffer")
args
[
4
],
args
[
4
],
args
[
5
],
args
[
5
],
args
[
6
],
args
[
6
],
args
[
7
]);
args
[
7
],
args
[
8
]);
});
});
TVM_REGISTER_API
(
"_Tensor"
)
TVM_REGISTER_API
(
"_Tensor"
)
...
...
src/codegen/codegen_c.cc
View file @
00506a62
...
@@ -724,6 +724,7 @@ void CodeGenC::VisitStmt_(const AssertStmt* op) {
...
@@ -724,6 +724,7 @@ void CodeGenC::VisitStmt_(const AssertStmt* op) {
}
else
{
}
else
{
stream
<<
"assert("
<<
cond
<<
");
\n
"
;
stream
<<
"assert("
<<
cond
<<
");
\n
"
;
}
}
this
->
PrintStmt
(
op
->
body
);
}
}
void
CodeGenC
::
VisitStmt_
(
const
For
*
op
)
{
void
CodeGenC
::
VisitStmt_
(
const
For
*
op
)
{
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
00506a62
...
@@ -1377,6 +1377,31 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
...
@@ -1377,6 +1377,31 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_
->
CreateRet
(
llvm
::
ConstantInt
::
getSigned
(
t_int32_
,
-
1
));
builder_
->
CreateRet
(
llvm
::
ConstantInt
::
getSigned
(
t_int32_
,
-
1
));
// otherwise set it to be new end.
// otherwise set it to be new end.
builder_
->
SetInsertPoint
(
end_block
);
builder_
->
SetInsertPoint
(
end_block
);
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) move these pattern to a generic scope info visitor.
if
(
const
EQ
*
eq
=
op
->
condition
.
as
<
EQ
>
())
{
const
Mod
*
mod
=
eq
->
a
.
as
<
Mod
>
();
int64_t
factor
,
offset
;
if
(
mod
&&
arith
::
GetConst
(
eq
->
b
,
&
offset
))
{
const
Variable
*
var
=
mod
->
a
.
as
<
Variable
>
();
if
(
var
&&
arith
::
GetConst
(
mod
->
b
,
&
factor
))
{
arith
::
ModularEntry
old
=
align_map_
[
var
];
if
(
factor
>
old
.
coeff
)
{
arith
::
ModularEntry
e
;
e
.
coeff
=
static_cast
<
int
>
(
factor
);
e
.
base
=
static_cast
<
int
>
(
offset
);
// new alignment info,
align_map_
[
var
]
=
e
;
this
->
VisitStmt
(
op
->
body
);
// restore old info
align_map_
[
var
]
=
old
;
return
;
}
}
}
}
this
->
VisitStmt
(
op
->
body
);
}
}
void
CodeGenLLVM
::
VisitStmt_
(
const
LetStmt
*
op
)
{
void
CodeGenLLVM
::
VisitStmt_
(
const
LetStmt
*
op
)
{
...
...
src/codegen/stack_vm/codegen_stack_vm.cc
View file @
00506a62
...
@@ -456,6 +456,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
...
@@ -456,6 +456,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
this
->
Push
(
op
->
condition
);
this
->
Push
(
op
->
condition
);
this
->
PushOp
(
StackVM
::
ASSERT
,
sid
);
this
->
PushOp
(
StackVM
::
ASSERT
,
sid
);
}
}
this
->
Push
(
op
->
body
);
}
}
void
CodeGenStackVM
::
VisitStmt_
(
const
AttrStmt
*
op
)
{
void
CodeGenStackVM
::
VisitStmt_
(
const
AttrStmt
*
op
)
{
...
...
src/lang/buffer.cc
View file @
00506a62
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
* \file buffer.cc
* \file buffer.cc
*/
*/
#include <tvm/buffer.h>
#include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
...
@@ -26,7 +27,9 @@ Buffer decl_buffer(Array<Expr> shape,
...
@@ -26,7 +27,9 @@ Buffer decl_buffer(Array<Expr> shape,
shape
,
shape
,
Array
<
Expr
>
(),
Array
<
Expr
>
(),
Expr
(),
Expr
(),
name
,
""
,
0
);
name
,
""
,
0
,
0
);
}
}
// The buffer offset in convention of number of elements of
// The buffer offset in convention of number of elements of
...
@@ -124,6 +127,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
...
@@ -124,6 +127,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
elem_offset
,
elem_offset
,
n
->
name
+
"_slice"
,
n
->
name
+
"_slice"
,
n
->
scope
,
n
->
scope
,
n
->
data_alignment
,
0
);
0
);
}
}
...
@@ -134,7 +138,8 @@ Buffer BufferNode::make(Var data,
...
@@ -134,7 +138,8 @@ Buffer BufferNode::make(Var data,
Expr
elem_offset
,
Expr
elem_offset
,
std
::
string
name
,
std
::
string
name
,
std
::
string
scope
,
std
::
string
scope
,
int
offset_alignment
)
{
int
data_alignment
,
int
offset_factor
)
{
auto
n
=
std
::
make_shared
<
BufferNode
>
();
auto
n
=
std
::
make_shared
<
BufferNode
>
();
n
->
data
=
std
::
move
(
data
);
n
->
data
=
std
::
move
(
data
);
n
->
dtype
=
dtype
;
n
->
dtype
=
dtype
;
...
@@ -145,11 +150,15 @@ Buffer BufferNode::make(Var data,
...
@@ -145,11 +150,15 @@ Buffer BufferNode::make(Var data,
if
(
!
elem_offset
.
defined
())
{
if
(
!
elem_offset
.
defined
())
{
elem_offset
=
make_const
(
n
->
shape
[
0
].
type
(),
0
);
elem_offset
=
make_const
(
n
->
shape
[
0
].
type
(),
0
);
}
}
if
(
offset_alignment
==
0
)
{
if
(
data_alignment
==
0
)
{
offset_alignment
=
1
;
data_alignment
=
runtime
::
kAllocAlignment
;
}
if
(
offset_factor
==
0
)
{
offset_factor
=
1
;
}
}
n
->
elem_offset
=
elem_offset
;
n
->
elem_offset
=
elem_offset
;
n
->
offset_alignment
=
offset_alignment
;
n
->
data_alignment
=
data_alignment
;
n
->
offset_factor
=
offset_factor
;
return
Buffer
(
n
);
return
Buffer
(
n
);
}
}
...
...
src/pass/arg_binder.cc
View file @
00506a62
...
@@ -24,7 +24,7 @@ void BinderAddAssert(Expr cond,
...
@@ -24,7 +24,7 @@ void BinderAddAssert(Expr cond,
if
(
!
is_one
(
cond
))
{
if
(
!
is_one
(
cond
))
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
"Argument "
<<
arg_name
<<
" has an unsatisfied constraint"
;
os
<<
"Argument "
<<
arg_name
<<
" has an unsatisfied constraint"
;
asserts
->
emplace_back
(
AssertStmt
::
make
(
cond
,
os
.
str
()));
asserts
->
emplace_back
(
AssertStmt
::
make
(
cond
,
os
.
str
()
,
Evaluate
::
make
(
0
)
));
}
}
}
}
...
@@ -107,7 +107,14 @@ void ArgBinder::BindBuffer(const Buffer& arg,
...
@@ -107,7 +107,14 @@ void ArgBinder::BindBuffer(const Buffer& arg,
this
->
BindArray
(
arg
->
shape
,
value
->
shape
,
arg_name
+
".shape"
);
this
->
BindArray
(
arg
->
shape
,
value
->
shape
,
arg_name
+
".shape"
);
this
->
BindArray
(
arg
->
strides
,
value
->
strides
,
arg_name
+
".strides"
);
this
->
BindArray
(
arg
->
strides
,
value
->
strides
,
arg_name
+
".strides"
);
}
}
this
->
Bind
(
arg
->
elem_offset
,
value
->
elem_offset
,
arg_name
+
".elem_offset"
);
if
(
Bind_
(
arg
->
elem_offset
,
value
->
elem_offset
,
arg_name
+
".elem_offset"
,
false
))
{
if
(
arg
->
offset_factor
>
1
)
{
Expr
offset
=
value
->
elem_offset
;
Expr
factor
=
make_const
(
offset
.
type
(),
arg
->
offset_factor
);
Expr
zero
=
make_zero
(
offset
.
type
());
BinderAddAssert
(
offset
%
factor
==
zero
,
arg_name
+
".elem_offset"
,
&
asserts_
);
}
}
}
}
inline
Expr
TVMArrayGet
(
Type
t
,
Var
arr
,
intrinsic
::
TVMStructFieldKind
kind
)
{
inline
Expr
TVMArrayGet
(
Type
t
,
Var
arr
,
intrinsic
::
TVMStructFieldKind
kind
)
{
...
@@ -117,7 +124,7 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
...
@@ -117,7 +124,7 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
inline
Stmt
AssertNull
(
Var
handle
,
std
::
string
msg
)
{
inline
Stmt
AssertNull
(
Var
handle
,
std
::
string
msg
)
{
return
AssertStmt
::
make
(
Call
::
make
(
return
AssertStmt
::
make
(
Call
::
make
(
Bool
(
1
),
intrinsic
::
tvm_handle_is_null
,
Bool
(
1
),
intrinsic
::
tvm_handle_is_null
,
{
handle
},
Call
::
PureIntrinsic
),
msg
);
{
handle
},
Call
::
PureIntrinsic
),
msg
,
Evaluate
::
make
(
0
)
);
}
}
void
ArgBinder
::
BindDLTensor
(
const
Buffer
&
buffer
,
void
ArgBinder
::
BindDLTensor
(
const
Buffer
&
buffer
,
...
@@ -136,7 +143,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
...
@@ -136,7 +143,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
ndim_err_msg
<<
arg_name
ndim_err_msg
<<
arg_name
<<
".ndim is expected to equal "
<<
".ndim is expected to equal "
<<
buffer
->
shape
.
size
();
<<
buffer
->
shape
.
size
();
asserts_
.
emplace_back
(
AssertStmt
::
make
(
a_ndim
==
v_ndim
,
ndim_err_msg
.
str
()));
asserts_
.
emplace_back
(
AssertStmt
::
make
(
a_ndim
==
v_ndim
,
ndim_err_msg
.
str
()
,
nop
));
// type checks
// type checks
Type
dtype
=
buffer
->
dtype
;
Type
dtype
=
buffer
->
dtype
;
std
::
ostringstream
type_err_msg
;
std
::
ostringstream
type_err_msg
;
...
@@ -147,7 +154,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
...
@@ -147,7 +154,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
UIntImm
::
make
(
UInt
(
8
),
dtype
.
bits
())
&&
UIntImm
::
make
(
UInt
(
8
),
dtype
.
bits
())
&&
TVMArrayGet
(
UInt
(
16
),
handle
,
intrinsic
::
kArrTypeLanes
)
==
TVMArrayGet
(
UInt
(
16
),
handle
,
intrinsic
::
kArrTypeLanes
)
==
UIntImm
::
make
(
UInt
(
16
),
dtype
.
lanes
()));
UIntImm
::
make
(
UInt
(
16
),
dtype
.
lanes
()));
asserts_
.
emplace_back
(
AssertStmt
::
make
(
cond
,
type_err_msg
.
str
()));
asserts_
.
emplace_back
(
AssertStmt
::
make
(
cond
,
type_err_msg
.
str
()
,
nop
));
// data field
// data field
if
(
Bind_
(
buffer
->
data
,
TVMArrayGet
(
Handle
(),
handle
,
intrinsic
::
kArrData
),
if
(
Bind_
(
buffer
->
data
,
TVMArrayGet
(
Handle
(),
handle
,
intrinsic
::
kArrData
),
arg_name
+
".data"
,
true
))
{
arg_name
+
".data"
,
true
))
{
...
@@ -156,7 +163,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
...
@@ -156,7 +163,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
// mark alignment of external bufs
// mark alignment of external bufs
init_nest_
.
emplace_back
(
AttrStmt
::
make
(
init_nest_
.
emplace_back
(
AttrStmt
::
make
(
vptr
,
ir
::
attr
::
storage_alignment
,
vptr
,
ir
::
attr
::
storage_alignment
,
IntImm
::
make
(
Int
(
32
),
runtime
::
kAllocA
lignment
),
nop
));
IntImm
::
make
(
Int
(
32
),
buffer
->
data_a
lignment
),
nop
));
}
}
Var
v_shape
(
arg_name
+
".shape"
,
Handle
());
Var
v_shape
(
arg_name
+
".shape"
,
Handle
());
...
@@ -202,11 +209,18 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
...
@@ -202,11 +209,18 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
TVMArrayGet
(
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
),
TVMArrayGet
(
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
),
arg_name
+
".byte_offset"
,
true
);
arg_name
+
".byte_offset"
,
true
);
}
else
{
}
else
{
Bind_
(
buffer
->
elem_offset
,
if
(
Bind_
(
buffer
->
elem_offset
,
cast
(
buffer
->
elem_offset
.
type
(),
cast
(
buffer
->
elem_offset
.
type
(),
(
TVMArrayGet
(
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
)
/
(
TVMArrayGet
(
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
)
/
make_const
(
UInt
(
64
),
data_bytes
))),
make_const
(
UInt
(
64
),
data_bytes
))),
arg_name
+
".elem_offset"
,
true
);
arg_name
+
".elem_offset"
,
true
))
{
if
(
buffer
->
offset_factor
>
1
)
{
Expr
offset
=
buffer
->
elem_offset
;
Expr
factor
=
make_const
(
offset
.
type
(),
buffer
->
offset_factor
);
Expr
zero
=
make_zero
(
offset
.
type
());
BinderAddAssert
(
offset
%
factor
==
zero
,
arg_name
+
".elem_offset"
,
&
asserts_
);
}
}
}
}
// device info.
// device info.
Bind_
(
device_type
,
Bind_
(
device_type
,
...
...
src/pass/ir_deep_compare.cc
View file @
00506a62
...
@@ -118,6 +118,7 @@ class IRDeepCompare :
...
@@ -118,6 +118,7 @@ class IRDeepCompare :
const
AssertStmt
*
rhs
=
other
.
as
<
AssertStmt
>
();
const
AssertStmt
*
rhs
=
other
.
as
<
AssertStmt
>
();
if
(
CompareExpr
(
op
->
condition
,
rhs
->
condition
)
!=
0
)
return
;
if
(
CompareExpr
(
op
->
condition
,
rhs
->
condition
)
!=
0
)
return
;
if
(
CompareExpr
(
op
->
message
,
rhs
->
message
)
!=
0
)
return
;
if
(
CompareExpr
(
op
->
message
,
rhs
->
message
)
!=
0
)
return
;
if
(
CompareStmt
(
op
->
body
,
rhs
->
body
)
!=
0
)
return
;
}
}
void
VisitStmt_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
other
)
final
{
void
VisitStmt_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
other
)
final
{
...
@@ -127,7 +128,6 @@ class IRDeepCompare :
...
@@ -127,7 +128,6 @@ class IRDeepCompare :
if
(
CompareStmt
(
op
->
body
,
rhs
->
body
)
!=
0
)
return
;
if
(
CompareStmt
(
op
->
body
,
rhs
->
body
)
!=
0
)
return
;
}
}
void
VisitStmt_
(
const
Provide
*
op
,
const
Stmt
&
other
)
final
{
void
VisitStmt_
(
const
Provide
*
op
,
const
Stmt
&
other
)
final
{
const
Provide
*
rhs
=
other
.
as
<
Provide
>
();
const
Provide
*
rhs
=
other
.
as
<
Provide
>
();
if
(
CompareNodeRef
(
op
->
func
,
rhs
->
func
)
!=
0
)
return
;
if
(
CompareNodeRef
(
op
->
func
,
rhs
->
func
)
!=
0
)
return
;
...
...
src/pass/ir_mutator.cc
View file @
00506a62
...
@@ -219,11 +219,14 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
...
@@ -219,11 +219,14 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt
IRMutator
::
Mutate_
(
const
AssertStmt
*
op
,
const
Stmt
&
s
)
{
Stmt
IRMutator
::
Mutate_
(
const
AssertStmt
*
op
,
const
Stmt
&
s
)
{
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Expr
message
=
this
->
Mutate
(
op
->
message
);
Expr
message
=
this
->
Mutate
(
op
->
message
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
))
{
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
return
s
;
}
else
{
}
else
{
return
AssertStmt
::
make
(
condition
,
message
);
return
AssertStmt
::
make
(
condition
,
message
,
body
);
}
}
}
}
...
...
src/pass/ir_util.cc
View file @
00506a62
...
@@ -34,7 +34,10 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
...
@@ -34,7 +34,10 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
n
->
then_case
=
body
;
n
->
then_case
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
AssertStmt
>
())
{
}
else
if
(
s
.
as
<
AssertStmt
>
())
{
body
=
Block
::
make
(
s
,
body
);
auto
n
=
std
::
make_shared
<
AssertStmt
>
(
*
s
.
as
<
AssertStmt
>
());
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
Allocate
>
())
{
}
else
if
(
s
.
as
<
Allocate
>
())
{
auto
n
=
std
::
make_shared
<
Allocate
>
(
*
s
.
as
<
Allocate
>
());
auto
n
=
std
::
make_shared
<
Allocate
>
(
*
s
.
as
<
Allocate
>
());
CHECK
(
is_no_op
(
n
->
body
));
CHECK
(
is_no_op
(
n
->
body
));
...
...
src/pass/ir_visitor.cc
View file @
00506a62
...
@@ -162,6 +162,7 @@ void IRVisitor::Visit_(const Broadcast *op) {
...
@@ -162,6 +162,7 @@ void IRVisitor::Visit_(const Broadcast *op) {
void
IRVisitor
::
Visit_
(
const
AssertStmt
*
op
)
{
void
IRVisitor
::
Visit_
(
const
AssertStmt
*
op
)
{
this
->
Visit
(
op
->
condition
);
this
->
Visit
(
op
->
condition
);
this
->
Visit
(
op
->
message
);
this
->
Visit
(
op
->
message
);
this
->
Visit
(
op
->
body
);
}
}
void
IRVisitor
::
Visit_
(
const
ProducerConsumer
*
op
)
{
void
IRVisitor
::
Visit_
(
const
ProducerConsumer
*
op
)
{
...
...
src/pass/make_api.cc
View file @
00506a62
...
@@ -19,7 +19,7 @@ namespace tvm {
...
@@ -19,7 +19,7 @@ namespace tvm {
namespace
ir
{
namespace
ir
{
inline
Stmt
MakeAssertEQ
(
Expr
lhs
,
Expr
rhs
,
std
::
string
msg
)
{
inline
Stmt
MakeAssertEQ
(
Expr
lhs
,
Expr
rhs
,
std
::
string
msg
)
{
return
AssertStmt
::
make
(
lhs
==
rhs
,
msg
);
return
AssertStmt
::
make
(
lhs
==
rhs
,
msg
,
Evaluate
::
make
(
0
)
);
}
}
LoweredFunc
MakeAPI
(
Stmt
body
,
LoweredFunc
MakeAPI
(
Stmt
body
,
...
@@ -100,16 +100,16 @@ LoweredFunc MakeAPI(Stmt body,
...
@@ -100,16 +100,16 @@ LoweredFunc MakeAPI(Stmt body,
seq_check
.
emplace_back
(
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kHandle
||
AssertStmt
::
make
(
tcode
==
kHandle
||
tcode
==
kArrayHandle
||
tcode
==
kArrayHandle
||
tcode
==
kNull
,
msg
.
str
()));
tcode
==
kNull
,
msg
.
str
()
,
nop
));
}
else
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
}
else
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
std
::
ostringstream
msg
;
std
::
ostringstream
msg
;
msg
<<
"Expect argument "
<<
i
<<
" to be int"
;
msg
<<
"Expect argument "
<<
i
<<
" to be int"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kInt
,
msg
.
str
()));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kInt
,
msg
.
str
()
,
nop
));
}
else
{
}
else
{
CHECK
(
t
.
is_float
());
CHECK
(
t
.
is_float
());
std
::
ostringstream
msg
;
std
::
ostringstream
msg
;
msg
<<
"Expect argument "
<<
i
<<
" to be float"
;
msg
<<
"Expect argument "
<<
i
<<
" to be float"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kFloat
,
msg
.
str
()));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kFloat
,
msg
.
str
()
,
nop
));
}
}
}
else
{
}
else
{
args
.
push_back
(
v_arg
);
args
.
push_back
(
v_arg
);
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
00506a62
...
@@ -16,10 +16,12 @@ def test_llvm_add_pipeline():
...
@@ -16,10 +16,12 @@ def test_llvm_add_pipeline():
return
return
# Specifically allow offset to test codepath when offset is available
# Specifically allow offset to test codepath when offset is available
Ab
=
tvm
.
decl_buffer
(
Ab
=
tvm
.
decl_buffer
(
A
.
shape
,
A
.
dtype
,
elem_offset
=
tvm
.
var
(
'Aoffset'
),
A
.
shape
,
A
.
dtype
,
elem_offset
=
tvm
.
var
(
'Aoffset'
),
offset_factor
=
8
,
name
=
'A'
)
name
=
'A'
)
binds
=
{
A
:
Ab
}
binds
=
{
A
:
Ab
}
#
build
and invoke the kernel.
#
BUILD
and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
Ab
,
B
,
C
],
"llvm"
,
binds
=
binds
)
f
=
tvm
.
build
(
s
,
[
Ab
,
B
,
C
],
"llvm"
,
binds
=
binds
)
ctx
=
tvm
.
cpu
(
0
)
ctx
=
tvm
.
cpu
(
0
)
# launch the kernel.
# launch the kernel.
...
@@ -31,7 +33,8 @@ def test_llvm_add_pipeline():
...
@@ -31,7 +33,8 @@ def test_llvm_add_pipeline():
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
check_llvm
()
with
tvm
.
build_config
(
offset_factor
=
4
):
check_llvm
()
def
test_llvm_flip_pipeline
():
def
test_llvm_flip_pipeline
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment