Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2b3d2e21
Commit
2b3d2e21
authored
Sep 22, 2017
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Improve GraphFuse to include five patterns (#26)
parent
2e9b6b99
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
162 additions
and
46 deletions
+162
-46
nnvm/docs/top.rst
+10
-3
nnvm/include/nnvm/compiler/op_attr_types.h
+13
-6
nnvm/python/nnvm/compiler/registry.py
+17
-5
nnvm/python/nnvm/top/nn.py
+13
-5
nnvm/python/nnvm/top/tensor.py
+22
-18
nnvm/python/nnvm/top/transform.py
+1
-1
nnvm/src/compiler/graph_fuse.cc
+24
-7
nnvm/src/compiler/layout_transform.cc
+1
-1
nnvm/tests/python/compiler/test_op_fusion.py
+61
-0
No files found.
nnvm/docs/top.rst
View file @
2b3d2e21
NNVM Core
Primitive
s
====================
NNVM Core
Tensor Operator
s
====================
======
**Level 1: Basic Ops**
**Level 1: Basic Operators**
This level enables fully connected multi-layer perceptron.
.. autosummary::
:nosignatures:
...
...
@@ -12,12 +13,14 @@ NNVM Core Primitives
nnvm.symbol.sigmoid
nnvm.symbol.exp
nnvm.symbol.log
nnvm.symbol.sqrt
nnvm.symbol.elemwise_add
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
nnvm.symbol.split
nnvm.symbol.dropout
nnvm.symbol.batch_norm
...
...
@@ -27,6 +30,8 @@ NNVM Core Primitives
**Level 2: Convolutions**
This level enables typical convnet models.
.. autosummary::
:nosignatures:
...
...
@@ -78,12 +83,14 @@ NNVM Core Primitives
.. autofunction:: nnvm.symbol.sigmoid
.. autofunction:: nnvm.symbol.exp
.. autofunction:: nnvm.symbol.log
.. autofunction:: nnvm.symbol.sqrt
.. autofunction:: nnvm.symbol.elemwise_add
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
.. autofunction:: nnvm.symbol.split
.. autofunction:: nnvm.symbol.dropout
.. autofunction:: nnvm.symbol.batch_norm
...
...
nnvm/include/nnvm/compiler/op_attr_types.h
View file @
2b3d2e21
...
...
@@ -25,16 +25,23 @@ using ::tvm::Tensor;
using
::
tvm
::
Schedule
;
/*! \brief operator pattern used in graph fusion */
enum
OpPatternKind
:
int
{
enum
OpPatternKind
{
// Elementwise operation
kElemWise
=
0
,
// Broadcast operation
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast
=
1
,
// Complex operation, can fuse bcast in input/outputs
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective
=
2
,
// Communicative reduction operator.
kCommReduce
=
3
,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
k
Complex
=
2
,
//
Extern
operation, cannot fuse anything.
k
Extern
=
3
k
OutEWiseFusable
=
4
,
//
Opaque
operation, cannot fuse anything.
k
Opaque
=
8
};
/*! \brief the operator pattern */
...
...
nnvm/python/nnvm/compiler/registry.py
View file @
2b3d2e21
...
...
@@ -3,12 +3,24 @@
import
tvm
class
OpPattern
(
object
):
ELEM_WISE
=
0
"""Operator generic patterns
See Also
--------
top.tag : Contains explaination of the tag type.
"""
# Elementwise operator
ELEMWISE
=
0
# Broadcast operator
BROADCAST
=
1
# Complex means we can fuse elemwise to it
COMPLEX
=
2
# Extern means the op is not fusable
EXTERN
=
3
# Injective mapping
INJECTIVE
=
2
# Comunication
COMM_REDUCE
=
3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE
=
4
# Not fusable opaque op
OPAQUE
=
8
_register_compute
=
tvm
.
get_global_func
(
"nnvm._register_compute"
)
_register_schedule
=
tvm
.
get_global_func
(
"nnvm._register_schedule"
)
...
...
nnvm/python/nnvm/top/nn.py
View file @
2b3d2e21
...
...
@@ -16,8 +16,16 @@ def compute_relu(attrs, inputs, _):
return
topi
.
nn
.
relu
(
inputs
[
0
])
reg
.
register_schedule
(
"relu"
,
_fschedule_broadcast
)
reg
.
register_pattern
(
"relu"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"relu"
,
OpPattern
.
ELEMWISE
)
# leaky_relu
@reg.register_compute
(
"leaky_relu"
)
def
compute_relu
(
attrs
,
inputs
,
_
):
"""Compute definition of relu"""
return
topi
.
nn
.
leaky_relu
(
inputs
[
0
])
reg
.
register_schedule
(
"leaky_relu"
,
_fschedule_broadcast
)
reg
.
register_pattern
(
"leaky_relu"
,
OpPattern
.
ELEMWISE
)
# flatten
@reg.register_compute
(
"flatten"
)
...
...
@@ -26,7 +34,7 @@ def compute_flatten(attrs, inputs, _):
return
topi
.
nn
.
flatten
(
inputs
[
0
])
reg
.
register_schedule
(
"flatten"
,
_fschedule_broadcast
)
reg
.
register_pattern
(
"flatten"
,
OpPattern
.
COMPLEX
)
reg
.
register_pattern
(
"flatten"
,
OpPattern
.
INJECTIVE
)
# softmax
...
...
@@ -46,7 +54,7 @@ def schedule_softmax(_, outs, target):
return
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
# Mark softmax as extern as we do not fuse it in call cases
reg
.
register_pattern
(
"softmax"
,
OpPattern
.
EXTERN
)
reg
.
register_pattern
(
"softmax"
,
OpPattern
.
OPAQUE
)
# dense
...
...
@@ -67,7 +75,7 @@ def schedule_dense(_, outs, target):
return
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
# register extern for now, change me when fusion is enabled.
reg
.
register_pattern
(
"dense"
,
OpPattern
.
EXTERN
)
reg
.
register_pattern
(
"dense"
,
OpPattern
.
OPAQUE
)
# conv
...
...
@@ -105,4 +113,4 @@ def schedule_conv2d(attrs, outs, target):
# naive schedule
return
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
reg
.
register_pattern
(
"conv2d"
,
OpPattern
.
COMPLEX
)
reg
.
register_pattern
(
"conv2d"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
nnvm/python/nnvm/top/tensor.py
View file @
2b3d2e21
...
...
@@ -8,13 +8,15 @@ import topi.cuda
from
..compiler
import
registry
as
reg
from
..compiler
import
OpPattern
def
_schedule_
broadcast
(
_
,
outs
,
target
):
def
_schedule_
injective
(
_
,
outs
,
target
):
"""Generic schedule for binary bcast"""
if
target
==
"cuda"
:
return
topi
.
cuda
.
schedule_
elemwis
e
(
outs
)
return
topi
.
cuda
.
schedule_
injectiv
e
(
outs
)
assert
target
.
startswith
(
"llvm"
)
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
x
=
outs
[
0
]
tvm
.
schedule
.
AutoInlineInjective
(
s
)
s
[
x
]
.
fuse
(
s
[
x
]
.
op
.
axis
)
return
s
def
_compute_binary_scalar
(
f
):
...
...
@@ -42,89 +44,91 @@ def _compute_binary(f):
return
_compute
_fschedule_broadcast
=
tvm
.
convert
(
_schedule_broadcast
)
_fschedule_injective
=
tvm
.
convert
(
_schedule_injective
)
_fschedule_broadcast
=
_fschedule_injective
_fschedule_elemwise
=
_fschedule_injective
# copy
reg
.
register_compute
(
"copy"
,
_compute_unary
(
topi
.
identity
))
reg
.
register_pattern
(
"copy"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"copy"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"copy"
,
_fschedule_broadcast
)
# exp
reg
.
register_compute
(
"exp"
,
_compute_unary
(
topi
.
exp
))
reg
.
register_pattern
(
"exp"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"exp"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"exp"
,
_fschedule_broadcast
)
# sqrt
reg
.
register_compute
(
"sqrt"
,
_compute_unary
(
topi
.
sqrt
))
reg
.
register_pattern
(
"sqrt"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"sqrt"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"sqrt"
,
_fschedule_broadcast
)
# log
reg
.
register_compute
(
"log"
,
_compute_unary
(
topi
.
log
))
reg
.
register_pattern
(
"log"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"log"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"log"
,
_fschedule_broadcast
)
# tanh
reg
.
register_compute
(
"tanh"
,
_compute_unary
(
topi
.
tanh
))
reg
.
register_pattern
(
"tanh"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"tanh"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"tanh"
,
_fschedule_broadcast
)
# negative
reg
.
register_compute
(
"negative"
,
_compute_unary
(
topi
.
negative
))
reg
.
register_pattern
(
"negative"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"negative"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"negative"
,
_fschedule_broadcast
)
# sigmoid
reg
.
register_compute
(
"sigmoid"
,
_compute_unary
(
topi
.
sigmoid
))
reg
.
register_pattern
(
"sigmoid"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"sigmoid"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"sigmoid"
,
_fschedule_broadcast
)
# add_scalar
reg
.
register_compute
(
"__add_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
x
+
y
))
reg
.
register_pattern
(
"__add_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__add_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__add_scalar__"
,
_fschedule_broadcast
)
# sub_calar
reg
.
register_compute
(
"__sub_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
x
-
y
))
reg
.
register_pattern
(
"__sub_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__sub_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__sub_scalar__"
,
_fschedule_broadcast
)
# rsub_scalar
reg
.
register_compute
(
"__rsub_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
y
-
x
))
reg
.
register_pattern
(
"__rsub_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__rsub_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__rsub_scalar__"
,
_fschedule_broadcast
)
# mul_scalar
reg
.
register_compute
(
"__mul_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
x
*
y
))
reg
.
register_pattern
(
"__mul_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__mul_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__mul_scalar__"
,
_fschedule_broadcast
)
# div_scalar
reg
.
register_compute
(
"__div_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
x
/
y
))
reg
.
register_pattern
(
"__div_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__div_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__div_scalar__"
,
_fschedule_broadcast
)
# rdiv_scalar
reg
.
register_compute
(
"__rdiv_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
y
/
x
))
reg
.
register_pattern
(
"__rdiv_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__rdiv_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__rdiv_scalar__"
,
_fschedule_broadcast
)
# pow_scalar
reg
.
register_compute
(
"__pow_scalar__"
,
_compute_binary_scalar
(
tvm
.
power
))
reg
.
register_pattern
(
"__pow_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__pow_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__pow_scalar__"
,
_fschedule_broadcast
)
# rpow_scalar
reg
.
register_compute
(
"__rpow_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
tvm
.
power
(
y
,
x
)))
reg
.
register_pattern
(
"__rpow_scalar__"
,
OpPattern
.
ELEM
_
WISE
)
reg
.
register_pattern
(
"__rpow_scalar__"
,
OpPattern
.
ELEMWISE
)
reg
.
register_schedule
(
"__rpow_scalar__"
,
_fschedule_broadcast
)
# elemwise_add
...
...
nnvm/python/nnvm/top/transform.py
View file @
2b3d2e21
...
...
@@ -37,5 +37,5 @@ def compute_reshape(attrs, inputs, out_info):
oshape
=
out_info
[
0
]
.
shape
x
=
inputs
[
0
]
return
tvm
.
compute
(
oshape
,
lambda
*
i
:
x
(
_flatten_index
(
i
,
oshape
)))
reg
.
register_pattern
(
"reshape"
,
OpPattern
.
COMPLEX
)
reg
.
register_pattern
(
"reshape"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"reshape"
,
_fschedule_broadcast
)
nnvm/src/compiler/graph_fuse.cc
View file @
2b3d2e21
...
...
@@ -71,7 +71,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
ref_count
[
e
.
node_id
]
+=
2
;
}
// Pattern for the subgraph
std
::
vector
<
TOpPattern
>
pattern_vec
(
idx
.
num_nodes
(),
k
Extern
);
std
::
vector
<
TOpPattern
>
pattern_vec
(
idx
.
num_nodes
(),
k
Opaque
);
// Whether node can be fused to parent.
std
::
vector
<
FuseRule
>
fuse_vec
(
idx
.
num_nodes
(),
FuseRule
::
kUknown
);
// Master node id of fusion segment.
...
...
@@ -84,19 +84,21 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
if
(
inode
.
source
->
is_variable
())
{
fuse_vec
[
nid
]
=
FuseRule
::
kRealize
;
continue
;
}
TOpPattern
pt
=
op_pattern
.
get
(
inode
.
source
->
op
(),
k
Extern
);
TOpPattern
pt
=
op_pattern
.
get
(
inode
.
source
->
op
(),
k
Opaque
);
if
(
pt
<=
kBroadcast
)
{
// Try to check if we can fuse to the master.
int
chosen_master
=
-
1
;
bool
ewise
=
inode
.
source
->
num_outputs
()
==
1
;
for
(
const
auto
&
e
:
inode
.
inputs
)
{
if
(
fuse_vec
[
e
.
node_id
]
==
FuseRule
::
kUknown
)
{
TOpPattern
ipt
=
pattern_vec
[
e
.
node_id
];
if
(
ipt
!=
kElemWise
)
ewise
=
false
;
if
(
ipt
<=
k
Broadcast
)
{
if
(
ipt
<=
k
Injective
)
{
fuse_vec
[
e
.
node_id
]
=
FuseRule
::
kFuseToMaster
;
}
else
if
(
ipt
==
kComplex
&&
chosen_master
==
-
1
&&
shape_vec
[
idx
.
entry_id
(
nid
,
0
)]
==
shape_vec
[
idx
.
entry_id
(
e
)])
{
}
else
if
(
ipt
==
kOutEWiseFusable
&&
chosen_master
==
-
1
&&
shape_vec
[
idx
.
entry_id
(
nid
,
0
)]
==
shape_vec
[
idx
.
entry_id
(
e
)])
{
chosen_master
=
master_vec
[
e
.
node_id
];
fuse_vec
[
e
.
node_id
]
=
FuseRule
::
kFuseToMaster
;
}
else
{
...
...
@@ -111,11 +113,27 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
master_vec
[
nid
]
=
chosen_master
;
if
(
chosen_master
!=
-
1
)
{
pt
=
k
Complex
;
pt
=
k
OutEWiseFusable
;
}
else
{
pt
=
ewise
?
kElemWise
:
kBroadcast
;
}
}
else
if
(
pt
==
kInjective
||
pt
==
kCommReduce
)
{
// fuse to the comm reduce or injective
for
(
const
auto
&
e
:
inode
.
inputs
)
{
if
(
fuse_vec
[
e
.
node_id
]
==
FuseRule
::
kUknown
)
{
TOpPattern
ipt
=
pattern_vec
[
e
.
node_id
];
if
(
ipt
<=
kInjective
)
{
fuse_vec
[
e
.
node_id
]
=
FuseRule
::
kFuseToMaster
;
}
else
{
fuse_vec
[
e
.
node_id
]
=
FuseRule
::
kRealize
;
}
}
}
if
(
pt
==
kCommReduce
)
{
master_vec
[
nid
]
=
nid
;
}
}
else
{
// realize
master_vec
[
nid
]
=
nid
;
for
(
const
auto
&
e
:
inode
.
inputs
)
{
if
(
fuse_vec
[
e
.
node_id
]
==
FuseRule
::
kUknown
)
{
...
...
@@ -136,7 +154,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
}
// point to the group root id of each node
std
::
vector
<
int
>
group_vec
(
idx
.
num_nodes
(),
-
1
);
for
(
uint32_t
i
=
idx
.
num_nodes
();
i
!=
0
;
--
i
)
{
...
...
nnvm/src/compiler/layout_transform.cc
View file @
2b3d2e21
...
...
@@ -52,7 +52,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
// use op pattern to decide whether an op is map
auto
is_map_op
=
[
&
](
size_t
nid
)
{
TOpPattern
pt
=
op_pattern
.
get
(
idx
[
nid
].
source
->
op
(),
k
Extern
);
TOpPattern
pt
=
op_pattern
.
get
(
idx
[
nid
].
source
->
op
(),
k
Opaque
);
bool
is_map
=
(
pt
<=
kBroadcast
);
if
(
pt
==
kBroadcast
)
{
for
(
const
auto
&
e
:
idx
[
nid
].
inputs
)
{
...
...
nnvm/tests/python/compiler/test_op_fusion.py
0 → 100644
View file @
2b3d2e21
import
nnvm
import
numpy
as
np
import
tvm
import
topi
from
nnvm
import
symbol
as
sym
from
nnvm.compiler
import
graph_util
,
graph_attr
from
nnvm.testing.config
import
test_ctx_list
def
test_ewise_injective
():
x
=
sym
.
Variable
(
"x"
)
y
=
x
*
2
y
=
sym
.
flatten
(
y
)
+
1
dshape
=
(
10
,
2
,
3
)
shape_dict
=
{
"x"
:
dshape
}
dtype
=
"float32"
target
=
"llvm"
for
target
,
ctx
in
test_ctx_list
():
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
y
,
target
,
shape_dict
)
assert
graph
.
index
.
num_nodes
==
2
m
=
nnvm
.
runtime
.
create
(
graph
,
lib
,
ctx
)
x_np
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
)
m
.
run
(
x
=
x_np
)
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
((
10
,
6
)))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
x_np
.
reshape
(
out
.
shape
)
*
2
+
1
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_conv_ewise_injective
():
x
=
sym
.
Variable
(
"x"
)
y
=
sym
.
conv2d
(
x
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
groups
=
32
,
name
=
"y"
,
padding
=
(
1
,
1
))
y
=
sym
.
flatten
(
y
+
1
)
+
1
dtype
=
"float32"
dshape
=
(
1
,
32
,
18
,
18
)
kshape
=
(
32
,
1
,
3
,
3
)
oshape
=
(
1
,
32
*
18
*
18
)
shape_dict
=
{
"x"
:
dshape
}
for
target
,
ctx
in
test_ctx_list
():
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
y
,
target
,
shape_dict
)
m
=
nnvm
.
runtime
.
create
(
graph
,
lib
,
ctx
)
# print(graph.ir(join_entry_attrs=["shape"]))
assert
graph
.
index
.
num_nodes
==
5
# set input
data
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
))
kernel
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
kshape
)
.
astype
(
dtype
))
bias
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
kshape
[
0
])
.
astype
(
dtype
))
m
.
run
(
x
=
data
,
y_weight
=
kernel
,
y_bias
=
bias
)
# get output
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
oshape
,
dtype
))
c_np
=
topi
.
testing
.
depthwise_conv2d_python_nchw
(
data
.
asnumpy
(),
kernel
.
asnumpy
(),
(
1
,
1
),
'SAME'
)
c_np
=
c_np
+
bias
.
asnumpy
()
.
reshape
(
kshape
[
0
],
1
,
1
)
+
1
c_np
=
c_np
.
reshape
(
c_np
.
shape
[
0
],
np
.
prod
(
c_np
.
shape
[
1
:]))
+
1
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
if
__name__
==
"__main__"
:
test_ewise_injective
()
test_conv_ewise_injective
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment