Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f5f2feea
Unverified
Commit
f5f2feea
authored
Sep 29, 2019
by
Tianqi Chen
Committed by
GitHub
Sep 29, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] migrate indexdiv/mod to floordiv/mod (#4008)
parent
2dac17d8
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
38 additions
and
19 deletions
+38
-19
python/tvm/expr.py
+3
-6
src/lang/attr_functor.h
+11
-2
src/lang/attrs.cc
+6
-2
src/lang/buffer.cc
+2
-2
src/lang/expr_operator.cc
+2
-2
src/pass/lower_intrin.cc
+10
-4
tests/python/unittest/test_codegen_device.py
+2
-0
tests/python/unittest/test_codegen_vm_basic.py
+1
-0
topi/python/topi/cuda/nms.py
+1
-1
No files found.
python/tvm/expr.py
View file @
f5f2feea
...
...
@@ -92,16 +92,13 @@ class ExprOp(object):
return
_generic
.
divide
(
other
,
self
)
def
__floordiv__
(
self
,
other
):
# return _generic.floordiv(self, other)
return
_generic
.
divide
(
self
,
other
)
return
_generic
.
floordiv
(
self
,
other
)
def
__rfloordiv__
(
self
,
other
):
# return _generic.floordiv(other, self)
return
_generic
.
divide
(
other
,
self
)
return
_generic
.
floordiv
(
other
,
self
)
def
__mod__
(
self
,
other
):
raise
div_ambiguity_error
()
# return _make._OpMod(self, other)
return
_make
.
_OpFloorMod
(
self
,
other
)
def
__neg__
(
self
):
neg_one
=
_api_internal
.
_const
(
-
1
,
self
.
dtype
)
...
...
src/lang/attr_functor.h
View file @
f5f2feea
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -87,6 +87,8 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
virtual
R
VisitAttr_
(
const
ir
::
Mul
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Div
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Mod
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
FloorDiv
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
FloorMod
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Min
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
Max
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
virtual
R
VisitAttr_
(
const
ir
::
GE
*
op
,
Args
...
args
)
ATTR_FUNCTOR_DEFAULT
;
...
...
@@ -119,6 +121,9 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH
(
Sub
);
ATTR_FUNCTOR_DISPATCH
(
Mul
);
ATTR_FUNCTOR_DISPATCH
(
Div
);
ATTR_FUNCTOR_DISPATCH
(
Mod
);
ATTR_FUNCTOR_DISPATCH
(
FloorDiv
);
ATTR_FUNCTOR_DISPATCH
(
FloorMod
);
ATTR_FUNCTOR_DISPATCH
(
Min
);
ATTR_FUNCTOR_DISPATCH
(
Max
);
ATTR_FUNCTOR_DISPATCH
(
GE
);
...
...
@@ -160,6 +165,8 @@ class AttrsEqualHandler :
bool
VisitAttr_
(
const
ir
::
Mul
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Div
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Mod
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
FloorDiv
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
FloorMod
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Min
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
Max
*
lhs
,
const
NodeRef
&
other
)
final
;
bool
VisitAttr_
(
const
ir
::
GE
*
lhs
,
const
NodeRef
&
other
)
final
;
...
...
@@ -201,6 +208,8 @@ class AttrsHashHandler :
size_t
VisitAttr_
(
const
ir
::
Mul
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Div
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Mod
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
FloorDiv
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
FloorMod
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Min
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
Max
*
op
)
final
;
size_t
VisitAttr_
(
const
ir
::
GE
*
op
)
final
;
...
...
src/lang/attrs.cc
View file @
f5f2feea
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Mul
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Div
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Mod
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
FloorDiv
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
FloorMod
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Max
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
Min
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
GE
);
...
...
@@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Mul
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Div
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Mod
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
FloorDiv
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
FloorMod
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Max
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
Min
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
GE
);
...
...
src/lang/buffer.cc
View file @
f5f2feea
...
...
@@ -32,8 +32,8 @@
namespace
tvm
{
// TODO(tqchen): change to floormod/div
using
IndexMod
=
ir
::
Mod
;
using
IndexDiv
=
ir
::
Div
;
using
IndexMod
=
ir
::
Floor
Mod
;
using
IndexDiv
=
ir
::
Floor
Div
;
Array
<
Expr
>
SimplifyArray
(
Array
<
Expr
>
array
)
{
for
(
size_t
i
=
0
;
i
<
array
.
size
();
++
i
)
{
...
...
src/lang/expr_operator.cc
View file @
f5f2feea
...
...
@@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) {
// TODO(tqchen): switch to floordiv
Expr
indexdiv
(
Expr
a
,
Expr
b
)
{
return
trunc
div
(
a
,
b
);
return
floor
div
(
a
,
b
);
}
Expr
indexmod
(
Expr
a
,
Expr
b
)
{
return
trunc
mod
(
a
,
b
);
return
floor
mod
(
a
,
b
);
}
Expr
floordiv
(
Expr
a
,
Expr
b
)
{
...
...
src/pass/lower_intrin.cc
View file @
f5f2feea
...
...
@@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
patterns_
.
push_back
(
"tvm.intrin.rule."
+
starget
+
"."
);
patterns_
.
push_back
(
"tvm.intrin.rule.default."
);
fma_
=
runtime
::
Registry
::
Get
(
patterns_
[
0
]
+
"fma"
);
if
(
target
==
"stackvm"
)
{
support_bitwise_op_
=
false
;
}
}
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
...
...
@@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const
DataType
&
dtype
=
op
->
type
;
CHECK
(
dtype
.
is_int
()
||
!
dtype
.
is_uint
());
if
(
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
if
(
support_bitwise_op_
&&
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
// lower to right shift if possible.
return
op
->
a
>>
make_const
(
dtype
,
shift
);
}
...
...
@@ -93,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if
(
dtype
==
Int
(
32
)
||
dtype
==
Int
(
64
)
)
{
if
(
(
dtype
==
Int
(
32
)
||
dtype
==
Int
(
64
))
&&
support_bitwise_op_
)
{
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return
rdiv
+
(
rmod
>>
make_const
(
dtype
,
dtype
.
bits
()
-
1
));
}
else
{
...
...
@@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const
DataType
&
dtype
=
op
->
type
;
CHECK
(
dtype
.
is_int
()
||
!
dtype
.
is_uint
());
if
(
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
if
(
support_bitwise_op_
&&
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
// lower to masking if possible.
int64_t
mask
=
(
static_cast
<
int64_t
>
(
1
)
<<
static_cast
<
int64_t
>
(
shift
))
-
1
;
...
...
@@ -140,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
Expr
rmod
=
truncmod
(
op
->
a
,
op
->
b
);
if
(
dtype
==
Int
(
32
)
||
dtype
==
Int
(
64
)
)
{
if
(
(
dtype
==
Int
(
32
)
||
dtype
==
Int
(
64
))
&&
support_bitwise_op_
)
{
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
...
...
@@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// patterns
std
::
vector
<
std
::
string
>
patterns_
;
const
PackedFunc
*
fma_
{
nullptr
};
bool
support_bitwise_op_
{
true
};
};
Stmt
LowerIntrinStmt
(
Stmt
stmt
,
const
std
::
string
&
target
)
{
...
...
tests/python/unittest/test_codegen_device.py
View file @
f5f2feea
...
...
@@ -48,6 +48,8 @@ def test_add_pipeline():
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
fapi
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"myadd"
,
[
Ab
,
Bb
,
Db
],
0
,
True
)
fsplits
=
[
x
for
x
in
tvm
.
ir_pass
.
SplitHostDevice
(
fapi
)]
# lower the floordiv(use stackvm rules so it works for all targets)
fsplits
=
[
tvm
.
ir_pass
.
LowerIntrin
(
x
,
"stackvm"
)
for
x
in
fsplits
]
fsplits
[
0
]
=
tvm
.
ir_pass
.
LowerTVMBuiltin
(
fsplits
[
0
])
def
check_target
(
device
,
host
=
"stackvm"
):
...
...
tests/python/unittest/test_codegen_vm_basic.py
View file @
f5f2feea
...
...
@@ -37,6 +37,7 @@ def test_stack_vm_basic():
stmt
=
tvm
.
make
.
Evaluate
(
tvm
.
call_packed
(
"tvm_call_back_get_shape"
,
Ab
.
shape
[
0
]))
fapi
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"print_shape"
,
[
Ab
],
0
,
True
)
fapi
=
tvm
.
ir_pass
.
LowerTVMBuiltin
(
fapi
)
fapi
=
tvm
.
ir_pass
.
LowerIntrin
(
fapi
,
"stackvm"
)
run_jit
(
fapi
,
lambda
f
:
f
(
a
))
...
...
topi/python/topi/cuda/nms.py
View file @
f5f2feea
...
...
@@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial):
ib
.
scope_attr
(
bx
,
"thread_extent"
,
nthread_bx
)
var
=
tvm
.
make
.
node
(
"FloatImm"
,
dtype
=
"float32"
,
value
=
2
)
new_range
=
num_anchors
//
elem_per_thread
+
1
iteration
=
log
(
cast
(
new_range
,
"float32"
))
//
math
.
log
(
2
)
iteration
=
cast
(
log
(
cast
(
new_range
,
"float32"
))
/
math
.
log
(
2
),
"int32"
)
# Scan: Kogge-Stone adder
with
ib
.
if_scope
(
tvm
.
all
(
bx
<
batch_size
,
tx
<
tvm
.
min
(
new_range
,
num_anchors
))):
with
ib
.
for_range
(
0
,
iteration
)
as
k
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment