Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4e5c5843
Unverified
Commit
4e5c5843
authored
Apr 03, 2020
by
Haozheng Fan
Committed by
GitHub
Apr 02, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TIR][PASS] dtype rewrite for indexing variables (#5092)
parent
4195b2e2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
285 additions
and
8 deletions
+285
-8
include/tvm/arith/analyzer.h
+9
-0
include/tvm/tir/ir_pass.h
+9
-0
include/tvm/tir/transform.h
+10
-0
python/tvm/driver/build_module.py
+1
-0
python/tvm/tir/expr.py
+2
-1
python/tvm/tir/ir_builder.py
+4
-2
python/tvm/tir/transform/transform.py
+15
-0
src/arith/const_int_bound.cc
+32
-0
src/target/llvm/codegen_cpu.cc
+1
-1
src/target/llvm/codegen_llvm.cc
+2
-1
src/tir/ir/buffer.cc
+1
-1
src/tir/pass/ffi_api.cc
+1
-0
src/tir/pass/loop_partition.cc
+1
-1
src/tir/pass/unroll_loop.cc
+3
-1
src/tir/transforms/narrow_datatype.cc
+0
-0
tests/python/unittest/test_tir_transform_narrow_datatype.py
+194
-0
No files found.
include/tvm/arith/analyzer.h
View file @
4e5c5843
...
...
@@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer {
ConstIntBound
operator
()(
const
PrimExpr
&
expr
);
/*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound
operator
()(
const
PrimExpr
&
expr
,
std
::
unordered_map
<
const
PrimExprNode
*
,
ConstIntBound
>*
bound
);
/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
...
...
include/tvm/tir/ir_pass.h
View file @
4e5c5843
...
...
@@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt);
Stmt
HoistIfThenElse
(
Stmt
stmt
);
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt
NarrowDataType
(
Stmt
stmt
,
int
target_bits
);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
...
...
include/tvm/tir/transform.h
View file @
4e5c5843
...
...
@@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
*/
TVM_DLL
Pass
LowerWarpMemory
();
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
*
* \return The pass.
*/
TVM_DLL
Pass
NarrowDataType
();
}
// namespace transform
}
// namespace tir
}
// namespace tvm
...
...
python/tvm/driver/build_module.py
View file @
4e5c5843
...
...
@@ -159,6 +159,7 @@ def lower(sch,
# Phase 1
stmt
=
ir_pass
.
RewriteForTensorCore
(
stmt
,
sch
,
binds
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
,
64
,
cfg
.
instrument_bound_checkers
)
stmt
=
ir_pass
.
NarrowDataType
(
stmt
,
32
)
stmt
=
ir_pass
.
CanonicalSimplify
(
stmt
)
for
f
in
lower_phase1
:
stmt
=
f
(
stmt
)
...
...
python/tvm/tir/expr.py
View file @
4e5c5843
...
...
@@ -370,7 +370,8 @@ class IterVar(Object, ExprOp):
raise
TypeError
(
"dom need to be Range"
)
name
=
var
if
var
is
not
None
else
"iter"
var
=
Var
(
name
,
dtype
=
"int32"
)
if
not
isinstance
(
var
,
Var
)
else
var
dtype
=
"int32"
if
dom
is
None
else
dom
.
extent
.
dtype
var
=
Var
(
name
,
dtype
=
dtype
)
if
not
isinstance
(
var
,
Var
)
else
var
self
.
__init_handle_by_constructor__
(
_ffi_api
.
IterVar
,
dom
,
var
,
iter_type
,
thread_tag
)
...
...
python/tvm/tir/ir_builder.py
View file @
4e5c5843
...
...
@@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric):
def
__getitem__
(
self
,
index
):
t
=
DataType
(
self
.
_content_type
)
if
t
.
lanes
>
1
:
index
=
_expr
.
Ramp
(
index
*
t
.
lanes
,
1
,
t
.
lanes
)
base
=
index
*
t
.
lanes
index
=
_expr
.
Ramp
(
base
,
const
(
1
,
base
.
dtype
),
t
.
lanes
)
return
_expr
.
Load
(
self
.
_content_type
,
self
.
_buffer_var
,
index
)
def
__setitem__
(
self
,
index
,
value
):
...
...
@@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric):
value
.
dtype
,
self
.
_content_type
))
t
=
DataType
(
self
.
_content_type
)
if
t
.
lanes
>
1
:
index
=
_expr
.
Ramp
(
index
*
t
.
lanes
,
1
,
t
.
lanes
)
base
=
index
*
t
.
lanes
index
=
_expr
.
Ramp
(
base
,
const
(
1
,
base
.
dtype
),
t
.
lanes
)
self
.
_builder
.
emit
(
_stmt
.
Store
(
self
.
_buffer_var
,
value
,
index
))
...
...
python/tvm/tir/transform/transform.py
View file @
4e5c5843
...
...
@@ -66,3 +66,18 @@ def LowerWarpMemory():
The result pass
"""
return
_ffi_api
.
LowerWarpMemory
()
def
NarrowDataType
():
"""Narrow down PrimExpr datatype in stmt to target_bits.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after StorageFlatten.
"""
return
_ffi_api
.
NarrowDataType
()
src/arith/const_int_bound.cc
View file @
4e5c5843
...
...
@@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl :
res
=
Intersect
(
res
,
info
.
bound
);
}
}
if
(
bound_
)
{
const
PrimExprNode
*
op
=
expr
.
as
<
PrimExprNode
>
();
auto
val
=
bound_
->
find
(
op
);
if
(
val
!=
bound_
->
end
())
{
CHECK
(
val
->
second
->
min_value
==
res
.
min_value
&&
val
->
second
->
max_value
==
res
.
max_value
)
<<
"Detected bound for "
<<
expr
<<
"conflicts with memorization"
;
}
(
*
bound_
)[
op
]
=
ConstIntBound
(
res
.
min_value
,
res
.
max_value
);
}
return
res
;
}
Entry
VisitExpr_
(
const
RampNode
*
op
)
final
{
// op = {base + i * stride | 0 <= i < lanes}
// Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
// Note that `base + i * stride` is linear w.r.t. `i`
// Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
Entry
a
=
VisitExpr
(
op
->
base
);
Entry
b
=
VisitExpr
(
op
->
base
+
(
op
->
lanes
-
1
)
*
op
->
stride
);
return
Union
(
a
,
b
);
}
Entry
VisitExpr_
(
const
CastNode
*
op
)
final
{
Entry
a
=
VisitExpr
(
op
->
value
);
Entry
b
=
Everything
(
op
->
dtype
);
...
...
@@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl :
}
private
:
friend
class
ConstIntBoundAnalyzer
;
// internal variable map
std
::
unordered_map
<
Var
,
Entry
,
ObjectHash
,
ObjectEqual
>
var_map_
;
// additional bound info
std
::
vector
<
BoundInfo
>
additional_info_
;
// look up table for memorization
std
::
unordered_map
<
const
PrimExprNode
*
,
ConstIntBound
>*
bound_
{
nullptr
};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static
const
constexpr
int64_t
kNegInf
=
ConstIntBound
::
kNegInf
;
...
...
@@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
return
ConstIntBound
(
ret
.
min_value
,
ret
.
max_value
);
}
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
PrimExpr
&
expr
,
std
::
unordered_map
<
const
PrimExprNode
*
,
ConstIntBound
>*
bound
)
{
impl_
->
bound_
=
bound
;
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
impl_
->
bound_
=
nullptr
;
return
ConstIntBound
(
ret
.
min_value
,
ret
.
max_value
);
}
void
ConstIntBoundAnalyzer
::
Update
(
const
Var
&
var
,
const
ConstIntBound
&
info
,
bool
override
)
{
...
...
src/target/llvm/codegen_cpu.cc
View file @
4e5c5843
...
...
@@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
PrimExpr
end
=
MinNode
::
make
((
task_id
+
make_const
(
t
,
1
))
*
step
,
op
->
extent
);
CreateSerialFor
(
MakeValue
(
begin
),
MakeValue
(
end
),
ConstInt32
(
1
),
llvm
::
ConstantInt
::
getSigned
(
GetLLVMType
(
end
),
1
),
op
->
loop_var
,
op
->
body
);
}
...
...
src/target/llvm/codegen_llvm.cc
View file @
4e5c5843
...
...
@@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK
(
op
->
for_type
==
ForType
::
Serial
);
}
CreateSerialFor
(
MakeValue
(
op
->
min
),
MakeValue
(
op
->
extent
),
ConstInt32
(
1
),
op
->
loop_var
,
op
->
body
);
llvm
::
ConstantInt
::
getSigned
(
GetLLVMType
(
op
->
extent
),
1
),
op
->
loop_var
,
op
->
body
);
}
...
...
src/tir/ir/buffer.cc
View file @
4e5c5843
...
...
@@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data,
n
->
buffer_type
=
buffer_type
;
if
(
n
->
buffer_type
==
kAutoBroadcast
&&
n
->
shape
.
size
()
>
0
&&
n
->
strides
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
n
->
shape
.
size
();
++
i
)
{
n
->
strides
.
push_back
(
Var
(
"stride"
));
n
->
strides
.
push_back
(
Var
(
"stride"
,
n
->
shape
[
i
].
dtype
()
));
}
}
return
Buffer
(
n
);
...
...
src/tir/pass/ffi_api.cc
View file @
4e5c5843
...
...
@@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS
(
VerifyCompactBuffer
);
REGISTER_PASS
(
HoistIfThenElse
);
REGISTER_PASS
(
InferFragment
)
REGISTER_PASS
(
NarrowDataType
);
}
// namespace tir
}
// namespace tvm
src/tir/pass/loop_partition.cc
View file @
4e5c5843
...
...
@@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
// If the loop extent is 1, do not create the loop anymore
return
Substitute
(
body
,
{{
Var
{
for_node
->
loop_var
},
make_const
(
DataType
::
Int
(
32
),
0
)}});
}
else
{
return
ForNode
::
make
(
for_node
->
loop_var
,
0
,
extent
,
return
ForNode
::
make
(
for_node
->
loop_var
,
IntImm
(
for_node
->
min
.
dtype
(),
0
)
,
extent
,
for_node
->
for_type
,
for_node
->
device_api
,
body
);
}
}
...
...
src/tir/pass/unroll_loop.cc
View file @
4e5c5843
...
...
@@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator {
PrimExpr
extent
=
tir
::
Simplify
(
op
->
extent
);
const
IntImmNode
*
v1
=
extent
.
as
<
IntImmNode
>
();
int
value
=
-
1
;
if
(
v1
!=
nullptr
)
{
// integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
if
(
v1
!=
nullptr
&&
v1
->
value
<=
std
::
numeric_limits
<
int
>::
max
())
{
value
=
static_cast
<
int
>
(
v1
->
value
);
}
return
value
;
...
...
src/tir/transforms/narrow_datatype.cc
0 → 100644
View file @
4e5c5843
This diff is collapsed.
Click to expand it.
tests/python/unittest/test_tir_transform_narrow_datatype.py
0 → 100644
View file @
4e5c5843
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
from
tvm
import
te
from
tvm.tir
import
const
def
lower_stmt
(
params
,
stmt
,
target_bits
):
func
=
tvm
.
tir
.
PrimFunc
(
params
,
stmt
)
.
with_attr
(
"target_bits"
,
target_bits
)
func
=
tvm
.
tir
.
transform
.
NarrowDataType
()(
tvm
.
IRModule
.
from_expr
(
func
))[
"main"
]
stmt
=
func
.
body
return
stmt
def
lower_sch
(
sch
,
args
,
target_bits
):
binds
=
{}
arg_list
=
[]
for
x
in
args
:
if
isinstance
(
x
,
te
.
tensor
.
Tensor
):
buf
=
tvm
.
tir
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
assert
x
not
in
binds
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
bounds
=
te
.
schedule
.
InferBound
(
sch
)
stmt
=
te
.
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
tvm
.
tir
.
ir_pass
.
StorageFlatten
(
stmt
,
binds
,
64
,
False
)
return
lower_stmt
(
arg_list
,
stmt
,
target_bits
)
def
test_basic
():
def
check
(
m
,
n
,
target_bits
,
target_dtype
):
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
with
ib
.
for_range
(
0
,
m
,
name
=
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
name
=
'j'
)
as
j
:
B
[
i
*
n
+
j
]
=
A
[
i
*
n
+
j
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
loop_var
.
dtype
==
target_dtype
assert
stmt
.
body
.
loop_var
.
dtype
==
target_dtype
# const shape
# i32 -> i32
check
(
2
,
2
,
32
,
"int32"
)
check
(
2
**
16
,
2
**
16
,
32
,
"int32"
)
# i32 + i32 is not promoted to i64 even if overflow
# i64 -> i32
check
(
const
(
2
,
dtype
=
'int64'
),
const
(
2
,
dtype
=
'int64'
),
32
,
"int32"
)
check
(
const
(
2
**
16
,
dtype
=
'int64'
),
const
(
2
**
16
,
dtype
=
'int64'
),
32
,
"int64"
)
# i32 -> i16
check
(
2
,
2
,
16
,
"int16"
)
check
(
2
**
10
,
2
**
10
,
16
,
"int32"
)
# symbolic shape
check
(
te
.
size_var
(
name
=
'm'
,
dtype
=
'int32'
),
te
.
size_var
(
name
=
'n'
,
dtype
=
'int32'
),
32
,
"int32"
)
check
(
te
.
size_var
(
name
=
'm'
,
dtype
=
'int64'
),
te
.
size_var
(
name
=
'n'
,
dtype
=
'int64'
),
32
,
"int64"
)
def
test_thread_axis
():
def
check
(
m
,
n
,
target_bits
,
target_dtype
):
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
bx
=
te
.
thread_axis
(
"blockIdx.x"
)
tx
=
te
.
thread_axis
(
"threadIdx.x"
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
m
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
n
)
B
[
bx
*
n
+
tx
]
=
A
[
bx
*
n
+
tx
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
node
.
var
.
dtype
==
target_dtype
assert
stmt
.
body
.
node
.
var
.
dtype
==
target_dtype
# i32 -> i32
check
(
2
,
32
,
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
2
**
30
,
32
,
# i32 + i32 is not promoted to i64 even in the case of overflow
target_bits
=
32
,
target_dtype
=
'int32'
)
# i64 -> i32
check
(
const
(
2
,
dtype
=
'int64'
),
const
(
32
,
dtype
=
'int64'
),
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
const
(
2
**
30
,
dtype
=
'int64'
),
const
(
32
,
dtype
=
'int64'
),
target_bits
=
32
,
target_dtype
=
'int64'
)
# i32 -> i16
check
(
2
,
32
,
target_bits
=
16
,
target_dtype
=
'int16'
)
check
(
2
**
14
,
32
,
target_bits
=
16
,
target_dtype
=
'int32'
)
def
test_multilanes
():
def
check
(
m
,
lanes
,
target_bits
,
target_dtype
):
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,),
dtype
=
'float32x{}'
.
format
(
lanes
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,),
dtype
=
'float32x{}'
.
format
(
lanes
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
with
ib
.
for_range
(
0
,
m
,
name
=
'i'
,
dtype
=
m
.
dtype
)
as
i
:
B
[
i
]
=
A
[
i
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
loop_var
.
dtype
==
target_dtype
# i32 -> i32
check
(
const
(
2
**
10
,
dtype
=
'int32'
),
2
,
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
const
(
2
**
32
,
dtype
=
'int32'
),
2
,
target_bits
=
32
,
target_dtype
=
'int32'
)
# i64 -> i32
check
(
const
(
2
**
10
,
dtype
=
'int64'
),
2
,
target_bits
=
32
,
target_dtype
=
'int32'
)
check
(
const
(
2
**
32
,
dtype
=
'int64'
),
2
,
target_bits
=
32
,
target_dtype
=
'int64'
)
# i32 -> i16
check
(
const
(
2
**
10
,
dtype
=
'int32'
),
2
,
target_bits
=
16
,
target_dtype
=
'int16'
)
check
(
const
(
2
**
16
,
dtype
=
'int32'
),
2
,
target_bits
=
16
,
target_dtype
=
'int32'
)
def
test_reduce
():
def
check
(
m
,
target_bits
,
target_dtype
):
A
=
te
.
placeholder
((
m
,),
name
=
'A'
,
dtype
=
'float32'
)
k
=
te
.
reduce_axis
((
0
,
m
),
"k"
)
B
=
te
.
compute
((),
lambda
*
idx
:
te
.
sum
(
A
[
k
],
axis
=
k
),
name
=
'B'
)
s
=
te
.
create_schedule
(
B
.
op
)
stmt
=
lower_sch
(
s
,
[
A
,
B
],
target_bits
)
assert
stmt
.
body
[
1
]
.
loop_var
.
dtype
==
target_dtype
# i32 -> i32
check
(
const
(
64
,
dtype
=
'int32'
),
32
,
'int32'
)
# i64 -> i32
check
(
const
(
64
,
dtype
=
'int64'
),
32
,
'int32'
)
# i32 -> i16
check
(
const
(
64
,
dtype
=
'int32'
),
16
,
'int16'
)
check
(
const
(
2
**
16
,
dtype
=
'int32'
),
16
,
'int32'
)
# symbolic
check
(
te
.
var
(
'n'
,
dtype
=
'int32'
),
32
,
'int32'
)
check
(
te
.
var
(
'n'
,
dtype
=
'int64'
),
32
,
'int64'
)
def
test_slice
():
def
check
(
m
,
n
,
target_bits
,
target_dtype
):
# The index may overflow in B, while not in A
ib
=
tvm
.
tir
.
ir_builder
.
create
()
Ab
=
tvm
.
tir
.
decl_buffer
((
m
,
n
),
name
=
'A'
)
A
=
ib
.
buffer_ptr
(
Ab
)
Bb
=
tvm
.
tir
.
decl_buffer
((
m
,
n
*
2
),
name
=
'B'
)
B
=
ib
.
buffer_ptr
(
Bb
)
with
ib
.
for_range
(
0
,
m
,
name
=
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
name
=
'j'
)
as
j
:
A
[
i
*
n
+
j
]
=
B
[
i
*
2
*
n
+
2
*
j
]
+
1
stmt
=
ib
.
get
()
stmt
=
lower_stmt
([
Ab
,
Bb
],
stmt
,
target_bits
)
assert
stmt
.
loop_var
.
dtype
==
target_dtype
assert
stmt
.
body
.
loop_var
.
dtype
==
target_dtype
# The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
check
(
const
(
2
**
15
,
'int64'
),
const
(
2
**
15
,
'int64'
),
target_bits
=
32
,
target_dtype
=
'int32'
)
# The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
check
(
const
(
2
**
15
,
'int64'
),
const
((
2
**
15
+
1
),
'int64'
),
target_bits
=
32
,
target_dtype
=
'int64'
)
if
__name__
==
"__main__"
:
test_basic
()
test_thread_axis
()
test_multilanes
()
test_reduce
()
test_slice
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment