Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e0d286a1
Commit
e0d286a1
authored
Oct 21, 2019
by
Haichen Shen
Committed by
Zhi
Oct 21, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Pass] Count MAC for BatchMatMul (#4157)
* count MAC for BatchMatMul * update doc
parent
d660e514
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
10 deletions
+29
-10
src/relay/pass/mac_count.cc
+29
-10
No files found.
src/relay/pass/mac_count.cc
View file @
e0d286a1
...
...
@@ -66,7 +66,7 @@ int64_t ConvMacCount(const Call& call_node) {
return
0
;
}
Array
<
Expr
>
args
=
call_node
->
args
;
CHECK
(
args
.
size
()
==
2
)
CHECK
_EQ
(
args
.
size
(),
2
)
<<
"The number of input arguments of a CONV 2D node should be 2."
;
const
auto
*
conv_2d_attr
=
call_node
->
attrs
.
as
<
Conv2DAttrs
>
();
const
auto
*
data_type
=
args
[
0
]
->
checked_type
().
as
<
TensorTypeNode
>
();
...
...
@@ -74,13 +74,13 @@ int64_t ConvMacCount(const Call& call_node) {
std
::
string
data_layout
=
conv_2d_attr
->
data_layout
;
int32_t
C_ind
=
Layout
(
data_layout
).
IndexOf
(
LayoutAxis
::
Get
(
'C'
));
int32_t
c_ind
=
Layout
(
data_layout
).
IndexOf
(
LayoutAxis
::
Get
(
'c'
));
CHECK
(
C_ind
!=
-
1
)
CHECK
_NE
(
C_ind
,
-
1
)
<<
"There is no input channel dimension."
;
int64_t
input_channel
=
static_cast
<
int64_t
>
(
data_shape
[
C_ind
].
as
<
IntImm
>
()
->
value
);
if
(
c_ind
!=
-
1
)
input_channel
*=
static_cast
<
int64_t
>
(
data_shape
[
c_ind
].
as
<
IntImm
>
()
->
value
);
Array
<
IndexExpr
>
kernel_size
=
conv_2d_attr
->
kernel_size
;
CHECK
(
kernel_size
.
size
()
==
2
)
CHECK
_EQ
(
kernel_size
.
size
(),
2
)
<<
"The dimension of the kernel in Conv 2D should be 2."
;
const
auto
*
expr
=
call_node
->
checked_type
().
as
<
TensorTypeNode
>
();
Array
<
IndexExpr
>
output_tensor
=
expr
->
shape
;
...
...
@@ -99,7 +99,7 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
return
0
;
}
Array
<
Expr
>
args
=
call_node
->
args
;
CHECK
(
args
.
size
()
==
2
)
CHECK
_EQ
(
args
.
size
(),
2
)
<<
"The number of input arguments of a CONV 2D Transpose node should be 2."
;
const
auto
*
conv_2d_transpose_attr
=
call_node
->
attrs
.
as
<
Conv2DTransposeAttrs
>
();
const
auto
*
data_type
=
args
[
0
]
->
checked_type
().
as
<
TensorTypeNode
>
();
...
...
@@ -107,13 +107,13 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
std
::
string
data_layout
=
conv_2d_transpose_attr
->
data_layout
;
int32_t
C_ind
=
Layout
(
data_layout
).
IndexOf
(
LayoutAxis
::
Get
(
'C'
));
int32_t
c_ind
=
Layout
(
data_layout
).
IndexOf
(
LayoutAxis
::
Get
(
'c'
));
CHECK
(
C_ind
!=
-
1
)
CHECK
_NE
(
C_ind
,
-
1
)
<<
"There is no input channel dimension."
;
int64_t
input_channel
=
static_cast
<
int64_t
>
(
data_shape
[
C_ind
].
as
<
IntImm
>
()
->
value
);
if
(
c_ind
!=
-
1
)
input_channel
*=
static_cast
<
int64_t
>
(
data_shape
[
c_ind
].
as
<
IntImm
>
()
->
value
);
Array
<
IndexExpr
>
kernel_size
=
conv_2d_transpose_attr
->
kernel_size
;
CHECK
(
kernel_size
.
size
()
==
2
)
CHECK
_EQ
(
kernel_size
.
size
(),
2
)
<<
"The dimension of the kernel in Conv 2D Transpose should be 2."
;
const
auto
*
expr
=
call_node
->
checked_type
().
as
<
TensorTypeNode
>
();
Array
<
IndexExpr
>
output_tensor
=
expr
->
shape
;
...
...
@@ -132,7 +132,7 @@ int64_t DenseMacCount(const Call& call_node) {
return
0
;
}
Array
<
Expr
>
args
=
call_node
->
args
;
CHECK
(
args
.
size
()
==
2
)
CHECK
_EQ
(
args
.
size
(),
2
)
<<
"The number of input arguments of a Dense node should be 2."
;
const
auto
*
data_type
=
args
[
0
]
->
checked_type
().
as
<
TensorTypeNode
>
();
const
auto
*
weight_type
=
args
[
1
]
->
checked_type
().
as
<
TensorTypeNode
>
();
...
...
@@ -144,12 +144,28 @@ int64_t DenseMacCount(const Call& call_node) {
int64_t
d2
=
static_cast
<
int64_t
>
(
data_shape
[
1
].
as
<
IntImm
>
()
->
value
);
int64_t
d3
=
static_cast
<
int64_t
>
(
weight_shape
[
0
].
as
<
IntImm
>
()
->
value
);
int64_t
d4
=
static_cast
<
int64_t
>
(
weight_shape
[
1
].
as
<
IntImm
>
()
->
value
);
CHECK
(
d2
==
d4
)
CHECK
_EQ
(
d2
,
d4
)
<<
"The dimensions of input arguments do not match."
;
int64_t
count
=
d1
*
d2
*
d3
;
return
count
;
}
int64_t
BatchMatmulMacCount
(
const
Call
&
call_node
)
{
if
(
!
call_node
->
checked_type_
.
defined
())
{
LOG
(
WARNING
)
<<
"The infer type pass should be called before the mac count pass"
;
return
0
;
}
Array
<
Expr
>
args
=
call_node
->
args
;
CHECK_EQ
(
args
.
size
(),
2
);
Array
<
IndexExpr
>
x_shape
=
args
[
0
]
->
checked_type
().
as
<
TensorTypeNode
>
()
->
shape
;
Array
<
IndexExpr
>
y_shape
=
args
[
1
]
->
checked_type
().
as
<
TensorTypeNode
>
()
->
shape
;
int64_t
batch
=
x_shape
[
0
].
as
<
IntImm
>
()
->
value
;
int64_t
m
=
x_shape
[
1
].
as
<
IntImm
>
()
->
value
;
int64_t
k
=
x_shape
[
2
].
as
<
IntImm
>
()
->
value
;
int64_t
n
=
y_shape
[
1
].
as
<
IntImm
>
()
->
value
;
return
batch
*
m
*
k
*
n
;
}
RELAY_REGISTER_OP
(
"nn.conv2d"
)
.
set_attr
<
FMacCount
>
(
"FMacCount"
,
ConvMacCount
);
...
...
@@ -159,14 +175,17 @@ RELAY_REGISTER_OP("nn.conv2d_transpose")
RELAY_REGISTER_OP
(
"nn.dense"
)
.
set_attr
<
FMacCount
>
(
"FMacCount"
,
DenseMacCount
);
RELAY_REGISTER_OP
(
"nn.batch_matmul"
)
.
set_attr
<
FMacCount
>
(
"FMacCount"
,
BatchMatmulMacCount
);
class
MacCounter
:
private
ExprVisitor
{
public
:
MacCounter
()
{
count_
=
0
;
}
static
int64_t
GetTotalMacNumber
(
const
Expr
&
expr
)
{
LOG
(
INFO
)
<<
"This pass only counts MACs in direct
CONV 2D
, "
<<
"
CONV 2D Transpose and Dense
ops"
;
LOG
(
INFO
)
<<
"This pass only counts MACs in direct
conv2d
, "
<<
"
conv2d_transpose, dense, and batch_matmul
ops"
;
MacCounter
counter
;
counter
(
expr
);
return
counter
.
count_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment