Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4cebb1c7
Unverified
Commit
4cebb1c7
authored
Apr 20, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 20, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Remove legacy const pattern functions (#5387)
parent
d9cecdf5
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
80 additions
and
105 deletions
+80
-105
src/arith/compute_expr.h
+0
-21
src/arith/pattern_match.h
+11
-0
src/target/llvm/codegen_llvm.cc
+13
-12
src/target/source/codegen_c.cc
+10
-9
src/target/spirv/codegen_spirv.cc
+3
-3
src/tir/pass/arg_binder.cc
+3
-3
src/tir/pass/ir_util.h
+0
-16
src/tir/transforms/inject_virtual_thread.cc
+4
-4
src/tir/transforms/lower_thread_allreduce.cc
+3
-1
src/tir/transforms/lower_tvm_builtin.cc
+2
-4
src/tir/transforms/lower_warp_memory.cc
+20
-18
src/tir/transforms/storage_flatten.cc
+6
-5
src/tir/transforms/unroll_loop.cc
+2
-5
src/tir/transforms/vectorize_loop.cc
+3
-4
No files found.
src/arith/compute_expr.h
View file @
4cebb1c7
...
@@ -56,27 +56,6 @@ template<typename Op>
...
@@ -56,27 +56,6 @@ template<typename Op>
inline
PrimExpr
ComputeReduce
(
inline
PrimExpr
ComputeReduce
(
const
Array
<
PrimExpr
>&
values
,
PrimExpr
empty_value
);
const
Array
<
PrimExpr
>&
values
,
PrimExpr
empty_value
);
inline
bool
GetConst
(
PrimExpr
e
,
int64_t
*
out
)
{
if
(
e
.
dtype
().
is_vector
())
return
false
;
const
int64_t
*
v
=
tir
::
as_const_int
(
e
);
if
(
v
)
{
*
out
=
*
v
;
return
true
;
}
else
{
return
false
;
}
}
// get a small constant int
inline
bool
GetConstInt
(
PrimExpr
e
,
int
*
out
)
{
int64_t
v1
=
0
;
if
(
GetConst
(
e
,
&
v1
))
{
if
(
v1
>
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int
>::
max
()))
return
false
;
*
out
=
static_cast
<
int
>
(
v1
);
return
true
;
}
return
false
;
}
template
<>
template
<>
inline
PrimExpr
Compute
<
tir
::
AddNode
>
(
PrimExpr
a
,
PrimExpr
b
)
{
inline
PrimExpr
Compute
<
tir
::
AddNode
>
(
PrimExpr
a
,
PrimExpr
b
)
{
return
a
+
b
;
return
a
+
b
;
...
...
src/arith/pattern_match.h
View file @
4cebb1c7
...
@@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base,
...
@@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base,
base
.
derived
(),
stride
.
derived
(),
lanes
.
derived
());
base
.
derived
(),
stride
.
derived
(),
lanes
.
derived
());
}
}
template
<
typename
TBase
>
inline
PRampExpr
<
TBase
,
PConstWithTypeLike
<
TBase
>
,
PConst
<
int
>>
ramp
(
const
Pattern
<
TBase
>&
base
,
int
stride
,
int
lanes
)
{
return
PRampExpr
<
TBase
,
PConstWithTypeLike
<
TBase
>
,
PConst
<
int
>>
(
base
.
derived
(),
PConstWithTypeLike
<
TBase
>
(
base
.
derived
(),
stride
),
PConst
<
int
>
(
lanes
));
}
/*!
/*!
* \brief Pattern broadcast expression.
* \brief Pattern broadcast expression.
* \tparam TA The pattern type of the value.
* \tparam TA The pattern type of the value.
...
...
src/target/llvm/codegen_llvm.cc
View file @
4cebb1c7
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include "codegen_llvm.h"
#include "codegen_llvm.h"
#include "codegen_cpu.h"
#include "codegen_cpu.h"
#include "../../arith/pattern_match.h"
#include "../build_common.h"
#include "../build_common.h"
namespace
tvm
{
namespace
tvm
{
namespace
codegen
{
namespace
codegen
{
...
@@ -363,16 +364,16 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
...
@@ -363,16 +364,16 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
md_builder_
->
createTBAAStructTagNode
(
meta
,
meta
,
0
));
md_builder_
->
createTBAAStructTagNode
(
meta
,
meta
,
0
));
return
;
return
;
}
}
int
base
=
0
,
width
=
0
;
int64_t
base
=
0
,
width
=
0
;
arith
::
PVar
<
IntImm
>
pbase
,
pstride
;
arith
::
PVar
<
int
>
planes
;
// create meta-data for alias analysis
// create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks.
// Use a group of binary tree ranges of memory banks.
if
(
index
.
defined
())
{
if
(
index
.
defined
())
{
const
RampNode
*
ramp
=
index
.
as
<
RampNode
>
();
if
(
arith
::
ramp
(
pbase
,
pstride
,
planes
).
Match
(
index
))
{
if
(
ramp
)
{
base
=
pbase
.
Eval
()
->
value
;
int
base
,
stride
;
int64_t
xwith
=
planes
.
Eval
()
*
pstride
.
Eval
()
->
value
;
if
(
arith
::
GetConstInt
(
ramp
->
base
,
&
base
)
&&
arith
::
GetConstInt
(
ramp
->
stride
,
&
stride
))
{
int
xwith
=
ramp
->
lanes
*
stride
;
width
=
1
;
width
=
1
;
while
(
width
<
xwith
)
{
while
(
width
<
xwith
)
{
width
*=
2
;
width
*=
2
;
...
@@ -381,9 +382,9 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
...
@@ -381,9 +382,9 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
base
-=
base
%
width
;
base
-=
base
%
width
;
width
*=
2
;
width
*=
2
;
}
}
}
}
else
if
(
auto
*
ptr
=
index
.
as
<
tir
::
IntImmNode
>
())
{
}
else
{
width
=
1
;
if
(
arith
::
GetConstInt
(
index
,
&
base
))
width
=
1
;
base
=
ptr
->
value
;
}
}
}
}
llvm
::
MDNode
*
meta
=
md_tbaa_root_
;
llvm
::
MDNode
*
meta
=
md_tbaa_root_
;
...
@@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
...
@@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
meta
=
md_builder_
->
createTBAAScalarTypeNode
(
buffer_type
.
str
(),
meta
);
meta
=
md_builder_
->
createTBAAScalarTypeNode
(
buffer_type
.
str
(),
meta
);
// create a tree-shape access structure.
// create a tree-shape access structure.
if
(
width
!=
0
)
{
if
(
width
!=
0
)
{
for
(
int
w
=
1024
;
w
>=
width
;
w
/=
2
)
{
for
(
int
64_t
w
=
1024
;
w
>=
width
;
w
/=
2
)
{
int
b
=
(
base
/
w
)
*
w
;
int
64_t
b
=
(
base
/
w
)
*
w
;
std
::
stringstream
os
;
std
::
stringstream
os
;
os
<<
buffer
<<
".w"
<<
w
<<
".b"
<<
b
;
os
<<
buffer
<<
".w"
<<
w
<<
".b"
<<
b
;
meta
=
md_builder_
->
createTBAAScalarTypeNode
(
os
.
str
(),
meta
);
meta
=
md_builder_
->
createTBAAScalarTypeNode
(
os
.
str
(),
meta
);
...
...
src/target/source/codegen_c.cc
View file @
4cebb1c7
...
@@ -23,8 +23,8 @@
...
@@ -23,8 +23,8 @@
#include <iomanip>
#include <iomanip>
#include <cctype>
#include <cctype>
#include "codegen_c.h"
#include "codegen_c.h"
#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
#include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"
namespace
tvm
{
namespace
tvm
{
namespace
codegen
{
namespace
codegen
{
...
@@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef(
...
@@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef(
// optimize for case where it is in register,
// optimize for case where it is in register,
if
(
HandleTypeMatch
(
buffer
,
t
)
&&
!
is_vol
)
{
if
(
HandleTypeMatch
(
buffer
,
t
)
&&
!
is_vol
)
{
// optimize for constant access
// optimize for constant access
i
nt
offset
;
i
f
(
auto
*
ptr
=
index
.
as
<
tir
::
IntImmNode
>
())
{
if
(
arith
::
GetConstInt
(
index
,
&
offset
))
{
int64_t
offset
=
ptr
->
value
;
CHECK_EQ
(
offset
%
t
.
lanes
(),
0
)
CHECK_EQ
(
offset
%
t
.
lanes
(),
0
)
<<
"Find unaligned vector load to a vector type"
;
<<
"Find unaligned vector load to a vector type"
;
os
<<
vid
<<
'['
<<
(
offset
/
t
.
lanes
())
<<
']'
;
os
<<
vid
<<
'['
<<
(
offset
/
t
.
lanes
())
<<
']'
;
...
@@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
...
@@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
}
else
{
}
else
{
CHECK
(
is_one
(
op
->
predicate
))
CHECK
(
is_one
(
op
->
predicate
))
<<
"predicated load is not supported"
;
<<
"predicated load is not supported"
;
PrimExpr
base
;
if
(
GetRamp1Base
(
op
->
index
,
op
->
dtype
.
lanes
(),
&
base
))
{
arith
::
PVar
<
PrimExpr
>
base
;
std
::
string
ref
=
GetVecLoad
(
op
->
dtype
,
op
->
buffer_var
.
get
(),
base
);
if
(
arith
::
ramp
(
base
,
1
,
op
->
dtype
.
lanes
()).
Match
(
op
->
index
))
{
std
::
string
ref
=
GetVecLoad
(
op
->
dtype
,
op
->
buffer_var
.
get
(),
base
.
Eval
());
HandleVolatileLoads
(
ref
,
op
,
os
);
HandleVolatileLoads
(
ref
,
op
,
os
);
}
else
{
}
else
{
std
::
ostringstream
svalue_expr
;
std
::
ostringstream
svalue_expr
;
...
@@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
...
@@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
}
else
{
}
else
{
CHECK
(
is_one
(
op
->
predicate
))
CHECK
(
is_one
(
op
->
predicate
))
<<
"Predicated store is not supported"
;
<<
"Predicated store is not supported"
;
PrimExpr
base
;
arith
::
PVar
<
PrimExpr
>
base
;
if
(
GetRamp1Base
(
op
->
index
,
t
.
lanes
(),
&
base
))
{
if
(
arith
::
ramp
(
base
,
1
,
t
.
lanes
()).
Match
(
op
->
index
))
{
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
this
->
PrintVecStore
(
op
->
buffer_var
.
get
(),
t
,
base
,
value
);
this
->
PrintVecStore
(
op
->
buffer_var
.
get
(),
t
,
base
.
Eval
()
,
value
);
}
else
{
}
else
{
// The assignment below introduces side-effect, and the resulting value cannot
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
// be reused across multiple expression, thus a new scope is needed
...
...
src/target/spirv/codegen_spirv.cc
View file @
4cebb1c7
...
@@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
...
@@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
spirv
::
Value
v
;
spirv
::
Value
v
;
if
(
ts
.
rank
==
1
)
{
if
(
ts
.
rank
==
1
)
{
v
=
builder_
->
GetLocalID
(
ts
.
dim_index
);
v
=
builder_
->
GetLocalID
(
ts
.
dim_index
);
int
size
=
0
;
auto
*
sizeptr
=
extent
.
as
<
tir
::
IntImmNode
>
()
;
CHECK
(
arith
::
GetConstInt
(
extent
,
&
size
)
)
CHECK
(
sizeptr
)
<<
"SPIRV only allows constant thread group size "
<<
" get "
<<
extent
;
<<
"SPIRV only allows constant thread group size "
<<
" get "
<<
extent
;
CHECK_LT
(
ts
.
dim_index
,
3
);
CHECK_LT
(
ts
.
dim_index
,
3
);
workgroup_size_
[
ts
.
dim_index
]
=
static_cast
<
uint32_t
>
(
size
);
workgroup_size_
[
ts
.
dim_index
]
=
static_cast
<
uint32_t
>
(
size
ptr
->
value
);
}
else
{
}
else
{
v
=
builder_
->
GetWorkgroupID
(
ts
.
dim_index
);
v
=
builder_
->
GetWorkgroupID
(
ts
.
dim_index
);
}
}
...
...
src/tir/pass/arg_binder.cc
View file @
4cebb1c7
...
@@ -291,9 +291,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
...
@@ -291,9 +291,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
}
}
// Byte_offset field.
// Byte_offset field.
int
data_bytes
=
GetVectorBytes
(
buffer
->
dtype
);
int
data_bytes
=
GetVectorBytes
(
buffer
->
dtype
);
int64_t
const_offset
;
if
(
arith
::
GetConst
(
buffer
->
elem_offset
,
&
const_offset
))
{
if
(
const
auto
*
const_offset
=
buffer
->
elem_offset
.
as
<
IntImmNode
>
(
))
{
Bind_
(
make_const
(
DataType
::
UInt
(
64
),
const_offset
*
data_bytes
),
Bind_
(
make_const
(
DataType
::
UInt
(
64
),
const_offset
->
value
*
data_bytes
),
TVMArrayGet
(
DataType
::
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
),
TVMArrayGet
(
DataType
::
UInt
(
64
),
handle
,
intrinsic
::
kArrByteOffset
),
arg_name
+
".byte_offset"
,
true
);
arg_name
+
".byte_offset"
,
true
);
}
else
{
}
else
{
...
...
src/tir/pass/ir_util.h
View file @
4cebb1c7
...
@@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
...
@@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
return
align
;
return
align
;
}
}
/*!
* \brief Pattern match index to Ramp with stride=1
* This is a common pattern in continuous memory load.
* \param index The index formula
* \param lanes number of lanes in the ramp
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
inline
bool
GetRamp1Base
(
PrimExpr
index
,
int
lanes
,
PrimExpr
*
base
)
{
const
RampNode
*
r
=
index
.
as
<
RampNode
>
();
if
(
!
r
)
return
false
;
if
(
!
is_one
(
r
->
stride
))
return
false
;
CHECK_EQ
(
r
->
lanes
,
lanes
);
*
base
=
r
->
base
;
return
true
;
}
}
// namespace tir
}
// namespace tir
}
// namespace tvm
}
// namespace tvm
#endif // TVM_TIR_PASS_IR_UTIL_H_
#endif // TVM_TIR_PASS_IR_UTIL_H_
src/tir/transforms/inject_virtual_thread.cc
View file @
4cebb1c7
...
@@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor {
...
@@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor {
}
}
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_access_ptr
))
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_access_ptr
))
{
int
rw_mask
=
0
;
const
auto
*
rw_mask
=
op
->
args
[
4
].
as
<
IntImmNode
>
();
CHECK
(
arith
::
GetConstInt
(
op
->
args
[
4
],
&
rw_mask
));
const
VarNode
*
buffer_var
=
op
->
args
[
1
].
as
<
VarNode
>
();
const
VarNode
*
buffer_var
=
op
->
args
[
1
].
as
<
VarNode
>
();
CHECK
(
buffer_var
);
CHECK
(
buffer_var
);
CHECK
(
rw_mask
);
// read
// read
if
(
rw_mask
&
1
)
{
if
(
rw_mask
->
value
&
1
)
{
HandleUseVar
(
buffer_var
);
HandleUseVar
(
buffer_var
);
}
}
if
(
rw_mask
&
2
)
{
if
(
rw_mask
->
value
&
2
)
{
HandleWriteVar
(
buffer_var
);
HandleWriteVar
(
buffer_var
);
}
}
this
->
VisitExpr
(
op
->
args
[
2
]);
this
->
VisitExpr
(
op
->
args
[
2
]);
...
...
src/tir/transforms/lower_thread_allreduce.cc
View file @
4cebb1c7
...
@@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
...
@@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
CHECK_GE
(
e
.
scope
.
dim_index
,
0
)
CHECK_GE
(
e
.
scope
.
dim_index
,
0
)
<<
"vthread do not work with cross thread reduction"
;
<<
"vthread do not work with cross thread reduction"
;
if
(
e
.
scope
.
rank
==
1
)
{
if
(
e
.
scope
.
rank
==
1
)
{
CHECK
(
arith
::
GetConstInt
(
attr
->
value
,
&
(
e
.
extent
)))
const
auto
*
ptr
=
attr
->
value
.
as
<
IntImmNode
>
();
CHECK
(
ptr
)
<<
"Need constant extent for reduce set "
<<
iv
;
<<
"Need constant extent for reduce set "
<<
iv
;
e
.
extent
=
static_cast
<
int
>
(
ptr
->
value
);
if
(
reduce_set
.
count
(
iv
->
var
.
get
()))
{
if
(
reduce_set
.
count
(
iv
->
var
.
get
()))
{
vred
.
push_back
(
e
);
vred
.
push_back
(
e
);
++
nmatch
;
++
nmatch
;
...
...
src/tir/transforms/lower_tvm_builtin.cc
View file @
4cebb1c7
...
@@ -30,7 +30,6 @@
...
@@ -30,7 +30,6 @@
#include <unordered_set>
#include <unordered_set>
#include "../pass/ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
tir
{
namespace
tir
{
...
@@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator {
...
@@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator {
Stmt
stmt
=
StmtExprMutator
::
VisitStmt_
(
op
);
Stmt
stmt
=
StmtExprMutator
::
VisitStmt_
(
op
);
op
=
stmt
.
as
<
AllocateNode
>
();
op
=
stmt
.
as
<
AllocateNode
>
();
// Get constant allocation bound.
// Get constant allocation bound.
int64_t
dev_type
;
int64_t
nbytes
=
GetVectorBytes
(
op
->
dtype
);
int64_t
nbytes
=
GetVectorBytes
(
op
->
dtype
);
if
(
device_type_
.
defined
())
{
if
(
device_type_
.
defined
())
{
if
(
arith
::
GetConst
(
device_type_
,
&
dev_type
))
{
if
(
const
auto
*
dev_type
=
device_type_
.
as
<
IntImmNode
>
(
))
{
if
(
dev_type
==
kDLCPU
)
{
if
(
dev_type
->
value
==
kDLCPU
)
{
int32_t
constant_size
=
op
->
constant_allocation_size
();
int32_t
constant_size
=
op
->
constant_allocation_size
();
if
(
constant_size
>
0
&&
constant_size
*
nbytes
<
runtime
::
kMaxStackAlloca
)
{
if
(
constant_size
>
0
&&
constant_size
*
nbytes
<
runtime
::
kMaxStackAlloca
)
{
return
stmt
;
return
stmt
;
...
...
src/tir/transforms/lower_warp_memory.cc
View file @
4cebb1c7
...
@@ -37,7 +37,7 @@
...
@@ -37,7 +37,7 @@
#include <unordered_set>
#include <unordered_set>
#include "../
pass/ir_util
.h"
#include "../
../arith/pattern_match
.h"
#include "../../arith/compute_expr.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
#include "../../runtime/thread_storage_scope.h"
...
@@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
...
@@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
if
(
op
->
value
.
dtype
().
lanes
()
==
1
)
{
if
(
op
->
value
.
dtype
().
lanes
()
==
1
)
{
UpdatePattern
(
op
->
index
);
UpdatePattern
(
op
->
index
);
}
else
{
}
else
{
PrimExpr
base
;
arith
::
PVar
<
PrimExpr
>
base
;
CHECK
(
GetRamp1Base
(
op
->
index
,
op
->
value
.
dtype
().
lanes
(),
&
base
))
CHECK
(
arith
::
ramp
(
base
,
1
,
op
->
value
.
dtype
().
lanes
()).
Match
(
op
->
index
))
<<
"LowerWarpMemory failed due to store index="
<<
op
->
index
<<
"LowerWarpMemory failed due to store index="
<<
op
->
index
<<
", can only handle continuous store"
;
<<
", can only handle continuous store"
;
UpdatePattern
(
base
);
UpdatePattern
(
base
.
Eval
()
);
}
}
}
else
{
}
else
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
...
@@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
...
@@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
arith
::
DetectLinearEquation
(
index
,
{
warp_index_
});
arith
::
DetectLinearEquation
(
index
,
{
warp_index_
});
CHECK_EQ
(
m
.
size
(),
2U
)
CHECK_EQ
(
m
.
size
(),
2U
)
<<
"LowerWarpMemory failed due to store index="
<<
index
;
<<
"LowerWarpMemory failed due to store index="
<<
index
;
int
coeff
=
0
;
PrimExpr
mcoeff
=
analyzer_
->
canonical_simplify
(
m
[
0
]);
PrimExpr
mcoeff
=
analyzer_
->
canonical_simplify
(
m
[
0
]);
const
auto
*
mcoeff_as_int
=
mcoeff
.
as
<
IntImmNode
>
();
CHECK
(
arith
::
GetConstInt
(
mcoeff
,
&
coeff
)
&&
coeff
>
0
)
CHECK
(
mcoeff_as_int
&&
mcoeff_as_int
->
value
>
0
)
<<
"LowerWarpMemory failed due to store index="
<<
index
<<
"LowerWarpMemory failed due to store index="
<<
index
<<
", require positive constant coefficient on warp index "
<<
warp_index_
<<
", require positive constant coefficient on warp index "
<<
warp_index_
<<
" but get "
<<
mcoeff
;
<<
" but get "
<<
mcoeff
;
if
(
warp_coeff_
!=
0
)
{
if
(
warp_coeff_
!=
0
)
{
CHECK_EQ
(
warp_coeff_
,
coeff
)
CHECK_EQ
(
warp_coeff_
,
mcoeff_as_int
->
value
)
<<
"LowerWarpMemory failed due to two different store coefficient to warp index"
;
<<
"LowerWarpMemory failed due to two different store coefficient to warp index"
;
}
else
{
}
else
{
warp_coeff_
=
coeff
;
warp_coeff_
=
mcoeff_as_int
->
value
;
}
}
}
}
...
@@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
...
@@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
// the warp index
// the warp index
Var
warp_index_
;
Var
warp_index_
;
// the coefficient
// the coefficient
int
warp_coeff_
{
0
};
int
64_t
warp_coeff_
{
0
};
// analyzer.
// analyzer.
arith
::
Analyzer
*
analyzer_
;
arith
::
Analyzer
*
analyzer_
;
};
};
...
@@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor {
...
@@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor {
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
if
(
iv
->
thread_tag
==
"threadIdx.x"
)
{
if
(
iv
->
thread_tag
==
"threadIdx.x"
)
{
int
value
=
0
;
auto
*
value_as_int
=
op
->
value
.
as
<
IntImmNode
>
()
;
CHECK
(
arith
::
GetConstInt
(
op
->
value
,
&
value
)
&&
CHECK
(
value_as_int
&&
value
<=
warp_size_
&&
value
_as_int
->
value
<=
warp_size_
&&
warp_size_
%
value
==
0
)
warp_size_
%
value
_as_int
->
value
==
0
)
<<
"Expect threadIdx.x 's size to be no larger than, and a factor of"
<<
"Expect threadIdx.x 's size to be no larger than, and a factor of"
<<
" warp size("
<<
warp_size_
<<
")"
<<
" to enable warp memory"
<<
" warp size("
<<
warp_size_
<<
")"
<<
" to enable warp memory"
<<
" but get "
<<
op
->
value
<<
" instead"
;
<<
" but get "
<<
op
->
value
<<
" instead"
;
...
@@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor {
...
@@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor {
<<
"Please create it using thread_axis once and reuse the axis "
<<
"Please create it using thread_axis once and reuse the axis "
<<
"across multiple binds in the same kernel"
;
<<
"across multiple binds in the same kernel"
;
}
else
{
}
else
{
width_
=
value
;
width_
=
value
_as_int
->
value
;
warp_index_
=
iv
;
warp_index_
=
iv
;
}
}
}
}
...
@@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
...
@@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
// in this access pattern.
// in this access pattern.
std
::
pair
<
PrimExpr
,
PrimExpr
>
SplitIndexByGroup
(
const
PrimExpr
&
index
)
{
std
::
pair
<
PrimExpr
,
PrimExpr
>
SplitIndexByGroup
(
const
PrimExpr
&
index
)
{
if
(
index
.
dtype
().
lanes
()
!=
1
)
{
if
(
index
.
dtype
().
lanes
()
!=
1
)
{
PrimExpr
base
,
local_index
,
group
;
PrimExpr
local_index
,
group
;
CHECK
(
GetRamp1Base
(
index
,
index
.
dtype
().
lanes
(),
&
base
));
std
::
tie
(
local_index
,
group
)
=
SplitIndexByGroup
(
base
);
arith
::
PVar
<
PrimExpr
>
base
;
CHECK
(
arith
::
ramp
(
base
,
1
,
index
.
dtype
().
lanes
()).
Match
(
index
));
std
::
tie
(
local_index
,
group
)
=
SplitIndexByGroup
(
base
.
Eval
());
local_index
=
local_index
=
RampNode
::
make
(
local_index
,
make_const
(
local_index
.
dtype
(),
1
),
index
.
dtype
().
lanes
());
RampNode
::
make
(
local_index
,
make_const
(
local_index
.
dtype
(),
1
),
index
.
dtype
().
lanes
());
return
std
::
make_pair
(
local_index
,
group
);
return
std
::
make_pair
(
local_index
,
group
);
...
...
src/tir/transforms/storage_flatten.cc
View file @
4cebb1c7
...
@@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator {
...
@@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator {
<<
"Prefetch dim should be the same as buffer dim"
;
<<
"Prefetch dim should be the same as buffer dim"
;
int
block_size
=
1
,
int
block_size
=
1
,
elem_cnt
=
cache_line_size_
/
e
.
buffer
->
dtype
.
bytes
(),
elem_cnt
=
cache_line_size_
/
e
.
buffer
->
dtype
.
bytes
();
shape
=
0
;
int
starts
=
op
->
bounds
.
size
()
-
1
;
int
starts
=
op
->
bounds
.
size
()
-
1
;
while
(
starts
>
0
&&
arith
::
GetConstInt
(
e
.
buffer
->
shape
[
starts
],
&
shape
)
&&
elem_cnt
>=
block_size
*
shape
)
{
while
(
starts
>
0
)
{
block_size
*=
shape
;
auto
*
shape_as_int
=
e
.
buffer
->
shape
[
starts
].
as
<
IntImmNode
>
();
if
(
shape_as_int
==
nullptr
||
block_size
*
shape_as_int
->
value
>
elem_cnt
)
break
;
block_size
*=
static_cast
<
int
>
(
shape_as_int
->
value
);
starts
--
;
starts
--
;
}
}
PrimExpr
stride
(
elem_cnt
/
block_size
);
PrimExpr
stride
(
elem_cnt
/
block_size
);
...
...
src/tir/transforms/unroll_loop.cc
View file @
4cebb1c7
...
@@ -51,16 +51,13 @@ class LoopUnroller : public StmtExprMutator {
...
@@ -51,16 +51,13 @@ class LoopUnroller : public StmtExprMutator {
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
"pragma_auto_unroll_max_step"
)
{
if
(
op
->
attr_key
==
"pragma_auto_unroll_max_step"
)
{
int
value
=
0
;
int
value
=
static_cast
<
int
>
(
Downcast
<
Integer
>
(
op
->
value
)
->
value
);
CHECK
(
arith
::
GetConstInt
(
op
->
value
,
&
value
));
std
::
swap
(
value
,
auto_max_step_
);
std
::
swap
(
value
,
auto_max_step_
);
Stmt
ret
=
this
->
VisitStmt
(
op
->
body
);
Stmt
ret
=
this
->
VisitStmt
(
op
->
body
);
std
::
swap
(
value
,
auto_max_step_
);
std
::
swap
(
value
,
auto_max_step_
);
return
ret
;
return
ret
;
}
else
if
(
op
->
attr_key
==
"pragma_unroll_explicit"
)
{
}
else
if
(
op
->
attr_key
==
"pragma_unroll_explicit"
)
{
int
value
=
0
;
bool
explicit_unroll
=
Downcast
<
Integer
>
(
op
->
value
)
->
value
;
CHECK
(
arith
::
GetConstInt
(
op
->
value
,
&
value
));
bool
explicit_unroll
=
value
;
std
::
swap
(
explicit_unroll
,
explicit_unroll_
);
std
::
swap
(
explicit_unroll
,
explicit_unroll_
);
Stmt
ret
=
this
->
VisitStmt
(
op
->
body
);
Stmt
ret
=
this
->
VisitStmt
(
op
->
body
);
std
::
swap
(
explicit_unroll
,
explicit_unroll_
);
std
::
swap
(
explicit_unroll
,
explicit_unroll_
);
...
...
src/tir/transforms/vectorize_loop.cc
View file @
4cebb1c7
...
@@ -519,12 +519,11 @@ class LoopVectorizer : public StmtMutator {
...
@@ -519,12 +519,11 @@ class LoopVectorizer : public StmtMutator {
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
if
(
op
->
for_type
==
ForType
::
Vectorized
)
{
if
(
op
->
for_type
==
ForType
::
Vectorized
)
{
CHECK
(
is_zero
(
op
->
min
));
CHECK
(
is_zero
(
op
->
min
));
int
lanes
=
0
;
auto
*
extent_as_int
=
op
->
extent
.
as
<
IntImmNode
>
();
bool
succ
=
arith
::
GetConstInt
(
op
->
extent
,
&
lanes
);
if
(
!
extent_as_int
||
extent_as_int
->
value
<
1
)
{
if
(
!
succ
||
lanes
<
1
)
{
LOG
(
FATAL
)
<<
"Failed to vectorize loop with extent "
<<
op
->
extent
;
LOG
(
FATAL
)
<<
"Failed to vectorize loop with extent "
<<
op
->
extent
;
}
}
return
Vectorizer
(
op
->
loop_var
,
lanes
)(
op
->
body
);
return
Vectorizer
(
op
->
loop_var
,
static_cast
<
int
>
(
extent_as_int
->
value
)
)(
op
->
body
);
}
else
{
}
else
{
return
StmtMutator
::
VisitStmt_
(
op
);
return
StmtMutator
::
VisitStmt_
(
op
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment