Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
45597d00
Commit
45597d00
authored
Feb 09, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 09, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG/PASS] Support Vectorize (#37)
parent
6a62beb2
Show whitespace changes
Inline
Side-by-side
Showing
24 changed files
with
1105 additions
and
125 deletions
+1105
-125
include/tvm/ir_mutator.h
+1
-0
include/tvm/ir_pass.h
+6
-0
include/tvm/schedule.h
+56
-1
python/tvm/build.py
+1
-0
python/tvm/schedule.py
+20
-0
src/api/api_lang.cc
+12
-0
src/api/api_pass.cc
+1
-0
src/arithmetic/compute_expr.h
+18
-0
src/codegen/codegen_c.cc
+213
-55
src/codegen/codegen_c.h
+45
-22
src/codegen/codegen_cuda.cc
+103
-2
src/codegen/codegen_cuda.h
+9
-0
src/codegen/codegen_opencl.cc
+82
-5
src/codegen/codegen_opencl.h
+10
-0
src/pass/ir_mutator.cc
+17
-15
src/pass/split_host_device.cc
+7
-1
src/pass/unroll_loop.cc
+6
-5
src/pass/vectorize_loop.cc
+385
-0
src/schedule/schedule_lang.cc
+39
-0
src/schedule/schedule_ops.cc
+9
-2
tests/python/integration/test_ewise.py
+21
-15
tests/python/unittest/test_lang_schedule.py
+16
-0
tests/python/unittest/test_pass_unroll.py
+4
-2
tests/python/unittest/test_pass_vectorize.py
+24
-0
No files found.
include/tvm/ir_mutator.h
View file @
45597d00
...
...
@@ -62,6 +62,7 @@ class IRMutator {
virtual
Stmt
Mutate_
(
const
Realize
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Free
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
);
virtual
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
s
);
virtual
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
);
...
...
include/tvm/ir_pass.h
View file @
45597d00
...
...
@@ -112,6 +112,12 @@ Stmt StorageFlatten(Stmt stmt,
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
);
/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
*/
Stmt
VectorizeLoop
(
Stmt
stmt
);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
...
...
include/tvm/schedule.h
View file @
45597d00
...
...
@@ -18,6 +18,8 @@ class StageNode;
class
ScheduleNode
;
// Node container for IterVarRelation
class
IterVarRelationNode
;
// Attribute of itervar.
class
IterVarAttrNode
;
/*! \brief the attachment type */
enum
AttachType
:
int
{
...
...
@@ -27,6 +29,12 @@ enum AttachType : int {
kScope
=
3
};
/*! \brief IterVar type */
enum
IterVarType
:
int
{
kUnrolled
=
1
,
kVectorized
=
2
};
/*! \brief Stage, contains scheduling for a stage of computation. */
class
Stage
:
public
NodeRef
{
public
:
...
...
@@ -124,11 +132,22 @@ class Stage : public NodeRef {
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
Expr
x_factor
,
Expr
y_factor
);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage
&
vectorize
(
IterVar
var
);
// NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage
&
unroll
(
IterVar
var
);
// NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
inline
bool
is_scheduled
()
const
;
// declare container type
using
ContainerType
=
StageNode
;
};
...
...
@@ -193,6 +212,21 @@ class IterVarRelation : public NodeRef {
inline
const
IterVarRelationNode
*
operator
->
()
const
;
};
/*!
* \brief Additional scheduable attributes about IterVar.
*/
class
IterVarAttr
:
public
NodeRef
{
public
:
IterVarAttr
()
{}
explicit
IterVarAttr
(
IterVarType
t
);
explicit
IterVarAttr
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
IterVarAttrNode
*
operator
->
()
const
;
};
// defintion of node containers
/*!
* \brief represents the schedule of the tensor
...
...
@@ -223,6 +257,8 @@ class StageNode : public Node {
Array
<
IterVar
>
leaf_iter_vars
;
/*! \brief The relation bwteen of IterVars */
Array
<
IterVarRelation
>
relations
;
/*! \brief additional attributes about iter var. */
Map
<
IterVar
,
IterVarAttr
>
iter_var_attrs
;
/*! \brief The attachment type of the schedule */
AttachType
attach_type
{
kNone
};
/*! \brief The attach point of this schedule. */
...
...
@@ -236,6 +272,7 @@ class StageNode : public Node {
v
->
Visit
(
"all_iter_vars"
,
&
all_iter_vars
);
v
->
Visit
(
"leaf_iter_vars"
,
&
leaf_iter_vars
);
v
->
Visit
(
"relations"
,
&
relations
);
v
->
Visit
(
"iter_var_attrs"
,
&
iter_var_attrs
);
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_ivar"
,
&
attach_ivar
);
v
->
Visit
(
"attach_stage"
,
&
attach_stage
);
...
...
@@ -268,6 +305,20 @@ class ScheduleNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO
(
ScheduleNode
);
};
/*! \brief node container for IterVar attr */
class
IterVarAttrNode
:
public
Node
{
public
:
/*! \brief The iteration type. */
IterVarType
iter_type
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"iter_type"
,
&
iter_type
);
}
static
constexpr
const
char
*
_type_key
=
"IterVarAttr"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IterVarAttrNode
);
};
/*! \brief base node of iteration var */
class
IterVarRelationNode
:
public
Node
{
};
...
...
@@ -372,5 +423,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
return
static_cast
<
const
IterVarRelationNode
*>
(
node_
.
get
());
}
inline
const
IterVarAttrNode
*
IterVarAttr
::
operator
->
()
const
{
return
static_cast
<
const
IterVarAttrNode
*>
(
node_
.
get
());
}
}
// namespace tvm
#endif // TVM_SCHEDULE_H_
python/tvm/build.py
View file @
45597d00
...
...
@@ -69,6 +69,7 @@ def build(sch,
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
ir_pass
.
CanonicalSimplify
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
max_auto_unroll_step
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
fapi
=
ir_pass
.
MakeAPI
(
stmt
,
name
,
arg_list
,
len
(
arg_list
))
...
...
python/tvm/schedule.py
View file @
45597d00
...
...
@@ -177,3 +177,23 @@ class Stage(NodeBase):
x_outer
,
y_outer
,
x_inner
,
y_inner
=
_api_internal
.
_StageTile
(
self
,
x_parent
,
y_parent
,
x_factor
,
y_factor
)
return
x_outer
,
y_outer
,
x_inner
,
y_inner
def
vectorize
(
self
,
var
):
"""Vectorize the iteration.
Parameters
----------
var : IterVar
The iteration to be vectorize
"""
_api_internal
.
_StageVectorize
(
self
,
var
)
def
unroll
(
self
,
var
):
"""Unroll the iteration.
Parameters
----------
var : IterVar
The iteration to be unrolled.
"""
_api_internal
.
_StageUnroll
(
self
,
var
)
src/api/api_lang.cc
View file @
45597d00
...
...
@@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile)
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
});
TVM_REGISTER_API
(
_StageUnroll
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
unroll
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageVectorize
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
vectorize
(
args
[
1
]);
});
TVM_REGISTER_API
(
_ScheduleNormalize
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Schedule
()
...
...
src/api/api_pass.cc
View file @
45597d00
...
...
@@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA);
REGISTER_PASS1
(
CanonicalSimplify
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS2
(
StorageFlatten
);
REGISTER_PASS1
(
VectorizeLoop
);
REGISTER_PASS2
(
UnrollLoop
);
REGISTER_PASS2
(
StorageSync
);
REGISTER_PASS4
(
MakeAPI
);
...
...
src/arithmetic/compute_expr.h
View file @
45597d00
...
...
@@ -9,6 +9,7 @@
#include <tvm/ir.h>
#include <pass/Interval.h>
#include <limits>
namespace
tvm
{
namespace
arith
{
...
...
@@ -52,6 +53,23 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
}
}
// get a small constant int
inline
bool
GetConstInt
(
Expr
e
,
int
*
out
)
{
int64_t
v1
=
0
;
uint64_t
v2
=
0
;
if
(
GetConst
(
e
,
&
v1
))
{
if
(
v1
>
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int
>::
max
()))
return
false
;
*
out
=
static_cast
<
int
>
(
v1
);
return
true
;
}
if
(
GetConst
(
e
,
&
v2
))
{
if
(
v2
>
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
int
>::
max
()))
return
false
;
*
out
=
static_cast
<
int
>
(
v2
);
return
true
;
}
return
false
;
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
...
...
src/codegen/codegen_c.cc
View file @
45597d00
...
...
@@ -3,7 +3,9 @@
* \file codegen_c.cc
*/
#include <iomanip>
#include <cctype>
#include "./codegen_c.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
codegen
{
...
...
@@ -14,10 +16,10 @@ std::string CodeGenC::Compile(LoweredFunc f,
bool
output_ssa
)
{
print_ssa_form_
=
output_ssa
;
// skip the first underscore, so SSA variable starts from _1
if
(
print_ssa_form_
)
GetUniqueName
(
"_"
);
GetUniqueName
(
"_"
);
// add to alloc buffer type.
for
(
const
auto
&
kv
:
f
->
handle_data_type
)
{
HandleTypeRegister
(
kv
.
first
.
get
(),
kv
.
second
.
type
());
RegisterHandleType
(
kv
.
first
.
get
(),
kv
.
second
.
type
());
}
this
->
stream
<<
"void "
<<
f
->
name
<<
"("
;
...
...
@@ -26,7 +28,11 @@ std::string CodeGenC::Compile(LoweredFunc f,
std
::
string
vid
=
AllocVarID
(
v
.
get
());
if
(
i
!=
0
)
stream
<<
", "
;
if
(
v
.
type
().
is_handle
())
{
stream
<<
arg_addr_space_
;
auto
it
=
alloc_storage_scope_
.
find
(
v
.
get
());
if
(
it
!=
alloc_storage_scope_
.
end
())
{
PrintStorageScope
(
it
->
second
,
stream
);
}
stream
<<
' '
;
}
if
(
handle_data_type_
.
count
(
v
.
get
()))
{
PrintType
(
handle_data_type_
.
at
(
v
.
get
()),
stream
);
...
...
@@ -126,7 +132,7 @@ bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
return
it
->
second
==
t
;
}
void
CodeGenC
::
HandleTypeRegister
(
const
Variable
*
buf_var
,
Type
t
)
{
void
CodeGenC
::
RegisterHandleType
(
const
Variable
*
buf_var
,
Type
t
)
{
auto
it
=
handle_data_type_
.
find
(
buf_var
);
if
(
it
==
handle_data_type_
.
end
())
{
handle_data_type_
[
buf_var
]
=
t
;
...
...
@@ -259,23 +265,39 @@ inline void PrintBinaryExpr(const T* op,
const
char
*
opstr
,
std
::
ostream
&
os
,
// NOLINT(*)
CodeGenC
*
p
)
{
if
(
op
->
type
.
lanes
()
==
1
)
{
if
(
isalpha
(
opstr
[
0
]))
{
os
<<
opstr
<<
'('
;
p
->
PrintExpr
(
op
->
a
,
os
);
os
<<
", "
;
p
->
PrintExpr
(
op
->
b
,
os
);
os
<<
')'
;
}
else
{
os
<<
'('
;
p
->
PrintExpr
(
op
->
a
,
os
);
os
<<
opstr
;
os
<<
' '
<<
opstr
<<
' '
;
p
->
PrintExpr
(
op
->
b
,
os
);
os
<<
')'
;
}
}
else
{
p
->
PrintVecBinaryOp
(
opstr
,
op
->
type
,
op
->
a
,
op
->
b
,
os
);
}
}
inline
void
PrintBinaryIntrinsitc
(
const
Call
*
op
,
const
char
*
opstr
,
std
::
ostream
&
os
,
// NOLINT(*)
CodeGenC
*
p
)
{
if
(
op
->
type
.
lanes
()
==
1
)
{
CHECK_EQ
(
op
->
args
.
size
(),
2U
);
os
<<
'('
;
p
->
PrintExpr
(
op
->
args
[
0
],
os
);
os
<<
opstr
;
p
->
PrintExpr
(
op
->
args
[
1
],
os
);
os
<<
')'
;
}
else
{
p
->
PrintVecBinaryOp
(
opstr
,
op
->
type
,
op
->
args
[
0
],
op
->
args
[
1
],
os
);
}
}
TVM_STATIC_IR_FUNCTOR
(
CodeGenC
,
vtable_print_expr
)
...
...
@@ -289,57 +311,49 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
os
<<
p
->
GetVarID
(
op
);
})
.
set_dispatch
<
Add
>
([](
const
Add
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
+
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
+
"
,
os
,
p
);
})
.
set_dispatch
<
Sub
>
([](
const
Sub
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
-
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
-
"
,
os
,
p
);
})
.
set_dispatch
<
Mul
>
([](
const
Mul
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
*
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
*
"
,
os
,
p
);
})
.
set_dispatch
<
Div
>
([](
const
Div
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
/
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
/
"
,
os
,
p
);
})
.
set_dispatch
<
Mod
>
([](
const
Mod
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
%
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
%
"
,
os
,
p
);
})
.
set_dispatch
<
Min
>
([](
const
Min
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
os
<<
"min("
;
p
->
PrintExpr
(
op
->
a
,
os
);
os
<<
", "
;
p
->
PrintExpr
(
op
->
b
,
os
);
os
<<
")"
;
PrintBinaryExpr
(
op
,
"min"
,
os
,
p
);
})
.
set_dispatch
<
Max
>
([](
const
Max
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
os
<<
"max("
;
p
->
PrintExpr
(
op
->
a
,
os
);
os
<<
", "
;
p
->
PrintExpr
(
op
->
b
,
os
);
os
<<
")"
;
PrintBinaryExpr
(
op
,
"max"
,
os
,
p
);
})
.
set_dispatch
<
EQ
>
([](
const
EQ
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
==
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
==
"
,
os
,
p
);
})
.
set_dispatch
<
NE
>
([](
const
NE
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
!=
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
!=
"
,
os
,
p
);
})
.
set_dispatch
<
LT
>
([](
const
LT
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
<
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
<
"
,
os
,
p
);
})
.
set_dispatch
<
LE
>
([](
const
LE
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
<=
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
<=
"
,
os
,
p
);
})
.
set_dispatch
<
GT
>
([](
const
GT
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
>
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
>
"
,
os
,
p
);
})
.
set_dispatch
<
GE
>
([](
const
GE
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
>=
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
>=
"
,
os
,
p
);
})
.
set_dispatch
<
And
>
([](
const
And
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
&&
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
&&
"
,
os
,
p
);
})
.
set_dispatch
<
Or
>
([](
const
Or
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"
||
"
,
os
,
p
);
PrintBinaryExpr
(
op
,
"
||
"
,
os
,
p
);
})
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
std
::
ostream
&
os
,
CodeGenC
*
p
)
{
// NOLINT(*)
os
<<
'!'
;
...
...
@@ -460,18 +474,179 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
}
}
void
CodeGenC
::
PrintExpr
(
const
Load
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
std
::
string
vid
=
GetVarID
(
op
->
buffer_var
.
get
());
if
(
!
HandleTypeMatch
(
op
->
buffer_var
.
get
(),
op
->
type
))
{
os
<<
"((const "
;
PrintType
(
op
->
type
,
os
);
void
CodeGenC
::
PrintVecBinaryOp
(
const
std
::
string
&
op
,
Type
t
,
Expr
lhs
,
Expr
rhs
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
isalpha
(
op
[
0
]))
{
os
<<
op
<<
"("
;
this
->
PrintExpr
(
lhs
,
os
);
os
<<
", "
;
this
->
PrintExpr
(
rhs
,
os
);
os
<<
")"
;
}
else
{
os
<<
"("
;
this
->
PrintExpr
(
lhs
,
os
);
os
<<
' '
<<
op
<<
' '
;
this
->
PrintExpr
(
rhs
,
os
);
os
<<
")"
;
}
}
inline
bool
TryGetRamp1Base
(
Expr
index
,
int
lanes
,
Expr
*
base
)
{
const
Ramp
*
r
=
index
.
as
<
Ramp
>
();
if
(
!
r
)
return
false
;
if
(
!
is_one
(
r
->
stride
))
return
false
;
CHECK_EQ
(
r
->
lanes
,
lanes
);
*
base
=
r
->
base
;
return
true
;
}
// Print a reference expression to a buffer.
void
CodeGenC
::
PrintBufferRef
(
const
Variable
*
buffer
,
Type
t
,
Expr
index
,
std
::
ostream
&
os
)
{
// NOLINT(*)
std
::
string
vid
=
GetVarID
(
buffer
);
if
(
t
.
lanes
()
==
1
)
{
if
(
!
HandleTypeMatch
(
buffer
,
t
))
{
os
<<
"(("
;
PrintType
(
t
,
os
);
os
<<
"*)"
<<
vid
<<
')'
;
}
else
{
os
<<
vid
;
}
os
<<
'['
;
PrintExpr
(
op
->
index
,
os
);
PrintExpr
(
index
,
os
);
os
<<
']'
;
}
else
{
// Buffer declared as vector type.
// optimize for case where it is in register,
if
(
HandleTypeMatch
(
buffer
,
t
))
{
// optimize for constant access
int
offset
;
if
(
arith
::
GetConstInt
(
index
,
&
offset
))
{
CHECK_EQ
(
offset
%
t
.
lanes
(),
0
)
<<
"Find unaligned vector load to a vector type"
;
os
<<
vid
<<
'['
<<
(
offset
/
t
.
lanes
())
<<
']'
;
return
;
}
}
os
<<
"(("
;
PrintType
(
t
,
os
);
os
<<
"*)("
;
if
(
!
HandleTypeMatch
(
buffer
,
t
.
element_of
()))
{
os
<<
'('
;
PrintType
(
t
.
element_of
(),
os
);
os
<<
"*)"
;
}
os
<<
vid
<<
" + "
;
PrintExpr
(
index
,
os
);
os
<<
"))[0]"
;
}
}
void
CodeGenC
::
PrintExpr
(
const
Load
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
int
lanes
=
op
->
type
.
lanes
();
if
(
op
->
type
.
lanes
()
==
1
)
{
this
->
PrintBufferRef
(
op
->
buffer_var
.
get
(),
op
->
type
,
op
->
index
,
os
);
}
else
{
Expr
base
;
if
(
TryGetRamp1Base
(
op
->
index
,
op
->
type
.
lanes
(),
&
base
))
{
this
->
PrintVecLoad
(
op
->
buffer_var
.
get
(),
op
->
type
,
base
,
os
);
}
else
{
// Load elements seperately
std
::
string
sindex
=
SSAGetID
(
PrintExpr
(
op
->
index
),
op
->
index
.
type
());
std
::
string
svalue
=
GetUniqueName
(
"_"
);
{
// delcare type.
this
->
PrintIndent
();
this
->
PrintType
(
op
->
type
,
stream
);
stream
<<
' '
<<
svalue
<<
";
\n
"
;
}
std
::
string
vid
=
GetVarID
(
op
->
buffer_var
.
get
());
Type
elem_type
=
op
->
type
.
element_of
();
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
std
::
ostringstream
value_temp
;
if
(
!
HandleTypeMatch
(
op
->
buffer_var
.
get
(),
elem_type
))
{
value_temp
<<
"(("
;
PrintType
(
elem_type
,
os
);
value_temp
<<
"*)"
<<
vid
<<
')'
;
}
else
{
value_temp
<<
vid
;
}
value_temp
<<
'['
;
PrintVecElemLoad
(
sindex
,
op
->
index
.
type
(),
i
,
value_temp
);
value_temp
<<
']'
;
PrintVecElemStore
(
svalue
,
op
->
type
,
i
,
value_temp
.
str
());
}
os
<<
svalue
;
}
}
}
void
CodeGenC
::
PrintStmt
(
const
Store
*
op
)
{
Type
t
=
op
->
value
.
type
();
if
(
t
.
lanes
()
==
1
)
{
this
->
PrintIndent
();
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
this
->
PrintBufferRef
(
op
->
buffer_var
.
get
(),
t
,
op
->
index
,
stream
);
stream
<<
" = "
<<
value
<<
";
\n
"
;
}
else
{
Expr
base
;
if
(
TryGetRamp1Base
(
op
->
index
,
t
.
lanes
(),
&
base
))
{
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
this
->
PrintVecStore
(
op
->
buffer_var
.
get
(),
t
,
base
,
value
);
}
else
{
// store elements seperately
std
::
string
index
=
SSAGetID
(
PrintExpr
(
op
->
index
),
op
->
index
.
type
());
std
::
string
value
=
SSAGetID
(
PrintExpr
(
op
->
value
),
op
->
value
.
type
());
std
::
string
vid
=
GetVarID
(
op
->
buffer_var
.
get
());
for
(
int
i
=
0
;
i
<
t
.
lanes
();
++
i
)
{
this
->
PrintIndent
();
Type
elem_type
=
t
.
element_of
();
if
(
!
HandleTypeMatch
(
op
->
buffer_var
.
get
(),
elem_type
))
{
stream
<<
"(("
;
PrintType
(
elem_type
,
stream
);
stream
<<
"*)"
<<
vid
<<
')'
;
}
else
{
stream
<<
vid
;
}
stream
<<
'['
;
PrintVecElemLoad
(
index
,
op
->
index
.
type
(),
i
,
stream
);
stream
<<
"] = "
;
PrintVecElemLoad
(
value
,
op
->
value
.
type
(),
i
,
stream
);
stream
<<
";
\n
"
;
}
}
}
}
void
CodeGenC
::
PrintVecElemLoad
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
std
::
ostream
&
os
)
{
// NOLINT(*)
os
<<
vec
<<
".s"
<<
std
::
hex
<<
i
;
}
void
CodeGenC
::
PrintVecElemStore
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
const
std
::
string
&
value
)
{
this
->
PrintIndent
();
stream
<<
vec
<<
".s"
<<
std
::
hex
<<
i
<<
" = "
<<
value
<<
";
\n
"
;
}
void
CodeGenC
::
PrintVecLoad
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
std
::
ostream
&
os
)
{
PrintBufferRef
(
buffer
,
t
,
base
,
os
);
}
void
CodeGenC
::
PrintVecStore
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
const
std
::
string
&
value
)
{
this
->
PrintIndent
();
PrintBufferRef
(
buffer
,
t
,
base
,
stream
);
stream
<<
" = "
<<
value
<<
";
\n
"
;
}
void
CodeGenC
::
PrintExpr
(
const
Let
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
...
...
@@ -483,15 +658,15 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
}
void
CodeGenC
::
PrintExpr
(
const
Ramp
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
LOG
(
FATAL
)
<<
"not supported "
;
LOG
(
FATAL
)
<<
"
Ramp:
not supported "
;
}
void
CodeGenC
::
PrintExpr
(
const
Broadcast
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
LOG
(
FATAL
)
<<
"not supported "
;
LOG
(
FATAL
)
<<
"
Broadcast:
not supported "
;
}
void
CodeGenC
::
PrintExpr
(
const
Select
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
LOG
(
FATAL
)
<<
"not supported "
;
LOG
(
FATAL
)
<<
"
Select:
not supported "
;
}
// Disoatch back to member functions
...
...
@@ -541,23 +716,6 @@ void CodeGenC::PrintStmt(const LetStmt* op) {
PrintStmt
(
op
->
body
);
}
void
CodeGenC
::
PrintStmt
(
const
Store
*
op
)
{
std
::
string
index
=
this
->
PrintExpr
(
op
->
index
);
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
this
->
PrintIndent
();
std
::
string
vid
=
GetVarID
(
op
->
buffer_var
.
get
());
if
(
!
HandleTypeMatch
(
op
->
buffer_var
.
get
(),
op
->
value
.
type
()))
{
this
->
stream
<<
"(("
;
PrintType
(
op
->
value
.
type
(),
this
->
stream
);
this
->
stream
<<
"*)"
<<
vid
<<
')'
;
}
else
{
this
->
stream
<<
vid
;
}
this
->
stream
<<
'['
<<
index
<<
"] = "
<<
value
<<
";
\n
"
;
}
void
CodeGenC
::
PrintStmt
(
const
Allocate
*
op
)
{
CHECK
(
!
is_zero
(
op
->
condition
));
std
::
string
vid
=
AllocVarID
(
op
->
buffer_var
.
get
());
...
...
@@ -580,7 +738,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
stream
<<
' '
<<
vid
<<
'['
<<
constant_size
<<
"];
\n
"
;
}
HandleTypeRegister
(
op
->
buffer_var
.
get
(),
op
->
type
);
RegisterHandleType
(
op
->
buffer_var
.
get
(),
op
->
type
);
this
->
PrintStmt
(
op
->
body
);
}
...
...
src/codegen/codegen_c.h
View file @
45597d00
...
...
@@ -102,6 +102,20 @@ class CodeGenC {
virtual
void
PrintExpr
(
const
ir
::
Ramp
*
op
,
std
::
ostream
&
os
);
// NOLINT(*)
virtual
void
PrintExpr
(
const
ir
::
Broadcast
*
op
,
std
::
ostream
&
os
);
// NOLINT(*)
virtual
void
PrintExpr
(
const
ir
::
Select
*
op
,
std
::
ostream
&
os
);
// NOLINT(*)
// Binary vector op.
virtual
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
Type
op_type
,
Expr
lhs
,
Expr
rhs
,
std
::
ostream
&
os
);
// NOLINT(*)
virtual
void
PrintVecLoad
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
std
::
ostream
&
os
);
// NOLINT(*)
virtual
void
PrintVecStore
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
const
std
::
string
&
value
);
// NOLINT(*)
virtual
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
std
::
ostream
&
os
);
// NOLINT(*)
virtual
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
const
std
::
string
&
value
);
/*! \brief function print into the ostream */
using
FPrintExpr
=
IRFunctor
<
void
(
const
NodeRef
&
,
std
::
ostream
&
os
,
CodeGenC
*
)
>
;
// NOLINT(*)
/*! \brief function to to print normal code */
...
...
@@ -116,17 +130,10 @@ class CodeGenC {
std
::
ostringstream
stream
;
protected
:
// additional string for arg addr_space.
std
::
string
arg_addr_space_
;
private
:
/*! \brief entry in ssa assign map */
struct
SSAEntry
{
/*! \brief The value id */
std
::
string
vid
;
/*! \brief The scope id */
int
scope_id
;
};
// print reference to a buffer as type t in index.
void
PrintBufferRef
(
const
Variable
*
buffer
,
Type
t
,
Expr
index
,
std
::
ostream
&
os
);
// NOLINT(*)
/*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
...
...
@@ -135,6 +142,19 @@ class CodeGenC {
*/
std
::
string
SSAGetID
(
std
::
string
src
,
Type
t
);
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std
::
string
GetUniqueName
(
std
::
string
prefix
);
/*! \brief entry in ssa assign map */
struct
SSAEntry
{
/*! \brief The value id */
std
::
string
vid
;
/*! \brief The scope id */
int
scope_id
;
};
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
*/
...
...
@@ -155,25 +175,28 @@ class CodeGenC {
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void
HandleTypeRegister
(
const
Variable
*
buf_var
,
Type
t
);
void
RegisterHandleType
(
const
Variable
*
buf_var
,
Type
t
);
/*!
* \brief
get a unique name with the corresponding prefix
* \param
prefix The prefix of the name
* \return The
returned nam
e.
* \brief
Get the storage scope of buf_var.
* \param
buf_var The buf_var to be queryed.
* \return The
storage scop
e.
*/
std
::
string
GetUniqueName
(
std
::
string
prefix
);
/*! \brief whether to print in SSA form */
bool
print_ssa_form_
{
true
};
/*! \brief name of each variable */
std
::
unordered_map
<
const
Variable
*
,
std
::
string
>
var_idmap_
;
/*! \brief the data type of allocated buffers */
std
::
unordered_map
<
const
Variable
*
,
Type
>
handle_data_type_
;
std
::
string
GetStorageScope
(
const
Variable
*
buf_var
)
const
;
/*! \brief the storage scope of allocation */
std
::
unordered_map
<
const
Variable
*
,
std
::
string
>
alloc_storage_scope_
;
private
:
/*! \brief whether to print in SSA form */
bool
print_ssa_form_
{
true
};
/*! \brief name allocation map */
std
::
unordered_map
<
std
::
string
,
int
>
name_alloc_map_
;
/*! \brief assignment map of ssa */
std
::
unordered_map
<
std
::
string
,
SSAEntry
>
ssa_assign_map_
;
/*! \brief name of each variable */
std
::
unordered_map
<
const
Variable
*
,
std
::
string
>
var_idmap_
;
/*! \brief the data type of allocated buffers */
std
::
unordered_map
<
const
Variable
*
,
Type
>
handle_data_type_
;
/*! \brief array to check whether we are inside certain scope */
std
::
vector
<
bool
>
scope_mark_
;
};
...
...
src/codegen/codegen_cuda.cc
View file @
45597d00
...
...
@@ -22,6 +22,108 @@ std::string CodeGenCUDA::Compile(
return
CodeGenC
::
Compile
(
f
,
output_ssa
);
}
void
CodeGenCUDA
::
PrintType
(
Type
t
,
std
::
ostream
&
os
)
const
{
// NOLINT(*)
int
lanes
=
t
.
lanes
();
if
(
t
.
is_handle
())
{
CHECK_EQ
(
lanes
,
1
)
<<
"do not yet support vector types"
;
os
<<
"void*"
;
return
;
}
bool
fail
=
false
;
if
(
t
.
is_float
())
{
switch
(
t
.
bits
())
{
case
16
:
os
<<
"half"
;
break
;
case
32
:
os
<<
"float"
;
break
;
case
64
:
os
<<
"double"
;
break
;
default
:
fail
=
true
;
break
;
}
if
(
!
fail
&&
lanes
==
1
)
return
;
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
4
))
{
os
<<
lanes
;
return
;
}
}
else
if
(
t
.
is_uint
()
||
t
.
is_int
())
{
if
(
t
.
is_uint
())
{
os
<<
'u'
;
}
if
(
t
.
bits
()
==
8
&&
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
os
<<
"int"
;
return
;
}
switch
(
t
.
bits
())
{
case
8
:
os
<<
"char"
;
break
;
case
16
:
os
<<
"short"
;
break
;
case
32
:
os
<<
"int"
;
break
;
case
64
:
{
if
(
lanes
!=
1
&&
sizeof
(
long
)
==
64
)
{
// NOLINT(*)
os
<<
"long"
;
break
;
}
else
{
os
<<
"int64_t"
;
break
;
}
}
case
1
:
os
<<
"int"
;
break
;
default
:
fail
=
true
;
break
;
}
if
(
!
fail
&&
lanes
==
1
)
return
;
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
4
))
{
os
<<
lanes
;
return
;
}
}
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to CUDA type"
;
}
void
CodeGenCUDA
::
PrintVecBinaryOp
(
const
std
::
string
&
op
,
Type
t
,
Expr
lhs
,
Expr
rhs
,
std
::
ostream
&
os
)
{
// NOLINT(*)
// unpacking operations.
int
lanes
=
t
.
lanes
();
{
// default: unpack into individual ops.
std
::
string
vlhs
=
SSAGetID
(
PrintExpr
(
lhs
),
lhs
.
type
());
std
::
string
vrhs
=
SSAGetID
(
PrintExpr
(
rhs
),
rhs
.
type
());
std
::
string
sret
=
GetUniqueName
(
"_"
);
{
// delcare type.
this
->
PrintIndent
();
this
->
PrintType
(
t
,
stream
);
stream
<<
' '
<<
sret
<<
";
\n
"
;
}
for
(
int
i
=
0
;
i
<
lanes
;
++
i
)
{
std
::
ostringstream
value_temp
;
if
(
isalpha
(
op
[
0
]))
{
value_temp
<<
op
<<
"("
;
PrintVecElemLoad
(
vlhs
,
lhs
.
type
(),
i
,
value_temp
);
value_temp
<<
", "
;
PrintVecElemLoad
(
vrhs
,
rhs
.
type
(),
i
,
value_temp
);
value_temp
<<
")"
;
}
else
{
value_temp
<<
"("
;
PrintVecElemLoad
(
vlhs
,
lhs
.
type
(),
i
,
value_temp
);
value_temp
<<
op
;
PrintVecElemLoad
(
vrhs
,
rhs
.
type
(),
i
,
value_temp
);
value_temp
<<
")"
;
}
PrintVecElemStore
(
sret
,
t
,
i
,
value_temp
.
str
());
}
os
<<
sret
;
}
}
void
CodeGenCUDA
::
PrintVecElemLoad
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
std
::
ostream
&
os
)
{
// NOLINT(*)
const
char
access
[]
=
{
'x'
,
'y'
,
'z'
,
'w'
};
CHECK
(
i
>=
0
&&
i
<
4
);
os
<<
vec
<<
"."
<<
access
[
i
];
}
void
CodeGenCUDA
::
PrintVecElemStore
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
const
std
::
string
&
value
)
{
this
->
PrintIndent
();
const
char
access
[]
=
{
'x'
,
'y'
,
'z'
,
'w'
};
CHECK
(
i
>=
0
&&
i
<
4
);
stream
<<
vec
<<
"."
<<
access
[
i
]
<<
" = "
<<
value
<<
";
\n
"
;
}
void
CodeGenCUDA
::
PrintStorageSync
(
const
std
::
string
&
sync
)
{
if
(
sync
==
"shared"
)
{
this
->
PrintIndent
();
...
...
@@ -43,8 +145,6 @@ void CodeGenCUDA::PrintStorageScope(
std
::
unordered_map
<
LoweredFunc
,
PackedFunc
>
MakeNVRTC
(
Array
<
LoweredFunc
>
funcs
)
{
std
::
ostringstream
os
;
os
<<
"typedef int int32_t;
\n
"
<<
"typedef unsigned unt32_t;
\n
"
;
bool
output_ssa
=
false
;
for
(
LoweredFunc
f
:
funcs
)
{
os
<<
CodeGenCUDA
().
Compile
(
f
,
output_ssa
);
...
...
@@ -56,6 +156,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_postproc"
);
code
=
f
(
code
).
operator
std
::
string
();
}
LOG
(
INFO
)
<<
code
;
std
::
string
ptx
;
if
(
PackedFunc
::
GlobalExist
(
"tvm_callback_cuda_compile"
))
{
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_compile"
);
...
...
src/codegen/codegen_cuda.h
View file @
45597d00
...
...
@@ -25,9 +25,18 @@ class CodeGenCUDA : public CodeGenC {
*/
std
::
string
Compile
(
LoweredFunc
f
,
bool
output_ssa
);
// override behavior
void
PrintStorageSync
(
const
std
::
string
&
sync
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
const
std
::
string
&
op
,
Type
t
,
Expr
lhs
,
Expr
rhs
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintType
(
Type
t
,
std
::
ostream
&
os
)
const
final
;
// NOLINT(*)
void
PrintVecElemLoad
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
const
std
::
string
&
value
)
final
;
};
}
// namespace codegen
...
...
src/codegen/codegen_opencl.cc
View file @
45597d00
...
...
@@ -19,7 +19,11 @@ std::string CodeGenOpenCL::Compile(
LoweredFunc
f
,
bool
output_ssa
)
{
this
->
stream
<<
" __kernel "
;
this
->
arg_addr_space_
=
"__global "
;
for
(
Var
arg
:
f
->
args
)
{
if
(
arg
.
type
().
is_handle
())
{
alloc_storage_scope_
[
arg
.
get
()]
=
"global"
;
}
}
return
CodeGenC
::
Compile
(
f
,
output_ssa
);
}
...
...
@@ -34,6 +38,80 @@ void CodeGenOpenCL::PrintThreadIndexExpr(
}
}
void
CodeGenOpenCL
::
PrintType
(
Type
t
,
std
::
ostream
&
os
)
const
{
// NOLINT(*)
int
lanes
=
t
.
lanes
();
if
(
t
.
is_handle
())
{
CHECK_EQ
(
lanes
,
1
)
<<
"do not yet support vector types"
;
os
<<
"void*"
;
return
;
}
bool
fail
=
false
;
if
(
t
.
is_float
())
{
switch
(
t
.
bits
())
{
case
16
:
os
<<
"half"
;
break
;
case
32
:
os
<<
"float"
;
break
;
case
64
:
os
<<
"double"
;
break
;
default
:
fail
=
true
;
break
;
}
if
(
!
fail
&&
lanes
==
1
)
return
;
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
16
))
{
os
<<
lanes
;
return
;
}
}
else
if
(
t
.
is_uint
()
||
t
.
is_int
())
{
if
(
t
.
is_uint
())
{
os
<<
'u'
;
}
if
(
t
.
bits
()
==
8
&&
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
os
<<
"int"
;
return
;
}
switch
(
t
.
bits
())
{
case
8
:
os
<<
"char"
;
break
;
case
16
:
os
<<
"short"
;
break
;
case
32
:
os
<<
"int"
;
break
;
case
64
:
os
<<
"long"
;
break
;
case
1
:
os
<<
"int"
;
break
;
default
:
fail
=
true
;
break
;
}
if
(
!
fail
&&
lanes
==
1
)
return
;
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
16
))
{
os
<<
lanes
;
return
;
}
}
LOG
(
FATAL
)
<<
"Cannot convert type "
<<
t
<<
" to OpenCL type"
;
}
void
CodeGenOpenCL
::
PrintVecAddr
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
!
HandleTypeMatch
(
buffer
,
t
.
element_of
()))
{
os
<<
'('
;
auto
it
=
alloc_storage_scope_
.
find
(
buffer
);
if
(
it
!=
alloc_storage_scope_
.
end
())
{
PrintStorageScope
(
it
->
second
,
os
);
}
os
<<
' '
;
PrintType
(
t
.
element_of
(),
os
);
os
<<
"*)"
;
}
os
<<
GetVarID
(
buffer
)
<<
" + "
;
PrintExpr
(
base
,
os
);
}
void
CodeGenOpenCL
::
PrintVecLoad
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
std
::
ostream
&
os
)
{
os
<<
"vload"
<<
t
.
lanes
()
<<
"(0, "
;
PrintVecAddr
(
buffer
,
t
,
base
,
os
);
os
<<
")"
;
}
void
CodeGenOpenCL
::
PrintVecStore
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
const
std
::
string
&
value
)
{
this
->
PrintIndent
();
stream
<<
"vstore"
<<
t
.
lanes
()
<<
"("
<<
value
<<
", 0, "
;
PrintVecAddr
(
buffer
,
t
,
base
,
stream
);
stream
<<
");
\n
"
;
}
void
CodeGenOpenCL
::
PrintStorageSync
(
const
std
::
string
&
sync
)
{
if
(
sync
==
"shared"
)
{
...
...
@@ -45,8 +123,9 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
}
void
CodeGenOpenCL
::
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
{
// NOLINT(*)
CHECK_NE
(
scope
,
"global"
);
if
(
scope
==
"shared"
)
{
if
(
scope
==
"global"
)
{
os
<<
"__global"
;
}
else
if
(
scope
==
"shared"
)
{
os
<<
"__local "
;
}
}
...
...
@@ -55,8 +134,6 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os
std
::
unordered_map
<
LoweredFunc
,
PackedFunc
>
MakeOpenCL
(
Array
<
LoweredFunc
>
funcs
)
{
std
::
ostringstream
os
;
os
<<
"typedef int int32_t;
\n
"
<<
"typedef unsigned unt32_t;
\n
"
;
bool
output_ssa
=
false
;
for
(
LoweredFunc
f
:
funcs
)
{
os
<<
CodeGenOpenCL
().
Compile
(
f
,
output_ssa
);
...
...
src/codegen/codegen_opencl.h
View file @
45597d00
...
...
@@ -30,6 +30,16 @@ class CodeGenOpenCL : public CodeGenC {
std
::
string
tag
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintStorageSync
(
const
std
::
string
&
scope
)
final
;
// NOLINT(*)
void
PrintType
(
Type
t
,
std
::
ostream
&
os
)
const
final
;
// NOLINT(*)
void
PrintVecLoad
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecStore
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
const
std
::
string
&
value
)
final
;
// NOLINT(*)
// the address of load/store
void
PrintVecAddr
(
const
Variable
*
buffer
,
Type
t
,
Expr
base
,
std
::
ostream
&
os
);
// NOLINT(*)
};
}
// namespace codegen
...
...
src/pass/ir_mutator.cc
View file @
45597d00
...
...
@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.
DISPATCH_TO_MUTATE_STMT
(
Provide
)
.
DISPATCH_TO_MUTATE_STMT
(
Realize
)
.
DISPATCH_TO_MUTATE_STMT
(
Store
)
.
DISPATCH_TO_MUTATE_STMT
(
IfThenElse
)
.
DISPATCH_TO_MUTATE_STMT
(
For
)
.
DISPATCH_TO_MUTATE_STMT
(
Allocate
)
.
DISPATCH_TO_MUTATE_STMT
(
Free
);
...
...
@@ -195,6 +196,22 @@ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return
s
;
}
Stmt
IRMutator
::
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
)
{
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Stmt
then_case
=
this
->
Mutate
(
op
->
then_case
);
Stmt
else_case
;
if
(
else_case
.
defined
())
{
else_case
=
this
->
Mutate
(
op
->
else_case
);
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
condition
,
then_case
,
else_case
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
DISPATCH_TO_MUTATE_EXPR
(
Call
)
.
DISPATCH_TO_MUTATE_EXPR
(
Let
)
...
...
@@ -363,21 +380,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return
Block
::
make
(
first
,
rest
);
}
})
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
Mutate
(
op
->
condition
);
Stmt
then_case
=
m
->
Mutate
(
op
->
then_case
);
Stmt
else_case
;
if
(
else_case
.
defined
())
{
else_case
=
m
->
Mutate
(
op
->
else_case
);
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
condition
,
then_case
,
else_case
);
}
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
v
=
m
->
Mutate
(
op
->
value
);
if
(
v
.
same_as
(
op
->
value
))
{
...
...
src/pass/split_host_device.cc
View file @
45597d00
...
...
@@ -101,9 +101,14 @@ class IRUseDefAnalysis : public IRMutator {
}
void
HandleDef
(
const
Variable
*
v
)
{
CHECK
(
!
def_count_
.
count
(
v
))
<<
"variable "
<<
v
->
name_hint
<<
" has already been defined, the Stmt is not SSA"
;
CHECK
(
!
use_count_
.
count
(
v
))
<<
"variable is already defined"
;
<<
"variable "
<<
v
->
name_hint
<<
" has been used before definition!"
;
use_count_
[
v
]
=
0
;
def_count_
[
v
]
=
1
;
}
void
HandleUse
(
const
Expr
&
v
)
{
...
...
@@ -127,6 +132,7 @@ class IRUseDefAnalysis : public IRMutator {
Array
<
IterVar
>
thread_axis_
;
Array
<
Expr
>
thread_extent_
;
std
::
unordered_map
<
const
Variable
*
,
int
>
use_count_
;
std
::
unordered_map
<
const
Variable
*
,
int
>
def_count_
;
};
class
HostDeviceSplitter
:
public
IRMutator
{
...
...
src/pass/unroll_loop.cc
View file @
45597d00
/*!
* Copyright (c) 201
6
by Contributors
*
SSA related checks and pass
.
* \file
ssa
.cc
* Copyright (c) 201
7
by Contributors
*
Loop unrolling
.
* \file
unroll_loop
.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
...
...
@@ -9,7 +9,7 @@
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic/
/
compute_expr.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
ir
{
...
...
@@ -33,7 +33,8 @@ class LoopUnroller : public IRMutator {
if
(
v2
!=
nullptr
)
{
value
=
static_cast
<
int
>
(
v2
->
value
);
}
bool
allow_unroll
=
value
>=
0
&&
value
<=
max_auto_step_
;
bool
allow_unroll
=
(
op
->
for_type
==
ForType
::
Serial
&&
value
>=
0
&&
value
<=
max_auto_step_
);
if
(
op
->
for_type
==
ForType
::
Unrolled
)
{
CHECK_GE
(
value
,
0
)
<<
"Cannot unroll non-constant loop"
;
...
...
src/pass/vectorize_loop.cc
0 → 100644
View file @
45597d00
/*!
* Copyright (c) 2017 by Contributors
* Vectorize the loop
* \file vectorize_loop.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
ir
{
inline
Expr
BroadcastTo
(
Expr
e
,
int
lanes
)
{
if
(
e
.
type
().
lanes
()
==
lanes
)
return
e
;
CHECK_EQ
(
e
.
type
().
lanes
(),
1
)
<<
"Cannot broadcast lane="
<<
e
.
type
().
lanes
()
<<
" to "
<<
lanes
;
return
Broadcast
::
make
(
e
,
lanes
);
}
// Rewrite vectorized allocation access
// s[i] = s[i * lanes + var]
class
VecAllocAccess
:
public
IRMutator
{
public
:
VecAllocAccess
(
const
Variable
*
buf
,
Var
var
,
int
var_lanes
)
:
buf_
(
buf
),
var_
(
var
),
var_lanes_
(
var_lanes
)
{}
// Load
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Load
>
();
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
op
->
index
*
var_lanes_
+
var_
);
}
else
{
return
expr
;
}
}
// Store
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Store
>
();
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
op
->
index
*
var_lanes_
+
var_
);
}
else
{
return
stmt
;
}
}
private
:
// buffer var
const
Variable
*
buf_
;
// variable to be replaced
Var
var_
;
// the lanes.
int
var_lanes_
;
};
class
Vectorizer
:
public
IRMutator
{
public
:
Vectorizer
(
Var
var
,
int
var_lanes
)
:
var_
(
var
),
var_lanes_
(
var_lanes
)
{
ramp_
=
Ramp
::
make
(
0
,
1
,
var_lanes
);
}
// user mutate from parent.
using
IRMutator
::
Mutate
;
// override mutate
Expr
Mutate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
Vectorizer
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
Mutate
(
expr
));
}
// Variable
Expr
Mutate_
(
const
Variable
*
v
,
const
Expr
&
e
)
final
{
if
(
v
==
var_
.
get
())
{
return
ramp_
;
}
else
if
(
lets_
.
count
(
v
))
{
return
lets_
[
v
];
}
else
{
return
e
;
}
}
// Call
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
int
lane
=
0
;
Array
<
Expr
>
new_args
=
MutateArray
(
op
->
args
,
&
lane
);
if
(
op
->
args
.
same_as
(
new_args
))
{
return
e
;
}
else
{
return
Call
::
make
(
op
->
type
.
with_lanes
(
lane
),
op
->
name
,
new_args
,
op
->
call_type
,
op
->
func
,
op
->
value_index
);
}
}
// Load
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
return
e
;
}
else
{
return
Load
::
make
(
op
->
type
.
with_lanes
(
index
.
type
().
lanes
()),
op
->
buffer_var
,
index
);
}
}
// Let
Expr
Mutate_
(
const
Let
*
op
,
const
Expr
&
e
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
CHECK
(
!
lets_
.
count
(
op
->
var
.
get
()))
<<
"not SSA"
;
if
(
value
.
type
().
lanes
()
!=
op
->
value
.
type
().
lanes
())
{
Var
v
(
op
->
var
->
name_hint
,
value
.
type
());
lets_
[
op
->
var
.
get
()]
=
v
;
return
Let
::
make
(
v
,
value
,
Mutate
(
op
->
body
));
}
else
{
Expr
body
=
this
->
Mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
e
;
}
else
{
return
Let
::
make
(
op
->
var
,
value
,
body
);
}
}
}
// Provide
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
Expr
new_value
=
this
->
Mutate
(
op
->
value
);
int
lane
=
new_value
.
type
().
lanes
();
Array
<
Expr
>
new_args
=
MutateArray
(
op
->
args
,
&
lane
);
if
(
op
->
args
.
same_as
(
new_args
)
&&
op
->
value
.
same_as
(
new_value
))
{
return
s
;
}
else
{
new_value
=
BroadcastTo
(
new_value
,
lane
);
return
Provide
::
make
(
op
->
func
,
op
->
value_index
,
new_value
,
new_args
);
}
}
// Store
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
return
s
;
}
else
{
int
lanes
=
std
::
max
(
value
.
type
().
lanes
(),
index
.
type
().
lanes
());
return
Store
::
make
(
op
->
buffer_var
,
BroadcastTo
(
value
,
lanes
),
BroadcastTo
(
index
,
lanes
));
}
}
// For
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
for_type
==
ForType
::
Vectorized
)
{
LOG
(
WARNING
)
<<
"Detect vectorize inside vectorized loop, ignoring..."
;
}
CHECK
(
is_zero
(
op
->
min
));
CHECK
(
!
op
->
extent
.
type
().
is_vector
());
Expr
extent
=
Mutate
(
op
->
extent
);
if
(
extent
.
type
().
is_vector
())
{
LOG
(
WARNING
)
<<
"Detect vectorized extent type, scalarizing..."
;
return
Scalarize
(
s
);
}
Stmt
body
=
Mutate
(
op
->
body
);
if
(
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
for_type
,
op
->
device_api
,
body
);
}
}
// IfThenElse
Stmt
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
)
final
{
CHECK
(
!
op
->
condition
.
type
().
is_vector
());
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
if
(
condition
.
type
().
is_vector
())
{
LOG
(
WARNING
)
<<
"Detect vector condition in Vectorized Loop, scalarizing..."
;
return
Scalarize
(
s
);
}
Stmt
then_case
=
this
->
Mutate
(
op
->
then_case
);
Stmt
else_case
;
if
(
else_case
.
defined
())
{
else_case
=
this
->
Mutate
(
op
->
else_case
);
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
condition
,
then_case
,
else_case
);
}
}
// LetStmt
Stmt
Mutate_
(
const
LetStmt
*
op
,
const
Stmt
&
s
)
final
{
LOG
(
WARNING
)
<<
"Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize"
;
return
Scalarize
(
s
);
}
// Allocate
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
new_expr
.
defined
())
{
LOG
(
WARNING
)
<<
"Cannot vectorize with new expr"
;
return
Scalarize
(
s
);
}
Expr
condition
=
Mutate
(
op
->
condition
);
if
(
condition
.
type
().
is_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc "
;
return
Scalarize
(
s
);
}
Array
<
Expr
>
extents
;
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
Expr
new_ext
=
Mutate
(
op
->
extents
[
i
]);
if
(
new_ext
.
type
().
is_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc "
;
return
Scalarize
(
s
);
}
extents
.
push_back
(
new_ext
);
}
// place the vector lanes in least significant dimension.
extents
.
push_back
(
var_lanes_
);
// rewrite access to buffer internally.
Stmt
body
=
VecAllocAccess
(
op
->
buffer_var
.
get
(),
var_
,
var_lanes_
).
Mutate
(
op
->
body
);
body
=
Mutate
(
body
);
return
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
extents
,
condition
,
body
,
op
->
new_expr
,
op
->
free_function
);
}
// scalarize the statment
Stmt
Scalarize
(
Stmt
stmt
)
{
Var
idx
(
var_
->
name_hint
+
".s"
,
var_
->
type
);
stmt
=
Substitute
(
stmt
,
{{
var_
,
idx
}});
return
For
::
make
(
idx
,
0
,
var_lanes_
,
ForType
::
Serial
,
DeviceAPI
::
None
,
stmt
);
}
// The overloads for vectorize.
static
FMutateExpr
&
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
}
private
:
// variable to be replaced
Var
var_
;
// the lanes.
int
var_lanes_
;
// ramp representing the var.
Expr
ramp_
;
// The lets
std
::
unordered_map
<
const
Variable
*
,
Expr
>
lets_
;
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array
<
Expr
>
MutateArray
(
Array
<
Expr
>
arr
,
int
*
p_lanes
)
{
if
(
arr
.
size
()
==
0
)
return
arr
;
int
&
lanes
=
*
p_lanes
;
bool
changed
=
false
;
std
::
vector
<
Expr
>
new_arr
(
arr
.
size
());
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
Expr
old_elem
=
arr
[
i
];
Expr
new_elem
=
this
->
Mutate
(
old_elem
);
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
new_arr
[
i
]
=
new_elem
;
lanes
=
std
::
max
(
lanes
,
new_elem
.
type
().
lanes
());
}
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
if
(
new_arr
[
i
].
type
().
lanes
()
!=
lanes
)
{
new_arr
[
i
]
=
BroadcastTo
(
new_arr
[
i
],
lanes
);
changed
=
true
;
}
}
if
(
!
changed
)
return
arr
;
return
Array
<
Expr
>
(
new_arr
);
}
};
// binary vectorize
template
<
typename
T
>
inline
Expr
BinaryVec
(
const
T
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
Mutate
(
op
->
a
);
Expr
b
=
m
->
Mutate
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
e
;
}
else
{
int
lanes
=
std
::
max
(
a
.
type
().
lanes
(),
b
.
type
().
lanes
());
return
T
::
make
(
BroadcastTo
(
a
,
lanes
),
BroadcastTo
(
b
,
lanes
));
}
}
template
<
typename
T
>
inline
Expr
AddSubVec
(
const
T
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
Mutate
(
op
->
a
);
Expr
b
=
m
->
Mutate
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
e
;
}
else
{
int
lanes
=
std
::
max
(
a
.
type
().
lanes
(),
b
.
type
().
lanes
());
if
(
lanes
!=
1
)
{
const
Ramp
*
b_ramp
=
b
.
as
<
Ramp
>
();
const
Ramp
*
a_ramp
=
a
.
as
<
Ramp
>
();
if
(
a
.
type
().
lanes
()
==
1
&&
b_ramp
)
{
return
Ramp
::
make
(
arith
::
ComputeExpr
<
T
>
(
a
,
b_ramp
->
base
),
b_ramp
->
stride
,
b_ramp
->
lanes
);
}
if
(
b
.
type
().
lanes
()
==
1
&&
a_ramp
)
{
return
Ramp
::
make
(
arith
::
ComputeExpr
<
T
>
(
a_ramp
->
base
,
b
),
a_ramp
->
stride
,
a_ramp
->
lanes
);
}
}
return
T
::
make
(
BroadcastTo
(
a
,
lanes
),
BroadcastTo
(
b
,
lanes
));
}
}
TVM_STATIC_IR_FUNCTOR
(
Vectorizer
,
vtable_expr
)
.
set_dispatch
<
Add
>
(
AddSubVec
<
Add
>
)
.
set_dispatch
<
Sub
>
(
AddSubVec
<
Sub
>
)
.
set_dispatch
<
Mul
>
(
BinaryVec
<
Mul
>
)
.
set_dispatch
<
Div
>
(
BinaryVec
<
Div
>
)
.
set_dispatch
<
Mod
>
(
BinaryVec
<
Mod
>
)
.
set_dispatch
<
Min
>
(
BinaryVec
<
Min
>
)
.
set_dispatch
<
Max
>
(
BinaryVec
<
Max
>
)
.
set_dispatch
<
EQ
>
(
BinaryVec
<
EQ
>
)
.
set_dispatch
<
NE
>
(
BinaryVec
<
NE
>
)
.
set_dispatch
<
LT
>
(
BinaryVec
<
LT
>
)
.
set_dispatch
<
LE
>
(
BinaryVec
<
LE
>
)
.
set_dispatch
<
GT
>
(
BinaryVec
<
GT
>
)
.
set_dispatch
<
GE
>
(
BinaryVec
<
GE
>
)
.
set_dispatch
<
And
>
(
BinaryVec
<
And
>
)
.
set_dispatch
<
Or
>
(
BinaryVec
<
Or
>
);
TVM_STATIC_IR_FUNCTOR
(
Vectorizer
,
vtable_expr
)
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
cond
=
m
->
Mutate
(
op
->
condition
);
Expr
t
=
m
->
Mutate
(
op
->
true_value
);
Expr
f
=
m
->
Mutate
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
return
e
;
}
else
{
int
lanes
=
std
::
max
(
std
::
max
(
cond
.
type
().
lanes
(),
t
.
type
().
lanes
()),
f
.
type
().
lanes
());
return
Select
::
make
(
cond
,
BroadcastTo
(
t
,
lanes
),
BroadcastTo
(
f
,
lanes
));
}
})
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
Mutate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
}
else
{
return
Cast
::
make
(
op
->
type
.
with_lanes
(
value
.
type
().
lanes
()),
value
);
}
});
class
LoopVectorizer
:
public
IRMutator
{
public
:
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
for_type
==
ForType
::
Vectorized
)
{
CHECK
(
is_zero
(
op
->
min
));
CHECK
(
is_positive_const
(
op
->
extent
));
int
lanes
=
0
;
bool
succ
=
arith
::
GetConstInt
(
op
->
extent
,
&
lanes
);
if
(
!
succ
||
lanes
<
1
)
{
LOG
(
FATAL
)
<<
"Failed to vectorize loop with extent "
<<
op
->
extent
;
}
Var
var
(
op
->
loop_var
.
node_
);
return
Vectorizer
(
var
,
lanes
).
Mutate
(
op
->
body
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
};
Stmt
VectorizeLoop
(
Stmt
stmt
)
{
return
LoopVectorizer
().
Mutate
(
stmt
);
}
}
// namespace ir
}
// namespace tvm
src/schedule/schedule_lang.cc
View file @
45597d00
...
...
@@ -57,6 +57,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<<
")"
;
});
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IterVarAttrNode
>
([](
const
IterVarAttrNode
*
op
,
IRPrinter
*
p
)
{
switch
(
op
->
iter_type
)
{
case
kUnrolled
:
p
->
stream
<<
"unroll"
;
break
;
case
kVectorized
:
p
->
stream
<<
"vectorize"
;
break
;
}
});
Stage
::
Stage
(
Operation
op
)
{
auto
n
=
std
::
make_shared
<
StageNode
>
();
n
->
op
=
op
;
...
...
@@ -246,7 +254,38 @@ void Schedule::normalize() {
}
}
IterVarAttr
::
IterVarAttr
(
IterVarType
t
)
{
std
::
shared_ptr
<
IterVarAttrNode
>
n
=
std
::
make_shared
<
IterVarAttrNode
>
();
n
->
iter_type
=
t
;
node_
=
n
;
}
inline
void
SetAttr
(
StageNode
*
self
,
IterVar
var
,
IterVarAttr
attr
)
{
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
FindLeafVar
(
all_vars
,
leaf_vars
,
var
);
auto
it
=
self
->
iter_var_attrs
.
find
(
var
);
if
(
it
!=
self
->
iter_var_attrs
.
end
())
{
CHECK_EQ
((
*
it
).
second
->
iter_type
,
attr
->
iter_type
)
<<
"IterVar's is already set to "
<<
(
*
it
).
second
<<
" instead of "
<<
attr
;
}
else
{
self
->
iter_var_attrs
.
Set
(
var
,
attr
);
}
}
Stage
&
Stage
::
vectorize
(
IterVar
var
)
{
// NOLINT(*)
SetAttr
(
operator
->
(),
var
,
IterVarAttr
(
kVectorized
));
return
*
this
;
}
Stage
&
Stage
::
unroll
(
IterVar
var
)
{
// NOLINT(*)
SetAttr
(
operator
->
(),
var
,
IterVarAttr
(
kUnrolled
));
return
*
this
;
}
TVM_REGISTER_NODE_TYPE
(
StageNode
);
TVM_REGISTER_NODE_TYPE
(
IterVarAttrNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
TVM_REGISTER_NODE_TYPE
(
RebaseNode
);
...
...
src/schedule/schedule_ops.cc
View file @
45597d00
...
...
@@ -177,6 +177,13 @@ MakeLoopNest(const Stage& sch,
}
// Mark the iter var in the IR, to remember the point
if
(
iv
->
thread_tag
.
length
()
==
0
)
{
ForType
for_type
=
ForType
::
Serial
;
if
(
sch
->
iter_var_attrs
.
count
(
iv
))
{
switch
(
sch
->
iter_var_attrs
[
iv
]
->
iter_type
)
{
case
kUnrolled
:
for_type
=
ForType
::
Unrolled
;
break
;
case
kVectorized
:
for_type
=
ForType
::
Vectorized
;
break
;
}
}
if
(
is_one
(
dom
->
extent
))
{
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
dom
->
min
,
no_op
));
...
...
@@ -184,13 +191,13 @@ MakeLoopNest(const Stage& sch,
}
else
if
(
is_zero
(
dom
->
min
))
{
nest
[
i
+
1
].
emplace_back
(
For
::
make
(
var
,
0
,
dom
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
for_type
,
DeviceAPI
::
None
,
no_op
));
value_map
[
iv
]
=
var
;
}
else
{
Var
idx
(
iv
->
var
->
name_hint
+
".idx"
,
iv
->
var
.
type
());
nest
[
i
+
1
].
emplace_back
(
For
::
make
(
idx
,
0
,
dom
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
for_type
,
DeviceAPI
::
None
,
no_op
));
Expr
new_value
=
dom
->
min
+
idx
;
value_map
[
iv
]
=
new_value
;
nest
[
i
+
1
].
emplace_back
(
...
...
tests/python/integration/test_ewise.py
View file @
45597d00
...
...
@@ -3,7 +3,7 @@ import numpy as np
def
test_add
():
# graph
n
=
tvm
.
Var
(
'n'
)
n
=
tvm
.
convert
(
1024
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
B
(
*
i
),
name
=
'C'
)
...
...
@@ -13,26 +13,28 @@ def test_add():
num_thread
=
256
block_x
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.x"
)
thread_x
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.x"
)
_
,
x
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
num_thread
,
outer
=
block_x
)
_
,
x
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
num_thread
*
4
,
outer
=
block_x
)
_
,
x
=
s
[
C
]
.
split
(
x
,
outer
=
thread_x
)
_
,
x
=
s
[
C
]
.
split
(
x
,
factor
=
4
)
s
[
C
]
.
vectorize
(
x
)
# one line to build the function.
def
check_device
(
target
):
codes
=
[]
fadd
=
tvm
.
build
(
s
,
args
=
[
A
,
B
,
C
],
target
=
"cuda"
,
name
=
"myadd"
,
record_codes
=
codes
)
for
c
in
codes
:
print
(
c
)
# call the function
num_device
=
1
for
i
in
range
(
num_device
):
ctx
=
tvm
.
gpu
(
i
)
fadd
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
target
,
record_codes
=
codes
,
name
=
"myadd"
)
if
target
==
"cuda"
:
ctx
=
tvm
.
gpu
(
0
)
else
:
ctx
=
tvm
.
cl
(
0
)
if
not
ctx
.
enabled
:
continue
return
for
c
in
codes
[
1
:]:
print
(
c
)
# launch the kernel.
n
=
102
7
n
=
102
4
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
...
...
@@ -40,6 +42,10 @@ def test_add():
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
tvm
.
init_opencl
()
check_device
(
"cuda"
)
check_device
(
"opencl"
)
if
__name__
==
"__main__"
:
test_add
()
tests/python/unittest/test_lang_schedule.py
View file @
45597d00
...
...
@@ -76,6 +76,21 @@ def test_fuse():
assert
any
(
isinstance
(
x
,
tvm
.
schedule
.
Fuse
)
for
x
in
s
[
T
]
.
relations
)
assert
tuple
(
s
[
T
]
.
leaf_iter_vars
)
==
(
fused
,
xi
,
yi
)
def
test_vectorize
():
m
=
tvm
.
Var
(
'm'
)
n
=
tvm
.
Var
(
'n'
)
A
=
tvm
.
placeholder
((
m
,
n
),
name
=
'A'
)
T
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
A
[
i
,
j
])
s
=
tvm
.
Schedule
(
T
.
op
)
xo
,
yo
,
xi
,
yi
=
s
[
T
]
.
tile
(
T
.
op
.
axis
[
0
],
T
.
op
.
axis
[
1
],
x_factor
=
10
,
y_factor
=
5
)
s
[
T
]
.
vectorize
(
yi
)
s
[
T
]
.
unroll
(
xi
)
UNROLL
=
1
VECTORIZE
=
2
assert
s
[
T
]
.
iter_var_attrs
[
xi
]
.
iter_type
==
UNROLL
assert
s
[
T
]
.
iter_var_attrs
[
yi
]
.
iter_type
==
VECTORIZE
if
__name__
==
"__main__"
:
test_schedule_create
()
...
...
@@ -83,3 +98,4 @@ if __name__ == "__main__":
test_tile
()
test_split
()
test_fuse
()
test_vectorize
()
tests/python/unittest/test_pass_unroll.py
View file @
45597d00
...
...
@@ -9,11 +9,13 @@ def test_unroll_loop():
# for i in 0 to n-1:
stmt
=
tvm
.
make
.
For
(
i
,
n
,
2
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
n
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
8
,
3
,
0
,
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
j
+
1
)))
stmt
=
tvm
.
ir_pass
.
UnrollLoop
(
stmt
,
8
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
stmt
=
tvm
.
ir_pass
.
UnrollLoop
(
stmt
,
4
)
assert
not
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
print
(
stmt
)
if
__name__
==
"__main__"
:
...
...
tests/python/unittest/test_pass_vectorize.py
0 → 100644
View file @
45597d00
import
tvm
def
test_vectorize_loop
():
dtype
=
'int64'
n
=
tvm
.
Var
(
'n'
)
Ab
=
tvm
.
Buffer
((
n
,
),
dtype
)
i
=
tvm
.
Var
(
'i'
)
j
=
tvm
.
Var
(
'j'
)
VECTORIZE
=
2
# for i in 0 to n-1:
stmt
=
tvm
.
make
.
For
(
i
,
n
,
2
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
4
,
VECTORIZE
,
0
,
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
j
+
1
)))
assert
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
stmt
=
tvm
.
ir_pass
.
VectorizeLoop
(
stmt
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
assert
not
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
print
(
stmt
)
if
__name__
==
"__main__"
:
test_vectorize_loop
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment