Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6bc0ae12
Commit
6bc0ae12
authored
Aug 02, 2017
by
Tianqi Chen
Committed by
ziheng
Aug 02, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Refactor intset eval with functor (#295)
parent
10bc2fdf
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
137 additions
and
66 deletions
+137
-66
include/tvm/arithmetic.h
+6
-0
src/api/api_arith.cc
+5
-0
src/arithmetic/int_set.cc
+117
-66
tests/python/unittest/test_arith_intset.py
+9
-0
No files found.
include/tvm/arithmetic.h
View file @
6bc0ae12
...
...
@@ -94,6 +94,12 @@ class IntSet : public NodeRef {
*/
static
IntSet
single_point
(
Expr
point
);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static
IntSet
vector
(
Expr
vec
);
/*!
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
...
...
src/api/api_arith.cc
View file @
6bc0ae12
...
...
@@ -16,6 +16,11 @@ TVM_REGISTER_API("arith.intset_single_point")
*
ret
=
IntSet
::
single_point
(
args
[
0
]);
});
TVM_REGISTER_API
(
"arith.intset_vector"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IntSet
::
vector
(
args
[
0
]);
});
TVM_REGISTER_API
(
"arith.intset_interval"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IntSet
::
interval
(
args
[
0
],
args
[
1
]);
...
...
src/arithmetic/int_set.cc
View file @
6bc0ae12
...
...
@@ -6,6 +6,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <arithmetic/Interval.h>
#include <unordered_map>
#include "./compute_expr.h"
...
...
@@ -423,80 +424,129 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return
CombineSets
<
OP
>
(
a
,
b
);
}
// Evaluator to evalute the epxression.
class
IntSetEvaluator
{
class
IntSetEvaluator
:
public
ExprFunctor
<
IntSet
(
const
Expr
&
,
const
Expr
&
)
>
{
public
:
explicit
IntSetEvaluator
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
:
dom_map
(
dom_map
)
{}
inline
virtual
IntSet
Eval
(
Expr
expr
)
{
static
const
FType
&
f
=
vtable
();
if
(
f
.
can_dispatch
(
expr
))
{
return
f
(
expr
,
expr
,
this
);
}
else
{
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
expr
->
type_key
();
return
IntSet
::
nothing
();
}
explicit
IntSetEvaluator
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
,
bool
eval_vec
=
false
)
:
dom_map_
(
dom_map
),
eval_vec_
(
eval_vec
)
{}
// Evaluate.
IntSet
Eval
(
const
Expr
&
e
)
{
return
this
->
VisitExpr
(
e
,
e
);
}
IntSet
VisitExpr_
(
const
IntImm
*
op
,
const
Expr
&
e
)
final
{
return
IntSet
::
single_point
(
e
);
}
using
FType
=
tvm
::
IRFunctor
<
IntSet
(
const
NodeRef
&
,
const
Expr
&
,
IntSetEvaluator
*
)
>
;
static
FType
&
vtable
()
{
// NOLINT(*)
static
FType
inst
;
return
inst
;
IntSet
VisitExpr_
(
const
UIntImm
*
op
,
const
Expr
&
e
)
final
{
return
IntSet
::
single_point
(
e
);
}
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
;
};
inline
IntSet
ConstOp
(
const
NodeRef
&
,
const
Expr
&
e
,
IntSetEvaluator
*
m
)
{
return
IntSet
::
single_point
(
e
);
}
TVM_STATIC_IR_FUNCTOR
(
IntSetEvaluator
,
vtable
)
.
set_dispatch
<
IntImm
>
(
ConstOp
)
.
set_dispatch
<
UIntImm
>
(
ConstOp
)
.
set_dispatch
<
FloatImm
>
(
ConstOp
);
TVM_STATIC_IR_FUNCTOR
(
IntSetEvaluator
,
vtable
)
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
IntSetEvaluator
*
m
)
{
auto
it
=
m
->
dom_map
.
find
(
op
);
if
(
it
!=
m
->
dom_map
.
end
())
{
IntSet
VisitExpr_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
dom_map_
.
find
(
op
);
if
(
it
!=
dom_map_
.
end
())
{
return
it
->
second
;
}
else
{
return
IntSet
::
single_point
(
e
);
}
});
}
IntSet
VisitExpr_
(
const
Add
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Sub
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Mul
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Div
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Mod
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Min
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Max
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
EQ
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
NE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
LT
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
LE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
GT
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
GE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
And
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Or
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
IntSet
VisitExpr_
(
const
Ramp
*
op
,
const
Expr
&
e
)
final
{
CHECK
(
eval_vec_
);
IntSet
base
=
Eval
(
op
->
base
);
int
vstride
;
if
(
GetConstInt
(
op
->
stride
,
&
vstride
))
{
Type
t
=
op
->
base
.
type
();
if
(
vstride
>
0
)
{
return
Combine
<
Add
>
(
base
,
IntSet
::
interval
(
make_zero
(
t
),
make_const
(
t
,
vstride
*
op
->
lanes
-
1
)));
}
else
{
return
Combine
<
Add
>
(
base
,
IntSet
::
interval
(
make_const
(
t
,
vstride
*
op
->
lanes
+
1
),
make_zero
(
t
)));
}
}
LOG
(
WARNING
)
<<
"cannot evaluate set on expression "
<<
e
;
return
IntSet
::
everything
();
}
IntSet
VisitExpr_
(
const
Broadcast
*
op
,
const
Expr
&
e
)
final
{
CHECK
(
eval_vec_
);
return
Eval
(
op
->
value
);
}
IntSet
VisitExprDefault_
(
const
Node
*
op
,
const
Expr
&
e
)
final
{
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
e
->
type_key
();
return
IntSet
::
everything
();
}
// binary operator
template
<
typename
T
>
inline
IntSet
Binary
(
const
T
*
op
,
const
Expr
&
e
,
IntSetEvaluator
*
m
)
{
IntSet
a
=
m
->
Eval
(
op
->
a
);
IntSet
b
=
m
->
Eval
(
op
->
b
);
if
(
MatchPoint
(
a
,
op
->
a
)
&&
MatchPoint
(
b
,
op
->
b
))
{
return
IntSet
::
single_point
(
e
);
private
:
template
<
typename
T
>
inline
IntSet
Binary
(
const
T
*
op
,
const
Expr
&
e
)
{
IntSet
a
=
this
->
Eval
(
op
->
a
);
IntSet
b
=
this
->
Eval
(
op
->
b
);
if
(
MatchPoint
(
a
,
op
->
a
)
&&
MatchPoint
(
b
,
op
->
b
))
{
return
IntSet
::
single_point
(
e
);
}
return
Combine
<
T
>
(
a
,
b
);
}
return
Combine
<
T
>
(
a
,
b
);
}
TVM_STATIC_IR_FUNCTOR
(
IntSetEvaluator
,
vtable
)
.
set_dispatch
<
Add
>
(
Binary
<
Add
>
)
.
set_dispatch
<
Sub
>
(
Binary
<
Sub
>
)
.
set_dispatch
<
Mul
>
(
Binary
<
Mul
>
)
.
set_dispatch
<
Div
>
(
Binary
<
Div
>
)
.
set_dispatch
<
Mod
>
(
Binary
<
Mod
>
)
.
set_dispatch
<
Min
>
(
Binary
<
Min
>
)
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
)
.
set_dispatch
<
EQ
>
(
Binary
<
EQ
>
)
.
set_dispatch
<
NE
>
(
Binary
<
NE
>
)
.
set_dispatch
<
LT
>
(
Binary
<
LT
>
)
.
set_dispatch
<
LE
>
(
Binary
<
LE
>
)
.
set_dispatch
<
GT
>
(
Binary
<
GT
>
)
.
set_dispatch
<
GE
>
(
Binary
<
GE
>
)
.
set_dispatch
<
And
>
(
Binary
<
And
>
)
.
set_dispatch
<
Or
>
(
Binary
<
Or
>
);
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map_
;
bool
eval_vec_
{
false
};
};
IntSet
EvalSet
(
Expr
e
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
return
IntSetEvaluator
(
dom_map
).
Eval
(
e
);
return
IntSetEvaluator
(
dom_map
,
false
).
Eval
(
e
);
}
IntSet
IntSet
::
vector
(
Expr
x
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dmap
;
return
IntSetEvaluator
(
dmap
,
true
).
Eval
(
x
);
}
IntSet
EvalSet
(
Expr
e
,
...
...
@@ -521,12 +571,13 @@ IntSet EvalSet(Range r,
class
SubExprIntSetEvaluator
:
public
IntSetEvaluator
{
public
:
explicit
SubExprIntSetEvaluator
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
explicit
SubExprIntSetEvaluator
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
:
IntSetEvaluator
(
dom_map
)
{}
inline
IntSet
Eval
(
Expr
expr
)
override
{
IntSet
ret
=
IntSetEvaluator
::
Eval
(
expr
);
expr_map
[
expr
]
=
ret
;
IntSet
VisitExpr
(
const
Expr
&
n
,
const
Expr
&
e
)
final
{
IntSet
ret
=
IntSetEvaluator
::
VisitExpr
(
n
,
e
);
expr_map
[
n
]
=
ret
;
return
ret
;
}
...
...
tests/python/unittest/test_arith_intset.py
View file @
6bc0ae12
...
...
@@ -5,6 +5,14 @@ def test_basic():
assert
s
.
min
()
.
value
==
2
assert
s
.
max
()
.
value
==
3
def
test_vector
():
base
=
10
stride
=
3
lanes
=
2
s
=
tvm
.
arith
.
intset_vector
(
tvm
.
make
.
Ramp
(
base
,
stride
,
lanes
))
assert
s
.
min
()
.
value
==
base
assert
s
.
max
()
.
value
==
base
+
stride
*
lanes
-
1
def
test_deduce
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
...
...
@@ -59,5 +67,6 @@ def test_check():
if
__name__
==
"__main__"
:
test_basic
()
test_vector
()
test_deduce
()
test_check
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment