Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e8899285
Commit
e8899285
authored
Oct 29, 2019
by
Wuwei Lin
Committed by
Tianqi Chen
Oct 29, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Quantize] Use fixed point mulplications (#4160)
parent
8b1fb4d5
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
27 additions
and
17 deletions
+27
-17
python/tvm/relay/quantize/quantize.py
+4
-0
src/relay/pass/quantize/quantize.cc
+2
-1
src/relay/pass/quantize/quantize.h
+2
-0
src/relay/pass/quantize/realize.cc
+13
-10
src/relay/qnn/op/requantize.cc
+2
-4
src/relay/qnn/util.cc
+3
-1
src/relay/qnn/util.h
+1
-1
No files found.
python/tvm/relay/quantize/quantize.py
View file @
e8899285
...
@@ -83,6 +83,7 @@ class QConfig(NodeBase):
...
@@ -83,6 +83,7 @@ class QConfig(NodeBase):
"do_simulation"
:
False
,
"do_simulation"
:
False
,
"round_for_shift"
:
True
,
"round_for_shift"
:
True
,
"debug_enabled_ops"
:
None
,
"debug_enabled_ops"
:
None
,
"rounding"
:
"UPWARD"
}
}
# pylint: disable=no-member
# pylint: disable=no-member
...
@@ -160,6 +161,9 @@ def qconfig(**kwargs):
...
@@ -160,6 +161,9 @@ def qconfig(**kwargs):
is None, which means will try to call all operartors' annotate rewrite
is None, which means will try to call all operartors' annotate rewrite
function.
function.
rounding: "UPWARD" or "TONEAREST"
Rounding direction for fixed point multiplications.
Returns
Returns
-------
-------
config: QConfig
config: QConfig
...
...
src/relay/pass/quantize/quantize.cc
View file @
e8899285
...
@@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p
->
stream
<<
"skip_conv_layers=="
<<
op
->
skip_conv_layers
<<
", "
;
p
->
stream
<<
"skip_conv_layers=="
<<
op
->
skip_conv_layers
<<
", "
;
p
->
stream
<<
"do_simulation=="
<<
op
->
do_simulation
<<
", "
;
p
->
stream
<<
"do_simulation=="
<<
op
->
do_simulation
<<
", "
;
p
->
stream
<<
"round_for_shift=="
<<
op
->
round_for_shift
<<
", "
;
p
->
stream
<<
"round_for_shift=="
<<
op
->
round_for_shift
<<
", "
;
p
->
stream
<<
"debug_enabled_ops=="
<<
op
->
debug_enabled_ops
;
p
->
stream
<<
"debug_enabled_ops=="
<<
op
->
debug_enabled_ops
<<
", "
;
p
->
stream
<<
"rounding=="
<<
op
->
rounding
;
p
->
stream
<<
")"
;
p
->
stream
<<
")"
;
});
});
...
...
src/relay/pass/quantize/quantize.h
View file @
e8899285
...
@@ -75,6 +75,7 @@ class QConfigNode : public Node {
...
@@ -75,6 +75,7 @@ class QConfigNode : public Node {
bool
do_simulation
=
false
;
bool
do_simulation
=
false
;
bool
round_for_shift
=
true
;
bool
round_for_shift
=
true
;
Array
<
Expr
>
debug_enabled_ops
=
Array
<
Expr
>
(
NodePtr
<
Node
>
(
nullptr
));
Array
<
Expr
>
debug_enabled_ops
=
Array
<
Expr
>
(
NodePtr
<
Node
>
(
nullptr
));
std
::
string
rounding
=
"UPWARD"
;
void
VisitAttrs
(
AttrVisitor
*
v
)
{
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"nbit_input"
,
&
nbit_input
);
v
->
Visit
(
"nbit_input"
,
&
nbit_input
);
...
@@ -88,6 +89,7 @@ class QConfigNode : public Node {
...
@@ -88,6 +89,7 @@ class QConfigNode : public Node {
v
->
Visit
(
"do_simulation"
,
&
do_simulation
);
v
->
Visit
(
"do_simulation"
,
&
do_simulation
);
v
->
Visit
(
"round_for_shift"
,
&
round_for_shift
);
v
->
Visit
(
"round_for_shift"
,
&
round_for_shift
);
v
->
Visit
(
"debug_enabled_ops"
,
&
debug_enabled_ops
);
v
->
Visit
(
"debug_enabled_ops"
,
&
debug_enabled_ops
);
v
->
Visit
(
"rounding"
,
&
rounding
);
}
}
static
constexpr
const
char
*
_type_key
=
"relay.quantize.QConfig"
;
static
constexpr
const
char
*
_type_key
=
"relay.quantize.QConfig"
;
...
...
src/relay/pass/quantize/realize.cc
View file @
e8899285
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/annotation.h>
#include "./quantize.h"
#include "./quantize.h"
#include "../pattern_util.h"
#include "../pattern_util.h"
#include "../../qnn/util.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
...
@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
...
@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
/* calculate `data * s1 / s2`, use shift if possible */
/* calculate `data * s1 / s2`, use shift if possible */
inline
Expr
MulAndDiv
(
Expr
data
,
float
s1
,
float
s2
,
DataType
dtype
)
{
inline
Expr
MulAndDiv
(
Expr
data
,
float
s1
,
float
s2
,
DataType
dtype
,
const
Array
<
IndexExpr
>
&
data_shape
)
{
const
QConfig
&
cfg
=
QConfig
::
Current
();
// here we assume the dtype of data is dtype activation
// here we assume the dtype of data is dtype activation
if
(
s1
==
s2
)
return
data
;
if
(
s1
==
s2
)
return
data
;
...
@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
...
@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
}
else
if
(
static_cast
<
int
>
(
factor
)
==
factor
)
{
}
else
if
(
static_cast
<
int
>
(
factor
)
==
factor
)
{
return
Multiply
(
data
,
MakeConstantScalar
(
dtype
,
factor
));
return
Multiply
(
data
,
MakeConstantScalar
(
dtype
,
factor
));
}
else
{
}
else
{
data
=
Cast
(
data
,
Float
(
32
));
data
=
qnn
::
FixedPointMultiply
(
Cast
(
data
,
Int
(
64
)),
factor
,
data_shape
,
cfg
->
rounding
);
data
=
Multiply
(
data
,
MakeConstantScalar
(
Float
(
32
),
factor
));
return
Cast
(
data
,
dtype
);
return
Cast
(
Round
(
data
),
dtype
);
}
}
}
}
...
@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
...
@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
data
=
Clip
(
data
,
clip_min_imm
,
clip_max_imm
);
data
=
Clip
(
data
,
clip_min_imm
,
clip_max_imm
);
return
QRealizeIntExprNode
::
make
(
data
,
dom_scale
,
n
->
dtype
);
return
QRealizeIntExprNode
::
make
(
data
,
dom_scale
,
n
->
dtype
);
}
else
{
}
else
{
// float computation
data
=
Cast
(
data
,
Int
(
64
));
data
=
Cast
(
data
,
Float
(
32
));
data
=
qnn
::
FixedPointMultiply
(
data
,
idom_scale_imm
/
odom_scale_imm
,
Expr
scaled_data
=
Multiply
(
data
,
Divide
(
n
->
dom_scale
,
dom_scale
));
ref_call
->
type_as
<
TensorTypeNode
>
()
->
shape
,
Expr
round_data
=
Clip
(
Round
(
scaled_data
),
clip_min_imm
,
clip_max_imm
);
cfg
->
rounding
);
return
QRealizeIntExprNode
::
make
(
round_data
,
dom_scale
,
Float
(
32
));
data
=
Cast
(
Clip
(
data
,
clip_min_imm
,
clip_max_imm
),
n
->
dtype
);
return
QRealizeIntExprNode
::
make
(
data
,
dom_scale
,
n
->
dtype
);
}
}
}
}
...
@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
...
@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
Expr
dom_scale
=
MakeConstantScalar
(
Float
(
32
),
s
);
Expr
dom_scale
=
MakeConstantScalar
(
Float
(
32
),
s
);
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
float
cur_s
=
GetScalarFromConstant
<
float
>
(
nptrs
[
i
]
->
dom_scale
);
float
cur_s
=
GetScalarFromConstant
<
float
>
(
nptrs
[
i
]
->
dom_scale
);
ret
.
Set
(
i
,
MulAndDiv
(
ret
[
i
],
cur_s
,
s
,
dtype
));
ret
.
Set
(
i
,
MulAndDiv
(
ret
[
i
],
cur_s
,
s
,
dtype
,
ref_args
[
i
]
->
type_as
<
TensorTypeNode
>
()
->
shape
));
}
}
*
dtype_ptr
=
dtype
;
*
dtype_ptr
=
dtype
;
...
...
src/relay/qnn/op/requantize.cc
View file @
e8899285
...
@@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
...
@@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
// Lowering of qnn.requantize op
// Lowering of qnn.requantize op
/*
/*
* \brief Lower requantize to a sequence of ops.
* \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op.
* \param input_tensor The input tensor to requantize op.
...
@@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
...
@@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
auto
scaled_int64_t
=
tensor
;
auto
scaled_int64_t
=
tensor
;
if
(
param
->
input_scale
!=
param
->
output_scale
)
{
if
(
param
->
input_scale
!=
param
->
output_scale
)
{
scaled_int64_t
=
FixedPointMuliply
(
scaled_int64_t
,
double_multiplier
,
input_shape
,
scaled_int64_t
=
param
->
rounding
);
FixedPointMultiply
(
scaled_int64_t
,
double_multiplier
,
input_shape
,
param
->
rounding
);
}
}
// 3) Add the output zero point.
// 3) Add the output zero point.
...
...
src/relay/qnn/util.cc
View file @
e8899285
...
@@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
...
@@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
return
std
::
make_pair
(
significand
,
exponent
);
return
std
::
make_pair
(
significand
,
exponent
);
}
}
Expr
FixedPointMuliply
(
Expr
tensor
,
double
multiplier
,
Expr
FixedPointMul
t
iply
(
Expr
tensor
,
double
multiplier
,
const
Array
<
IndexExpr
>&
input_shape
,
const
std
::
string
&
rounding
)
{
const
Array
<
IndexExpr
>&
input_shape
,
const
std
::
string
&
rounding
)
{
// Choose high precision datatype to be int64. This is for avoiding overflow
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
// in multiplication of two int32 values.
...
@@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
...
@@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
auto
zero_t
=
Zeros
(
input_shape
,
hp_dtype
);
auto
zero_t
=
Zeros
(
input_shape
,
hp_dtype
);
round_scalar
=
round_scalar
=
Where
(
GreaterEqual
(
tensor
,
zero_t
),
pos_rounder_t
,
neg_rounder_t
);
Where
(
GreaterEqual
(
tensor
,
zero_t
),
pos_rounder_t
,
neg_rounder_t
);
}
else
{
LOG
(
FATAL
)
<<
"Rounding mode "
<<
rounding
<<
" not supported."
;
}
}
// Add the rounding scalar.
// Add the rounding scalar.
tensor
=
Add
(
tensor
,
round_scalar
);
tensor
=
Add
(
tensor
,
round_scalar
);
...
...
src/relay/qnn/util.h
View file @
e8899285
...
@@ -115,7 +115,7 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
...
@@ -115,7 +115,7 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
* 2) Round the result.
* 2) Round the result.
* 3) Right shift the result
* 3) Right shift the result
*/
*/
Expr
FixedPointMuliply
(
Expr
tensor
,
double
multiplier
,
Expr
FixedPointMul
t
iply
(
Expr
tensor
,
double
multiplier
,
const
Array
<
IndexExpr
>&
input_shape
,
const
Array
<
IndexExpr
>&
input_shape
,
const
std
::
string
&
rounding
);
const
std
::
string
&
rounding
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment