Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4e5c5843
Unverified
Commit
4e5c5843
authored
Apr 03, 2020
by
Haozheng Fan
Committed by
GitHub
Apr 02, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TIR][PASS] dtype rewrite for indexing variables (#5092)
parent
4195b2e2
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
703 additions
and
8 deletions
+703
-8
include/tvm/arith/analyzer.h
+9
-0
include/tvm/tir/ir_pass.h
+9
-0
include/tvm/tir/transform.h
+10
-0
python/tvm/driver/build_module.py
+1
-0
python/tvm/tir/expr.py
+2
-1
python/tvm/tir/ir_builder.py
+4
-2
python/tvm/tir/transform/transform.py
+15
-0
src/arith/const_int_bound.cc
+32
-0
src/target/llvm/codegen_cpu.cc
+1
-1
src/target/llvm/codegen_llvm.cc
+2
-1
src/tir/ir/buffer.cc
+1
-1
src/tir/pass/ffi_api.cc
+1
-0
src/tir/pass/loop_partition.cc
+1
-1
src/tir/pass/unroll_loop.cc
+3
-1
src/tir/transforms/narrow_datatype.cc
+418
-0
tests/python/unittest/test_tir_transform_narrow_datatype.py
+194
-0
No files found.
include/tvm/arith/analyzer.h
View file @
4e5c5843
...
...
@@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer {
ConstIntBound
operator
()(
const
PrimExpr
&
expr
);
/*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound
operator
()(
const
PrimExpr
&
expr
,
std
::
unordered_map
<
const
PrimExprNode
*
,
ConstIntBound
>*
bound
);
/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
...
...
include/tvm/tir/ir_pass.h
View file @
4e5c5843
...
...
@@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt);
Stmt
HoistIfThenElse
(
Stmt
stmt
);
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt
NarrowDataType
(
Stmt
stmt
,
int
target_bits
);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
...
...
include/tvm/tir/transform.h
View file @
4e5c5843
...
...
@@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
*/
TVM_DLL
Pass
LowerWarpMemory
();
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
*
* \return The pass.
*/
TVM_DLL
Pass
NarrowDataType
();
}
// namespace transform
}
// namespace tir
}
// namespace tvm
...
...
python/tvm/driver/build_module.py
View file @
4e5c5843
...
...
@@ -159,6 +159,7 @@ def lower(sch,
# Phase 1
stmt
=
ir_pass
.
RewriteForTensorCore
(
stmt
,
sch
,
binds
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
,
64
,
cfg
.
instrument_bound_checkers
)
stmt
=
ir_pass
.
NarrowDataType
(
stmt
,
32
)
stmt
=
ir_pass
.
CanonicalSimplify
(
stmt
)
for
f
in
lower_phase1
:
stmt
=
f
(
stmt
)
...
...
python/tvm/tir/expr.py
View file @
4e5c5843
...
...
@@ -370,7 +370,8 @@ class IterVar(Object, ExprOp):
raise
TypeError
(
"dom need to be Range"
)
name
=
var
if
var
is
not
None
else
"iter"
var
=
Var
(
name
,
dtype
=
"int32"
)
if
not
isinstance
(
var
,
Var
)
else
var
dtype
=
"int32"
if
dom
is
None
else
dom
.
extent
.
dtype
var
=
Var
(
name
,
dtype
=
dtype
)
if
not
isinstance
(
var
,
Var
)
else
var
self
.
__init_handle_by_constructor__
(
_ffi_api
.
IterVar
,
dom
,
var
,
iter_type
,
thread_tag
)
...
...
python/tvm/tir/ir_builder.py
View file @
4e5c5843
...
...
@@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric):
def
__getitem__
(
self
,
index
):
t
=
DataType
(
self
.
_content_type
)
if
t
.
lanes
>
1
:
index
=
_expr
.
Ramp
(
index
*
t
.
lanes
,
1
,
t
.
lanes
)
base
=
index
*
t
.
lanes
index
=
_expr
.
Ramp
(
base
,
const
(
1
,
base
.
dtype
),
t
.
lanes
)
return
_expr
.
Load
(
self
.
_content_type
,
self
.
_buffer_var
,
index
)
def
__setitem__
(
self
,
index
,
value
):
...
...
@@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric):
value
.
dtype
,
self
.
_content_type
))
t
=
DataType
(
self
.
_content_type
)
if
t
.
lanes
>
1
:
index
=
_expr
.
Ramp
(
index
*
t
.
lanes
,
1
,
t
.
lanes
)
base
=
index
*
t
.
lanes
index
=
_expr
.
Ramp
(
base
,
const
(
1
,
base
.
dtype
),
t
.
lanes
)
self
.
_builder
.
emit
(
_stmt
.
Store
(
self
.
_buffer_var
,
value
,
index
))
...
...
python/tvm/tir/transform/transform.py
View file @
4e5c5843
...
...
@@ -66,3 +66,18 @@ def LowerWarpMemory():
The result pass
"""
return
_ffi_api
.
LowerWarpMemory
()
def
NarrowDataType
():
"""Narrow down PrimExpr datatype in stmt to target_bits.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after StorageFlatten.
"""
return
_ffi_api
.
NarrowDataType
()
src/arith/const_int_bound.cc
View file @
4e5c5843
...
...
@@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl :
res
=
Intersect
(
res
,
info
.
bound
);
}
}
if
(
bound_
)
{
const
PrimExprNode
*
op
=
expr
.
as
<
PrimExprNode
>
();
auto
val
=
bound_
->
find
(
op
);
if
(
val
!=
bound_
->
end
())
{
CHECK
(
val
->
second
->
min_value
==
res
.
min_value
&&
val
->
second
->
max_value
==
res
.
max_value
)
<<
"Detected bound for "
<<
expr
<<
"conflicts with memorization"
;
}
(
*
bound_
)[
op
]
=
ConstIntBound
(
res
.
min_value
,
res
.
max_value
);
}
return
res
;
}
Entry
VisitExpr_
(
const
RampNode
*
op
)
final
{
// op = {base + i * stride | 0 <= i < lanes}
// Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
// Note that `base + i * stride` is linear w.r.t. `i`
// Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
Entry
a
=
VisitExpr
(
op
->
base
);
Entry
b
=
VisitExpr
(
op
->
base
+
(
op
->
lanes
-
1
)
*
op
->
stride
);
return
Union
(
a
,
b
);
}
Entry
VisitExpr_
(
const
CastNode
*
op
)
final
{
Entry
a
=
VisitExpr
(
op
->
value
);
Entry
b
=
Everything
(
op
->
dtype
);
...
...
@@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl :
}
private
:
friend
class
ConstIntBoundAnalyzer
;
// internal variable map
std
::
unordered_map
<
Var
,
Entry
,
ObjectHash
,
ObjectEqual
>
var_map_
;
// additional bound info
std
::
vector
<
BoundInfo
>
additional_info_
;
// look up table for memorization
std
::
unordered_map
<
const
PrimExprNode
*
,
ConstIntBound
>*
bound_
{
nullptr
};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static
const
constexpr
int64_t
kNegInf
=
ConstIntBound
::
kNegInf
;
...
...
@@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
return
ConstIntBound
(
ret
.
min_value
,
ret
.
max_value
);
}
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
PrimExpr
&
expr
,
std
::
unordered_map
<
const
PrimExprNode
*
,
ConstIntBound
>*
bound
)
{
impl_
->
bound_
=
bound
;
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
impl_
->
bound_
=
nullptr
;
return
ConstIntBound
(
ret
.
min_value
,
ret
.
max_value
);
}
void
ConstIntBoundAnalyzer
::
Update
(
const
Var
&
var
,
const
ConstIntBound
&
info
,
bool
override
)
{
...
...
src/target/llvm/codegen_cpu.cc
View file @
4e5c5843
...
...
@@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
PrimExpr
end
=
MinNode
::
make
((
task_id
+
make_const
(
t
,
1
))
*
step
,
op
->
extent
);
CreateSerialFor
(
MakeValue
(
begin
),
MakeValue
(
end
),
ConstInt32
(
1
),
llvm
::
ConstantInt
::
getSigned
(
GetLLVMType
(
end
),
1
),
op
->
loop_var
,
op
->
body
);
}
...
...
src/target/llvm/codegen_llvm.cc
View file @
4e5c5843
...
...
@@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK
(
op
->
for_type
==
ForType
::
Serial
);
}
CreateSerialFor
(
MakeValue
(
op
->
min
),
MakeValue
(
op
->
extent
),
ConstInt32
(
1
),
op
->
loop_var
,
op
->
body
);
llvm
::
ConstantInt
::
getSigned
(
GetLLVMType
(
op
->
extent
),
1
),
op
->
loop_var
,
op
->
body
);
}
...
...
src/tir/ir/buffer.cc
View file @
4e5c5843
...
...
@@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data,
n
->
buffer_type
=
buffer_type
;
if
(
n
->
buffer_type
==
kAutoBroadcast
&&
n
->
shape
.
size
()
>
0
&&
n
->
strides
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
n
->
shape
.
size
();
++
i
)
{
n
->
strides
.
push_back
(
Var
(
"stride"
));
n
->
strides
.
push_back
(
Var
(
"stride"
,
n
->
shape
[
i
].
dtype
()
));
}
}
return
Buffer
(
n
);
...
...
src/tir/pass/ffi_api.cc
View file @
4e5c5843
...
...
@@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS
(
VerifyCompactBuffer
);
REGISTER_PASS
(
HoistIfThenElse
);
REGISTER_PASS
(
InferFragment
)
REGISTER_PASS
(
NarrowDataType
);
}
// namespace tir
}
// namespace tvm
src/tir/pass/loop_partition.cc
View file @
4e5c5843
...
...
@@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
// If the loop extent is 1, do not create the loop anymore
return
Substitute
(
body
,
{{
Var
{
for_node
->
loop_var
},
make_const
(
DataType
::
Int
(
32
),
0
)}});
}
else
{
return
ForNode
::
make
(
for_node
->
loop_var
,
0
,
extent
,
return
ForNode
::
make
(
for_node
->
loop_var
,
IntImm
(
for_node
->
min
.
dtype
(),
0
)
,
extent
,
for_node
->
for_type
,
for_node
->
device_api
,
body
);
}
}
...
...
src/tir/pass/unroll_loop.cc
View file @
4e5c5843
...
...
@@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator {
PrimExpr
extent
=
tir
::
Simplify
(
op
->
extent
);
const
IntImmNode
*
v1
=
extent
.
as
<
IntImmNode
>
();
int
value
=
-
1
;
if
(
v1
!=
nullptr
)
{
// integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
if
(
v1
!=
nullptr
&&
v1
->
value
<=
std
::
numeric_limits
<
int
>::
max
())
{
value
=
static_cast
<
int
>
(
v1
->
value
);
}
return
value
;
...
...
src/tir/transforms/narrow_datatype.cc
0 → 100644
View file @
4e5c5843
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file narrow_datatype.cc
* \brief narrow the datatype of indexing vars
*/
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/ir_visitor_with_analyzer.h"
namespace
tvm
{
namespace
tir
{
// This pass narrows indexing expressions (like StoreNode::Index)
// that trivially fit into i32/i16 (denoted by `target_bits_`) to
// i32/i16. Considering that i32/i16 indices may be more
// efficient on some backends (while i64 may be more efficient
// on others, like llvm), we may want this pass when i32/i16
// indices are more efficient.
//
// For Var v, we determine its dtype by examining all the PrimExpr
// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
// If all expressions in E fit into i32/i16, then we think v can be narrowed
// to i32/i16.
//
// To make an indexing expression i32/i16, we must make sure that every
// component of that expression is of dtype i32/i16. So besides Var, we
// rewrite the following inside an indexing expression
// - Var
// - IntImm
// - Cast
//
// Algorithm:
// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
// - Use DataTypeRewritter to rewrite the components of an indexing expression.
using
arith
::
Analyzer
;
using
arith
::
IRMutatorWithAnalyzer
;
using
arith
::
ConstIntBound
;
// Determine the result dtype for Var, IntImm and Cast,
// which will be stored in `vmap` eventually.
//
// Algorithm:
// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
// To be more specific, if for each Expr `e` which contains `var`
// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
// then we narrow `var` into `target_bits_`. That is,
// `vmap[var] = min(target_bits_, var.dtype.bits())`
// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
class
DataTypeVisitor
final
:
public
StmtExprVisitor
{
public
:
explicit
DataTypeVisitor
(
int
target_bits
)
:
bits_
(
target_bits
),
target_bits_
(
target_bits
)
{}
void
VisitExpr
(
const
PrimExpr
&
e
)
{
if
(
e
.
dtype
().
is_int
())
{
int
bits
=
max_bits_
;
const
PrimExprNode
*
op
=
e
.
as
<
PrimExprNode
>
();
if
(
bound_
.
find
(
op
)
==
bound_
.
end
())
{
analyzer_
.
const_int_bound
(
e
,
&
bound_
);
}
ConstIntBound
bound
=
bound_
[
op
];
int64_t
ubound
=
Downcast
<
IntImm
>
(
max_value
(
DataType
::
Int
(
target_bits_
)))
->
value
;
int64_t
lbound
=
Downcast
<
IntImm
>
(
min_value
(
DataType
::
Int
(
target_bits_
)))
->
value
;
if
(
e
.
dtype
().
bits
()
<=
target_bits_
||
(
bound
->
max_value
<=
ubound
&&
bound
->
min_value
>=
lbound
))
{
bits
=
target_bits_
;
}
int
tmp
=
bits
>
bits_
?
bits
:
bits_
;
std
::
swap
(
bits_
,
tmp
);
StmtExprVisitor
::
VisitExpr
(
e
);
std
::
swap
(
bits_
,
tmp
);
}
else
{
StmtExprVisitor
::
VisitExpr
(
e
);
}
}
void
VisitStmt_
(
const
ForNode
*
op
)
{
analyzer_
.
Bind
(
op
->
loop_var
,
Range
::
make_by_min_extent
(
op
->
min
,
op
->
extent
));
vextent_
[
op
->
loop_var
.
as
<
VarNode
>
()]
=
op
->
extent
.
dtype
();
return
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
{
if
(
op
->
attr_key
==
attr
::
thread_extent
||
op
->
attr_key
==
attr
::
virtual_thread
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
CHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
analyzer_
.
Bind
(
iv
->
var
,
Range
::
make_by_min_extent
(
0
,
op
->
value
));
vextent_
[
iv
->
var
.
as
<
VarNode
>
()]
=
op
->
value
.
dtype
();
StmtExprVisitor
::
VisitStmt_
(
op
);
}
else
{
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
void
VisitExpr_
(
const
ReduceNode
*
op
)
{
// Setup the domain information before simplification.
for
(
const
IterVar
&
iv
:
op
->
axis
)
{
analyzer_
.
Bind
(
iv
->
var
,
iv
->
dom
);
vextent_
[
iv
->
var
.
as
<
VarNode
>
()]
=
iv
->
dom
->
extent
.
dtype
();
}
// Recursively call simplification when necessary.
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitExpr_
(
const
VarNode
*
op
)
{
if
(
vextent_
.
find
(
op
)
!=
vextent_
.
end
())
{
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int
bits
=
std
::
min
(
vextent_
[
op
].
bits
(),
bits_
);
if
(
vmap
.
find
(
op
)
==
vmap
.
end
())
{
vmap
[
op
]
=
op
->
dtype
.
with_bits
(
bits
);
}
else
{
// We take maximum bits for all the possible Expr where a var occurs
vmap
[
op
]
=
op
->
dtype
.
with_bits
(
std
::
max
(
vmap
[
op
].
bits
(),
bits
));
}
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitExpr_
(
const
IntImmNode
*
op
)
{
if
(
op
->
dtype
.
is_int
())
{
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int
bits
=
std
::
min
(
op
->
dtype
.
bits
(),
bits_
);
if
(
vmap
.
find
(
op
)
==
vmap
.
end
())
{
vmap
[
op
]
=
op
->
dtype
.
with_bits
(
bits
);
}
else
{
vmap
[
op
]
=
op
->
dtype
.
with_bits
(
std
::
max
(
vmap
[
op
].
bits
(),
bits
));
}
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitExpr_
(
const
CastNode
*
op
)
{
if
(
op
->
dtype
.
is_int
())
{
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int
bits
=
std
::
min
(
op
->
dtype
.
bits
(),
bits_
);
if
(
vmap
.
find
(
op
)
==
vmap
.
end
())
{
vmap
[
op
]
=
op
->
dtype
.
with_bits
(
bits
);
}
else
{
vmap
[
op
]
=
op
->
dtype
.
with_bits
(
std
::
max
(
vmap
[
op
].
bits
(),
bits
));
}
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
// the narrowed datatype of Var and IntImm
std
::
unordered_map
<
const
PrimExprNode
*
,
DataType
>
vmap
;
protected
:
// internal analyzer
arith
::
Analyzer
analyzer_
;
private
:
// the maximum possible bits, which serves as an init value
static
constexpr
const
int
max_bits_
=
64
;
// the maximum possible bit of the current expression's return dtype
int
bits_
;
// the target bits
int
target_bits_
;
// the extent of vars to be rewritten
std
::
unordered_map
<
const
VarNode
*
,
DataType
>
vextent_
;
// the memorized bound generated by ConstIntBoundAnalyzer
std
::
unordered_map
<
const
PrimExprNode
*
,
ConstIntBound
>
bound_
;
};
class
DataTypeRewriter
:
public
StmtExprMutator
{
public
:
explicit
DataTypeRewriter
(
int
target_bits
)
:
visitor_
(
target_bits
)
{}
Stmt
operator
()(
Stmt
s
)
{
visitor_
(
s
);
for
(
auto
i
=
visitor_
.
vmap
.
begin
(),
last
=
visitor_
.
vmap
.
end
();
i
!=
last
;)
{
PrimExpr
e
=
GetRef
<
PrimExpr
>
(
i
->
first
);
if
(
e
.
dtype
()
==
i
->
second
)
{
i
=
visitor_
.
vmap
.
erase
(
i
);
}
else
{
++
i
;
}
}
return
VisitStmt
(
s
);
}
Stmt
VisitStmt_
(
const
StoreNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
is_index_
=
true
;
PrimExpr
index
=
this
->
VisitExpr
(
op
->
index
);
is_index_
=
false
;
Stmt
s
=
StoreNode
::
make
(
op
->
buffer_var
,
op
->
value
,
index
,
op
->
predicate
);
return
StmtExprMutator
::
VisitStmt_
(
s
.
as
<
StoreNode
>
());
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
s
=
StmtExprMutator
::
VisitStmt_
(
op
);
op
=
s
.
as
<
ForNode
>
();
CHECK
(
op
!=
nullptr
)
<<
"Expected type to be ForNode"
<<
", but get "
<<
s
->
GetTypeKey
();
PrimExpr
e
=
VisitExpr
(
op
->
loop_var
);
Var
var
=
Downcast
<
Var
>
(
e
);
return
ForNode
::
make
(
var
,
cast
(
var
.
dtype
(),
op
->
min
),
cast
(
var
.
dtype
(),
op
->
extent
),
op
->
for_type
,
op
->
device_api
,
op
->
body
);
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
attr
::
thread_extent
||
op
->
attr_key
==
attr
::
virtual_thread
)
{
Stmt
s
=
StmtExprMutator
::
VisitStmt_
(
op
);
op
=
s
.
as
<
AttrStmtNode
>
();
CHECK
(
op
!=
nullptr
)
<<
"Expected type to be AttrStmtNode"
<<
", but get "
<<
s
->
GetTypeKey
();
const
IterVarNode
*
iv
=
op
->
node
.
as
<
IterVarNode
>
();
CHECK
(
iv
!=
nullptr
)
<<
"Expected type to be IterVarNode"
<<
", but get "
<<
op
->
node
->
GetTypeKey
();
PrimExpr
e
=
VisitExpr
(
iv
->
var
);
Var
var
=
Downcast
<
Var
>
(
e
);
if
(
ivmap_
.
find
(
iv
)
==
ivmap_
.
end
())
{
ivmap_
[
iv
]
=
IterVarNode
::
make
(
iv
->
dom
,
var
,
iv
->
iter_type
,
iv
->
thread_tag
);
}
return
AttrStmtNode
::
make
(
ivmap_
[
iv
],
op
->
attr_key
,
cast
(
var
.
dtype
(),
op
->
value
),
op
->
body
);
}
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
if
(
visitor_
.
vmap
.
find
(
op
)
!=
visitor_
.
vmap
.
end
())
{
if
(
vmap_
.
find
(
op
)
==
vmap_
.
end
())
{
vmap_
[
op
]
=
Var
(
op
->
name_hint
,
visitor_
.
vmap
[
op
]);
}
return
vmap_
[
op
];
}
return
StmtExprMutator
::
VisitExpr_
(
op
);
}
PrimExpr
VisitExpr_
(
const
SizeVarNode
*
op
)
final
{
if
(
visitor_
.
vmap
.
find
(
op
)
!=
visitor_
.
vmap
.
end
())
{
if
(
vmap_
.
find
(
op
)
==
vmap_
.
end
())
{
vmap_
[
op
]
=
SizeVar
(
op
->
name_hint
,
visitor_
.
vmap
[
op
]);
}
return
vmap_
[
op
];
}
return
StmtExprMutator
::
VisitExpr_
(
op
);
}
PrimExpr
VisitExpr_
(
const
LoadNode
*
op
)
final
{
is_index_
=
true
;
PrimExpr
index
=
this
->
VisitExpr
(
op
->
index
);
is_index_
=
false
;
PrimExpr
e
=
LoadNode
::
make
(
op
->
dtype
,
op
->
buffer_var
,
index
,
op
->
predicate
);
return
StmtExprMutator
::
VisitExpr_
(
e
.
as
<
LoadNode
>
());
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
if
(
is_index_
)
{
if
(
visitor_
.
vmap
.
find
(
op
)
!=
visitor_
.
vmap
.
end
())
{
return
IntImm
(
visitor_
.
vmap
[
op
],
op
->
value
);
}
}
return
StmtExprMutator
::
VisitExpr_
(
op
);
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
if
(
is_index_
&&
visitor_
.
vmap
.
find
(
op
)
!=
visitor_
.
vmap
.
end
())
{
PrimExpr
e
=
StmtExprMutator
::
VisitExpr_
(
op
);
const
CastNode
*
new_op
=
e
.
as
<
CastNode
>
();
CHECK
(
new_op
!=
nullptr
)
<<
"Expected type to be CastNode"
<<
", but get "
<<
e
->
GetTypeKey
();
return
CastNode
::
make
(
visitor_
.
vmap
[
op
],
new_op
->
value
);
}
return
StmtExprMutator
::
VisitExpr_
(
op
);
}
PrimExpr
VisitExpr_
(
const
AddNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
SubNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
MulNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
DivNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
ModNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
FloorDivNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
FloorModNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
MinNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
MaxNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
EQNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
NENode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
LTNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
LENode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
GTNode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
GENode
*
op
)
final
;
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
;
private
:
// the internal visitor to deduce the narrowed dtype
DataTypeVisitor
visitor_
;
// a map from Var before rewrite to that after rewrite,
// ensures one old Var maps to exactly one new Var
std
::
unordered_map
<
const
VarNode
*
,
Var
>
vmap_
;
// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std
::
unordered_map
<
const
IterVarNode
*
,
IterVar
>
ivmap_
;
// indicator of LoadNode::index and StoreNode::index
bool
is_index_
{
false
};
};
#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \
PrimExpr a = this->VisitExpr(op->a); \
PrimExpr b = this->VisitExpr(op->b); \
if (a.same_as(op->a) && \
b.same_as(op->b)) { \
return GetRef<PrimExpr>(op); \
} else { \
return FUNC(a, b); \
} \
}
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
AddNode
,
operator
+
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
SubNode
,
operator
-
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
MulNode
,
operator
*
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
DivNode
,
div
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
ModNode
,
truncmod
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
FloorDivNode
,
floordiv
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
FloorModNode
,
floormod
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
MinNode
,
min
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
MaxNode
,
max
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
EQNode
,
operator
==
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
NENode
,
operator
!=
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
LTNode
,
operator
<
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
LENode
,
operator
<=
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
GTNode
,
operator
>
)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
(
GENode
,
operator
>=
)
PrimExpr
DataTypeRewriter
::
VisitExpr_
(
const
CallNode
*
op
)
{
PrimExpr
e
=
StmtExprMutator
::
VisitExpr_
(
op
);
op
=
e
.
as
<
CallNode
>
();
CHECK
(
op
!=
nullptr
)
<<
"Expected type to be CallNode"
<<
", but get "
<<
e
->
GetTypeKey
();
if
(
op
->
call_type
==
CallNode
::
PureIntrinsic
)
{
if
(
op
->
name
==
intrinsic
::
tvm_if_then_else
)
{
return
if_then_else
(
op
->
args
[
0
],
op
->
args
[
1
],
op
->
args
[
2
]);
}
else
if
(
op
->
name
==
CallNode
::
shift_right
)
{
return
op
->
args
[
0
]
>>
op
->
args
[
1
];
}
else
if
(
op
->
name
==
CallNode
::
shift_left
)
{
return
op
->
args
[
0
]
<<
op
->
args
[
1
];
}
else
if
(
op
->
name
==
CallNode
::
bitwise_and
)
{
return
op
->
args
[
0
]
&
op
->
args
[
1
];
}
else
if
(
op
->
name
==
CallNode
::
bitwise_or
)
{
return
op
->
args
[
0
]
|
op
->
args
[
1
];
}
else
if
(
op
->
name
==
CallNode
::
bitwise_xor
)
{
return
op
->
args
[
0
]
^
op
->
args
[
1
];
}
else
if
(
op
->
name
==
"pow"
)
{
return
pow
(
op
->
args
[
0
],
op
->
args
[
1
]);
}
}
return
e
;
}
Stmt
NarrowDataType
(
Stmt
stmt
,
int
target_bits
)
{
return
DataTypeRewriter
(
target_bits
)(
stmt
);
}
namespace
transform
{
Pass
NarrowDataType
()
{
auto
pass_func
=
[](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
*
n
=
f
.
CopyOnWrite
();
IntImm
target_bits
=
f
->
GetAttr
<
IntImm
>
(
"target_bits"
);
CHECK
(
target_bits
.
defined
())
<<
"NarrowDataType: Require the target_bits"
;
n
->
body
=
DataTypeRewriter
(
target_bits
->
value
)(
std
::
move
(
n
->
body
));
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tir.LowerDeviceStorageAccessInfo"
,
{});
}
TVM_REGISTER_GLOBAL
(
"tir.transform.NarrowDataType"
)
.
set_body_typed
(
NarrowDataType
);
}
// namespace transform
}
// namespace tir
}
// namespace tvm
tests/python/unittest/test_tir_transform_narrow_datatype.py
0 → 100644
View file @
4e5c5843
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
from
tvm
import
te
from
tvm.tir
import
const
def
lower_stmt
(
params
,
stmt
,
target_bits
):
func
=
tvm
.
tir
.
PrimFunc
(
params
,
stmt
)
.
with_attr
(
"target_bits"
,
target_bits
)
func
=
tvm
.
tir
.
transform
.
NarrowDataType
()(
tvm
.
IRModule
.
from_expr
(
func
))[
"main"
]
stmt
=
func
.
body
return
stmt
def
lower_sch
(
sch
,
args
,
target_bits
):
binds
=
{}
arg_list
=
[]
for
x
in
args
:
if
isinstance
(
x
,
te
.
tensor
.
Tensor
):
buf
=
tvm
.
tir
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
assert
x
not
in
binds
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
bounds
=
te
.
schedule
.
InferBound
(
sch
)
stmt
=
te
.
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
tvm
.
tir
.
ir_pass
.
StorageFlatten
(
stmt
,
binds
,
64
,
False
)
return
lower_stmt
(
arg_list
,
stmt
,
target_bits
)
def
test_basic
():
def
check
(
m
,
n
,
target_bits
,
target_dtype
):
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
with
ib
.
for_range
(
0
,
m
,
name
=
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
name
=
'j'
)
as
j
:
B
[
i
*
n
+
j
]
=
A
[
i
*
n
+
j
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
loop_var
.
dtype
==
target_dtype
assert
stmt
.
body
.
loop_var
.
dtype
==
target_dtype
# const shape
# i32 -> i32
check
(
2
,
2
,
32
,
"int32"
)
check
(
2
**
16
,
2
**
16
,
32
,
"int32"
)
# i32 + i32 is not promoted to i64 even if overflow
# i64 -> i32
check
(
const
(
2
,
dtype
=
'int64'
),
const
(
2
,
dtype
=
'int64'
),
32
,
"int32"
)
check
(
const
(
2
**
16
,
dtype
=
'int64'
),
const
(
2
**
16
,
dtype
=
'int64'
),
32
,
"int64"
)
# i32 -> i16
check
(
2
,
2
,
16
,
"int16"
)
check
(
2
**
10
,
2
**
10
,
16
,
"int32"
)
# symbolic shape
check
(
te
.
size_var
(
name
=
'm'
,
dtype
=
'int32'
),
te
.
size_var
(
name
=
'n'
,
dtype
=
'int32'
),
32
,
"int32"
)
check
(
te
.
size_var
(
name
=
'm'
,
dtype
=
'int64'
),
te
.
size_var
(
name
=
'n'
,
dtype
=
'int64'
),
32
,
"int64"
)
def
test_thread_axis
():
def
check
(
m
,
n
,
target_bits
,
target_dtype
):
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
bx
=
te
.
thread_axis
(
"blockIdx.x"
)
tx
=
te
.
thread_axis
(
"threadIdx.x"
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
m
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
n
)
B
[
bx
*
n
+
tx
]
=
A
[
bx
*
n
+
tx
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
node
.
var
.
dtype
==
target_dtype
assert
stmt
.
body
.
node
.
var
.
dtype
==
target_dtype
# i32 -> i32
check
(
2
,
32
,
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
2
**
30
,
32
,
# i32 + i32 is not promoted to i64 even in the case of overflow
target_bits
=
32
,
target_dtype
=
'int32'
)
# i64 -> i32
check
(
const
(
2
,
dtype
=
'int64'
),
const
(
32
,
dtype
=
'int64'
),
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
const
(
2
**
30
,
dtype
=
'int64'
),
const
(
32
,
dtype
=
'int64'
),
target_bits
=
32
,
target_dtype
=
'int64'
)
# i32 -> i16
check
(
2
,
32
,
target_bits
=
16
,
target_dtype
=
'int16'
)
check
(
2
**
14
,
32
,
target_bits
=
16
,
target_dtype
=
'int32'
)
def
test_multilanes
():
def
check
(
m
,
lanes
,
target_bits
,
target_dtype
):
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,),
dtype
=
'float32x{}'
.
format
(
lanes
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,),
dtype
=
'float32x{}'
.
format
(
lanes
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
with
ib
.
for_range
(
0
,
m
,
name
=
'i'
,
dtype
=
m
.
dtype
)
as
i
:
B
[
i
]
=
A
[
i
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
loop_var
.
dtype
==
target_dtype
# i32 -> i32
check
(
const
(
2
**
10
,
dtype
=
'int32'
),
2
,
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
const
(
2
**
32
,
dtype
=
'int32'
),
2
,
target_bits
=
32
,
target_dtype
=
'int32'
)
# i64 -> i32
check
(
const
(
2
**
10
,
dtype
=
'int64'
),
2
,
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
const
(
2
**
32
,
dtype
=
'int64'
),
2
,
target_bits
=
32
,
target_dtype
=
'int64'
)
# i32 -> i16
check
(
const
(
2
**
10
,
dtype
=
'int32'
),
2
,
target_bits
=
16
,
target_dtype
=
'int16'
)
check
(
const
(
2
**
16
,
dtype
=
'int32'
),
2
,
target_bits
=
16
,
target_dtype
=
'int32'
)
def
test_reduce
():
def
check
(
m
,
target_bits
,
target_dtype
):
A
=
te
.
placeholder
((
m
,),
name
=
'A'
,
dtype
=
'float32'
)
k
=
te
.
reduce_axis
((
0
,
m
),
"k"
)
B
=
te
.
compute
((),
lambda
*
idx
:
te
.
sum
(
A
[
k
],
axis
=
k
),
name
=
'B'
)
s
=
te
.
create_schedule
(
B
.
op
)
stmt
=
lower_sch
(
s
,
[
A
,
B
],
target_bits
)
assert
stmt
.
body
[
1
]
.
loop_var
.
dtype
==
target_dtype
# i32 -> i32
check
(
const
(
64
,
dtype
=
'int32'
),
32
,
'int32'
)
# i64 -> i32
check
(
const
(
64
,
dtype
=
'int64'
),
32
,
'int32'
)
# i32 -> i16
check
(
const
(
64
,
dtype
=
'int32'
),
16
,
'int16'
)
check
(
const
(
2
**
16
,
dtype
=
'int32'
),
16
,
'int32'
)
# symbolic
check
(
te
.
var
(
'n'
,
dtype
=
'int32'
),
32
,
'int32'
)
check
(
te
.
var
(
'n'
,
dtype
=
'int64'
),
32
,
'int64'
)
def
test_slice
():
def
check
(
m
,
n
,
target_bits
,
target_dtype
):
# The index may overflow in B, while not in A
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,
n
*
2
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
with
ib
.
for_range
(
0
,
m
,
name
=
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
name
=
'j'
)
as
j
:
A
[
i
*
n
+
j
]
=
B
[
i
*
2
*
n
+
2
*
j
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
loop_var
.
dtype
==
target_dtype
assert
stmt
.
body
.
loop_var
.
dtype
==
target_dtype
# The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
check
(
const
(
2
**
15
,
'int64'
),
const
(
2
**
15
,
'int64'
),
target_bits
=
32
,
target_dtype
=
'int32'
)
# The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
check
(
const
(
2
**
15
,
'int64'
),
const
((
2
**
15
+
1
),
'int64'
),
target_bits
=
32
,
target_dtype
=
'int64'
)
if
__name__
==
"__main__"
:
test_basic
()
test_thread_axis
()
test_multilanes
()
test_reduce
()
test_slice
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment