Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
00506a62
Commit
00506a62
authored
Jul 05, 2017
by
Tianqi Chen
Committed by
GitHub
Jul 05, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Add body to AssertStmt (#220)
* [IR] Add body to AssertStmt * fix lint
parent
c9da7254
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
135 additions
and
44 deletions
+135
-44
HalideIR
+1
-1
include/tvm/arithmetic.h
+2
-2
include/tvm/buffer.h
+17
-5
include/tvm/ir.h
+1
-1
python/tvm/api.py
+15
-4
python/tvm/build.py
+9
-2
src/api/api_ir.cc
+1
-1
src/api/api_lang.cc
+2
-1
src/codegen/codegen_c.cc
+1
-0
src/codegen/llvm/codegen_llvm.cc
+25
-0
src/codegen/stack_vm/codegen_stack_vm.cc
+1
-0
src/lang/buffer.cc
+14
-5
src/pass/arg_binder.cc
+25
-11
src/pass/ir_deep_compare.cc
+1
-1
src/pass/ir_mutator.cc
+5
-2
src/pass/ir_util.cc
+4
-1
src/pass/ir_visitor.cc
+1
-0
src/pass/make_api.cc
+4
-4
tests/python/unittest/test_codegen_llvm.py
+6
-3
No files found.
HalideIR
@
36ecc1ee
Subproject commit
860199eea031a4ea694b8fce03ad0bf8127910ac
Subproject commit
36ecc1eec0898411ae70e98c315b03247d5fb4a0
include/tvm/arithmetic.h
View file @
00506a62
...
...
@@ -119,9 +119,9 @@ class IntSet : public NodeRef {
*/
struct
ModularEntry
{
/*! \brief The base */
int
base
;
int
base
{
0
}
;
/*! \brief linear co-efficient */
int
coeff
;
int
coeff
{
1
}
;
/*! \return entry represent everything */
static
ModularEntry
everything
()
{
...
...
include/tvm/buffer.h
View file @
00506a62
...
...
@@ -68,7 +68,10 @@ class Buffer : public NodeRef {
class
BufferNode
:
public
Node
{
public
:
// Data fields.
/*! \brief The pointer to the head of the data */
/*!
* \brief The pointer to the head of the data
* \sa data_alignment The alignment of data in bytes.
*/
Var
data
;
/*! \brief data type in the content of the tensor */
Type
dtype
;
...
...
@@ -86,8 +89,13 @@ class BufferNode : public Node {
std
::
string
name
;
/*! \brief storage scope of the buffer, if other than global */
std
::
string
scope
;
/*! \brief Alignment multiple in terms of dtype elements (including lanes) */
int
offset_alignment
;
/*! \brief Alignment requirement of data pointer in bytes. */
int
data_alignment
;
/*!
* \brief Factor of elem_offset field,
* elem_offset is guaranteed to be multiple of offset_factor.
*/
int
offset_factor
;
/*! \brief constructor */
BufferNode
()
{}
...
...
@@ -99,9 +107,12 @@ class BufferNode : public Node {
v
->
Visit
(
"elem_offset"
,
&
elem_offset
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"offset_alignment"
,
&
offset_alignment
);
v
->
Visit
(
"data_alignment"
,
&
data_alignment
);
v
->
Visit
(
"offset_factor"
,
&
offset_factor
);
}
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
static
Buffer
make
(
Var
ptr
,
Type
dtype
,
Array
<
Expr
>
shape
,
...
...
@@ -109,7 +120,8 @@ class BufferNode : public Node {
Expr
byte_offset
,
std
::
string
name
,
std
::
string
scope
,
int
offset_alignment
);
int
data_alignment
,
int
offset_factor
);
static
constexpr
const
char
*
_type_key
=
"Buffer"
;
TVM_DECLARE_NODE_TYPE_INFO
(
BufferNode
,
Node
);
...
...
include/tvm/ir.h
View file @
00506a62
...
...
@@ -135,7 +135,7 @@ struct TensorKey {
}
};
/*! \brief namespace of possible attribute sin AttrStmt.
type
_key */
/*! \brief namespace of possible attribute sin AttrStmt.
attr
_key */
namespace
attr
{
// The above attr does not pass to ir stage.
/*! \brief Mark launching extent of thread, used by device API. */
...
...
python/tvm/api.py
View file @
00506a62
...
...
@@ -390,7 +390,8 @@ def decl_buffer(shape,
strides
=
None
,
elem_offset
=
None
,
scope
=
""
,
offset_alignment
=
0
):
data_alignment
=
0
,
offset_factor
=
0
):
"""Decleare a new symbolic buffer.
Normally buffer is created automatically during lower and build.
...
...
@@ -423,8 +424,15 @@ def decl_buffer(shape,
The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory.
offset_alignment: int, optional
The alignment of offset
data_alignment: int, optional
The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional
The factor of elem_offset field, when set,
elem_offset is required to be multiple of offset_factor.
If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
Returns
-------
...
...
@@ -447,11 +455,14 @@ def decl_buffer(shape,
shape
=
(
shape
,)
if
isinstance
(
shape
,
(
_expr
.
Expr
,
_Integral
))
else
shape
dtype
=
float32
if
dtype
is
None
else
dtype
strides
=
()
if
strides
is
None
else
strides
if
offset_factor
!=
0
and
elem_offset
is
None
:
elem_offset
=
var
(
'
%
s_elem_offset'
%
name
,
shape
[
0
]
.
dtype
)
if
data
is
None
:
data
=
var
(
name
,
"handle"
)
return
_api_internal
.
_Buffer
(
data
,
dtype
,
shape
,
strides
,
elem_offset
,
name
,
scope
,
offset_alignment
)
data
,
dtype
,
shape
,
strides
,
elem_offset
,
name
,
scope
,
data_alignment
,
offset_factor
)
def
_IterVar
(
dom
,
name
,
iter_type
,
thread_tag
=
''
):
...
...
python/tvm/build.py
View file @
00506a62
...
...
@@ -26,7 +26,8 @@ class BuildConfig(object):
'auto_unroll_max_step'
:
0
,
'auto_unroll_min_depth'
:
1
,
'unroll_explicit'
:
True
,
'detect_global_barrier'
:
False
'detect_global_barrier'
:
False
,
'offset_factor'
:
0
}
def
__init__
(
self
,
**
kwargs
):
self
.
_old_scope
=
None
...
...
@@ -76,6 +77,10 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True
Whether detect global barrier.
offset_factor: int, default=0
The factor used in default buffer declaration.
If specified as 0, offset field is not used.
Returns
-------
config: BuildConfig
...
...
@@ -105,10 +110,12 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments.
"""
binds
=
{}
if
binds
is
None
else
binds
.
copy
()
offset_factor
=
BuildConfig
.
current
.
offset_factor
arg_list
=
[]
for
x
in
args
:
if
isinstance
(
x
,
tensor
.
Tensor
):
buf
=
api
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
buf
=
api
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
,
offset_factor
=
offset_factor
)
assert
x
not
in
binds
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
...
...
src/api/api_ir.cc
View file @
00506a62
...
...
@@ -143,7 +143,7 @@ REGISTER_MAKE2(Cast);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE
2
(
AssertStmt
);
REGISTER_MAKE
3
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE5
(
Allocate
);
REGISTER_MAKE4
(
Provide
);
...
...
src/api/api_lang.cc
View file @
00506a62
...
...
@@ -152,7 +152,8 @@ TVM_REGISTER_API("_Buffer")
args
[
4
],
args
[
5
],
args
[
6
],
args
[
7
]);
args
[
7
],
args
[
8
]);
});
TVM_REGISTER_API
(
"_Tensor"
)
...
...
src/codegen/codegen_c.cc
View file @
00506a62
...
...
@@ -724,6 +724,7 @@ void CodeGenC::VisitStmt_(const AssertStmt* op) {
}
else
{
stream
<<
"assert("
<<
cond
<<
");
\n
"
;
}
this
->
PrintStmt
(
op
->
body
);
}
void
CodeGenC
::
VisitStmt_
(
const
For
*
op
)
{
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
00506a62
...
...
@@ -1377,6 +1377,31 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_
->
CreateRet
(
llvm
::
ConstantInt
::
getSigned
(
t_int32_
,
-
1
));
// otherwise set it to be new end.
builder_
->
SetInsertPoint
(
end_block
);
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) move these pattern to a generic scope info visitor.
if
(
const
EQ
*
eq
=
op
->
condition
.
as
<
EQ
>
())
{
const
Mod
*
mod
=
eq
->
a
.
as
<
Mod
>
();
int64_t
factor
,
offset
;
if
(
mod
&&
arith
::
GetConst
(
eq
->
b
,
&
offset
))
{
const
Variable
*
var
=
mod
->
a
.
as
<
Variable
>
();
if
(
var
&&
arith
::
GetConst
(
mod
->
b
,
&
factor
))
{
arith
::
ModularEntry
old
=
align_map_
[
var
];
if
(
factor
>
old
.
coeff
)
{
arith
::
ModularEntry
e
;
e
.
coeff
=
static_cast
<
int
>
(
factor
);
e
.
base
=
static_cast
<
int
>
(
offset
);
// new alignment info,
align_map_
[
var
]
=
e
;
this
->
VisitStmt
(
op
->
body
);
// restore old info
align_map_
[
var
]
=
old
;
return
;
}
}
}
}
this
->
VisitStmt
(
op
->
body
);
}
void
CodeGenLLVM
::
VisitStmt_
(
const
LetStmt
*
op
)
{
...
...
src/codegen/stack_vm/codegen_stack_vm.cc
View file @
00506a62
...
...
@@ -456,6 +456,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
this
->
Push
(
op
->
condition
);
this
->
PushOp
(
StackVM
::
ASSERT
,
sid
);
}
this
->
Push
(
op
->
body
);
}
void
CodeGenStackVM
::
VisitStmt_
(
const
AttrStmt
*
op
)
{
...
...
src/lang/buffer.cc
View file @
00506a62
...
...
@@ -3,6 +3,7 @@
* \file buffer.cc
*/
#include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
...
...
@@ -26,7 +27,9 @@ Buffer decl_buffer(Array<Expr> shape,
shape
,
Array
<
Expr
>
(),
Expr
(),
name
,
""
,
0
);
name
,
""
,
0
,
0
);
}
// The buffer offset in convention of number of elements of
...
...
@@ -124,6 +127,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
elem_offset
,
n
->
name
+
"_slice"
,
n
->
scope
,
n
->
data_alignment
,
0
);
}
...
...
@@ -134,7 +138,8 @@ Buffer BufferNode::make(Var data,
Expr
elem_offset
,
std
::
string
name
,
std
::
string
scope
,
int
offset_alignment
)
{
int
data_alignment
,
int
offset_factor
)
{
auto
n
=
std
::
make_shared
<
BufferNode
>
();
n
->
data
=
std
::
move
(
data
);
n
->
dtype
=
dtype
;
...
...
@@ -145,11 +150,15 @@ Buffer BufferNode::make(Var data,
if
(
!
elem_offset
.
defined
())
{
elem_offset
=
make_const
(
n
->
shape
[
0
].
type
(),
0
);
}
if
(
offset_alignment
==
0
)
{
offset_alignment
=
1
;
if
(
data_alignment
==
0
)
{
data_alignment
=
runtime
::
kAllocAlignment
;
}
if
(
offset_factor
==
0
)
{
offset_factor
=
1
;
}
n
->
elem_offset
=
elem_offset
;
n
->
offset_alignment
=
offset_alignment
;
n
->
data_alignment
=
data_alignment
;
n
->
offset_factor
=
offset_factor
;
return
Buffer
(
n
);
}
...
...
src/pass/arg_binder.cc
View file @
00506a62
...
...
@@ -24,7 +24,7 @@ void BinderAddAssert(Expr cond,
if
(
!
is_one
(
cond
))
{
std
::
ostringstream
os
;
os
<<
"Argument "
<<
arg_name
<<
" has an unsatisfied constraint"
;
asserts
->
emplace_back
(
AssertStmt
::
make
(
cond
,
os
.
str
()));
asserts
->
emplace_back
(
AssertStmt
::
make
(
cond
,
os
.
str
()
,
Evaluate
::
make
(
0
)
));
}
}
...
...
@@ -107,7 +107,14 @@ void ArgBinder::BindBuffer(const Buffer& arg,
this
->
BindArray
(
arg
->
shape
,
value
->
shape
,
arg_name
+
".shape"
);
this
->
BindArray
(
arg
->
strides
,
value
->
strides
,
arg_name
+
".strides"
);
}
this
->
Bind
(
arg
->
elem_offset
,
value
->
elem_offset
,
arg_name
+
".elem_offset"
);
if
(
Bind_
(
arg
->
elem_offset
,
value
->
elem_offset
,
arg_name
+
".elem_offset"
,
false
))
{
if
(
arg
->
offset_factor
>
1
)
{
Expr
offset
=
value
->
elem_offset
;
Expr
factor
=
make_const
(
offset
.
type
(),
arg
->
offset_factor
);
Expr
zero
=
make_zero
(
offset
.
type
());
BinderAddAssert
(
offset
%
factor
==
zero
,
arg_name
+
".elem_offset"
,
&
asserts_
);
}
}
}
inline
Expr
TVMArrayGet
(
Type
t
,
Var
arr
,
intrinsic
::
TVMStructFieldKind
kind
)
{
...
...
@@ -117,7 +124,7 @@ inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
inline
Stmt
AssertNull
(
Var
handle
,
std
::
string
msg
)
{
return
AssertStmt
::
make
(
Call
::
make
(
Bool
(
1
),
intrinsic
::
tvm_handle_is_null
,
{
handle
},
Call
::
PureIntrinsic
),
msg
);
{
handle
},
Call
::
PureIntrinsic
),
msg
,
Evaluate
::
make
(
0
)
);
}
void
ArgBinder
::
BindDLTensor
(
const
Buffer
&
buffer
,
...
...
@@ -136,7 +143,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
ndim_err_msg
<<
arg_name
<<
".ndim is expected to equal "
<<
buffer
->
shape
.
size
();
asserts_
.
emplace_back
(
AssertStmt
::
make
(
a_ndim
==
v_ndim
,
ndim_err_msg
.
str
()));
asserts_
.
emplace_back
(
AssertStmt
::
make
(
a_ndim
==
v_ndim
,
ndim_err_msg
.
str
()
,
nop
));
// type checks
Type
dtype
=
buffer
->
dtype
;
std
::
ostringstream
type_err_msg
;
...
...
@@ -147,7 +154,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
UIntImm
::
make
(
UInt
(
8
),
dtype
.
bits
())
&&
TVMArrayGet
(
UInt
(
16
),
handle
,
intrinsic
::
kArrTypeLanes
)
==
UIntImm
::
make
(
UInt
(
16
),
dtype
.
lanes
()));
asserts_
.
emplace_back
(
AssertStmt
::
make
(
cond
,
type_err_msg
.
str
()));
asserts_
.
emplace_back
(
AssertStmt
::
make
(
cond
,
type_err_msg
.
str
()
,
nop
));
// data field
if
(
Bind_
(
buffer
->
data
,
TVMArrayGet
(
Handle
(),
handle
,
intrinsic
::
kArrData
),
arg_name
+
".data"
,
true
))
{
...
...
@@ -156,7 +163,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
// mark alignment of external bufs
init_nest_
.
emplace_back
(
AttrStmt
::
make
(
vptr
,
ir
::
attr
::
storage_alignment
,
IntImm
::
make
(
Int
(
32
),
runtime
::
kAllocA
lignment
),
nop
));
IntImm
::
make
(
Int
(
32
),
buffer
->
data_a
lignment
),
nop
));
}
Var
v_shape
(
arg_name
+
".shape"
,
Handle
());
...
...
@@ -202,11 +209,18 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
TVMArrayGet
(
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
),
arg_name
+
".byte_offset"
,
true
);
}
else
{
Bind_
(
buffer
->
elem_offset
,
cast
(
buffer
->
elem_offset
.
type
(),
(
TVMArrayGet
(
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
)
/
make_const
(
UInt
(
64
),
data_bytes
))),
arg_name
+
".elem_offset"
,
true
);
if
(
Bind_
(
buffer
->
elem_offset
,
cast
(
buffer
->
elem_offset
.
type
(),
(
TVMArrayGet
(
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
)
/
make_const
(
UInt
(
64
),
data_bytes
))),
arg_name
+
".elem_offset"
,
true
))
{
if
(
buffer
->
offset_factor
>
1
)
{
Expr
offset
=
buffer
->
elem_offset
;
Expr
factor
=
make_const
(
offset
.
type
(),
buffer
->
offset_factor
);
Expr
zero
=
make_zero
(
offset
.
type
());
BinderAddAssert
(
offset
%
factor
==
zero
,
arg_name
+
".elem_offset"
,
&
asserts_
);
}
}
}
// device info.
Bind_
(
device_type
,
...
...
src/pass/ir_deep_compare.cc
View file @
00506a62
...
...
@@ -118,6 +118,7 @@ class IRDeepCompare :
const
AssertStmt
*
rhs
=
other
.
as
<
AssertStmt
>
();
if
(
CompareExpr
(
op
->
condition
,
rhs
->
condition
)
!=
0
)
return
;
if
(
CompareExpr
(
op
->
message
,
rhs
->
message
)
!=
0
)
return
;
if
(
CompareStmt
(
op
->
body
,
rhs
->
body
)
!=
0
)
return
;
}
void
VisitStmt_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
other
)
final
{
...
...
@@ -127,7 +128,6 @@ class IRDeepCompare :
if
(
CompareStmt
(
op
->
body
,
rhs
->
body
)
!=
0
)
return
;
}
void
VisitStmt_
(
const
Provide
*
op
,
const
Stmt
&
other
)
final
{
const
Provide
*
rhs
=
other
.
as
<
Provide
>
();
if
(
CompareNodeRef
(
op
->
func
,
rhs
->
func
)
!=
0
)
return
;
...
...
src/pass/ir_mutator.cc
View file @
00506a62
...
...
@@ -219,11 +219,14 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt
IRMutator
::
Mutate_
(
const
AssertStmt
*
op
,
const
Stmt
&
s
)
{
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Expr
message
=
this
->
Mutate
(
op
->
message
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
))
{
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
AssertStmt
::
make
(
condition
,
message
);
return
AssertStmt
::
make
(
condition
,
message
,
body
);
}
}
...
...
src/pass/ir_util.cc
View file @
00506a62
...
...
@@ -34,7 +34,10 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
n
->
then_case
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
AssertStmt
>
())
{
body
=
Block
::
make
(
s
,
body
);
auto
n
=
std
::
make_shared
<
AssertStmt
>
(
*
s
.
as
<
AssertStmt
>
());
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
Allocate
>
())
{
auto
n
=
std
::
make_shared
<
Allocate
>
(
*
s
.
as
<
Allocate
>
());
CHECK
(
is_no_op
(
n
->
body
));
...
...
src/pass/ir_visitor.cc
View file @
00506a62
...
...
@@ -162,6 +162,7 @@ void IRVisitor::Visit_(const Broadcast *op) {
void
IRVisitor
::
Visit_
(
const
AssertStmt
*
op
)
{
this
->
Visit
(
op
->
condition
);
this
->
Visit
(
op
->
message
);
this
->
Visit
(
op
->
body
);
}
void
IRVisitor
::
Visit_
(
const
ProducerConsumer
*
op
)
{
...
...
src/pass/make_api.cc
View file @
00506a62
...
...
@@ -19,7 +19,7 @@ namespace tvm {
namespace
ir
{
inline
Stmt
MakeAssertEQ
(
Expr
lhs
,
Expr
rhs
,
std
::
string
msg
)
{
return
AssertStmt
::
make
(
lhs
==
rhs
,
msg
);
return
AssertStmt
::
make
(
lhs
==
rhs
,
msg
,
Evaluate
::
make
(
0
)
);
}
LoweredFunc
MakeAPI
(
Stmt
body
,
...
...
@@ -100,16 +100,16 @@ LoweredFunc MakeAPI(Stmt body,
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kHandle
||
tcode
==
kArrayHandle
||
tcode
==
kNull
,
msg
.
str
()));
tcode
==
kNull
,
msg
.
str
()
,
nop
));
}
else
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
std
::
ostringstream
msg
;
msg
<<
"Expect argument "
<<
i
<<
" to be int"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kInt
,
msg
.
str
()));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kInt
,
msg
.
str
()
,
nop
));
}
else
{
CHECK
(
t
.
is_float
());
std
::
ostringstream
msg
;
msg
<<
"Expect argument "
<<
i
<<
" to be float"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kFloat
,
msg
.
str
()));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kFloat
,
msg
.
str
()
,
nop
));
}
}
else
{
args
.
push_back
(
v_arg
);
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
00506a62
...
...
@@ -16,10 +16,12 @@ def test_llvm_add_pipeline():
return
# Specifically allow offset to test codepath when offset is available
Ab
=
tvm
.
decl_buffer
(
A
.
shape
,
A
.
dtype
,
elem_offset
=
tvm
.
var
(
'Aoffset'
),
A
.
shape
,
A
.
dtype
,
elem_offset
=
tvm
.
var
(
'Aoffset'
),
offset_factor
=
8
,
name
=
'A'
)
binds
=
{
A
:
Ab
}
#
build
and invoke the kernel.
#
BUILD
and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
Ab
,
B
,
C
],
"llvm"
,
binds
=
binds
)
ctx
=
tvm
.
cpu
(
0
)
# launch the kernel.
...
...
@@ -31,7 +33,8 @@ def test_llvm_add_pipeline():
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
check_llvm
()
with
tvm
.
build_config
(
offset_factor
=
4
):
check_llvm
()
def
test_llvm_flip_pipeline
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment