Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e20ef0d4
Commit
e20ef0d4
authored
Feb 21, 2019
by
Marcus Shawcroft
Committed by
Tianqi Chen
Feb 21, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix pylint 2.2.2 gripes. (#2642)
parent
81334be3
Show whitespace changes
Inline
Side-by-side
Showing
70 changed files
with
186 additions
and
264 deletions
+186
-264
nnvm/python/nnvm/_base.py
+1
-1
nnvm/python/nnvm/attribute.py
+0
-1
nnvm/python/nnvm/compiler/compile_engine.py
+0
-2
nnvm/python/nnvm/frontend/caffe2.py
+1
-3
nnvm/python/nnvm/frontend/coreml.py
+14
-17
nnvm/python/nnvm/frontend/darknet.py
+0
-2
nnvm/python/nnvm/frontend/keras.py
+16
-20
nnvm/python/nnvm/frontend/mxnet.py
+1
-1
nnvm/python/nnvm/frontend/onnx_caffe2_utils.py
+0
-1
nnvm/python/nnvm/frontend/tensorflow.py
+1
-3
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
+2
-0
nnvm/python/nnvm/symbol.py
+3
-14
nnvm/python/nnvm/testing/inception_v3.py
+1
-2
nnvm/python/nnvm/testing/yolo_detection.py
+0
-2
nnvm/python/nnvm/top/attr_dict.py
+3
-4
python/tvm/_ffi/base.py
+0
-1
python/tvm/_ffi/function.py
+2
-3
python/tvm/_ffi/node_generic.py
+8
-8
python/tvm/arith.py
+2
-2
python/tvm/autotvm/measure/executor.py
+2
-3
python/tvm/autotvm/record.py
+4
-4
python/tvm/autotvm/task/space.py
+2
-3
python/tvm/autotvm/task/task.py
+8
-8
python/tvm/autotvm/tuner/tuner.py
+1
-1
python/tvm/container.py
+1
-1
python/tvm/contrib/nvcc.py
+2
-2
python/tvm/contrib/verilog.py
+0
-1
python/tvm/hybrid/parser.py
+5
-5
python/tvm/hybrid/util.py
+1
-1
python/tvm/intrin.py
+1
-1
python/tvm/make.py
+1
-1
python/tvm/ndarray.py
+0
-1
python/tvm/relay/_parser.py
+8
-10
python/tvm/relay/adt.py
+1
-1
python/tvm/relay/backend/compile_engine.py
+0
-2
python/tvm/relay/backend/interpreter.py
+0
-1
python/tvm/relay/build_module.py
+1
-2
python/tvm/relay/frontend/caffe2.py
+1
-4
python/tvm/relay/frontend/common.py
+0
-1
python/tvm/relay/frontend/coreml.py
+13
-16
python/tvm/relay/frontend/keras.py
+18
-20
python/tvm/relay/frontend/mxnet.py
+8
-8
python/tvm/relay/frontend/onnx.py
+0
-1
python/tvm/relay/frontend/tensorflow.py
+1
-3
python/tvm/relay/frontend/tflite.py
+12
-15
python/tvm/relay/op/nn/_nn.py
+6
-6
python/tvm/relay/op/op_attrs.py
+3
-4
python/tvm/relay/testing/inception_v3.py
+1
-2
python/tvm/relay/ty.py
+0
-1
python/tvm/rpc/proxy.py
+1
-1
python/tvm/rpc/tornado_util.py
+0
-1
python/tvm/rpc/tracker.py
+1
-1
python/tvm/schedule.py
+0
-3
python/tvm/stmt.py
+1
-1
python/tvm/tensor.py
+1
-4
topi/python/topi/arm_cpu/bitserial_conv2d.py
+1
-1
topi/python/topi/arm_cpu/conv2d.py
+1
-1
topi/python/topi/cuda/conv2d.py
+1
-2
topi/python/topi/cuda/conv2d_winograd.py
+2
-2
topi/python/topi/cuda/reduction.py
+1
-1
topi/python/topi/nn/bitserial_conv2d.py
+2
-2
topi/python/topi/nn/conv2d.py
+2
-3
topi/python/topi/testing/upsampling_python.py
+1
-2
topi/python/topi/x86/bitserial_conv2d.py
+1
-1
topi/python/topi/x86/conv2d.py
+2
-3
vta/python/vta/environment.py
+1
-2
vta/python/vta/graph.py
+1
-1
vta/python/vta/intrin.py
+1
-1
vta/python/vta/ir_pass.py
+7
-12
vta/python/vta/top/vta_conv2d.py
+2
-3
No files found.
nnvm/python/nnvm/_base.py
View file @
e20ef0d4
...
@@ -31,7 +31,7 @@ else:
...
@@ -31,7 +31,7 @@ else:
class
NNVMError
(
Exception
):
class
NNVMError
(
Exception
):
"""Error that will be throwed by all nnvm functions"""
"""Error that will be throwed by all nnvm functions"""
pass
def
_load_lib
():
def
_load_lib
():
"""Load libary by searching possible path."""
"""Load libary by searching possible path."""
...
...
nnvm/python/nnvm/attribute.py
View file @
e20ef0d4
...
@@ -42,7 +42,6 @@ class AttrScope(object):
...
@@ -42,7 +42,6 @@ class AttrScope(object):
if
attr
:
if
attr
:
ret
.
update
(
attr
)
ret
.
update
(
attr
)
return
ret
return
ret
else
:
return
attr
return
attr
def
__enter__
(
self
):
def
__enter__
(
self
):
...
...
nnvm/python/nnvm/compiler/compile_engine.py
View file @
e20ef0d4
...
@@ -23,13 +23,11 @@ class GraphKey(tvm.node.NodeBase):
...
@@ -23,13 +23,11 @@ class GraphKey(tvm.node.NodeBase):
@tvm.register_node
@tvm.register_node
class
GraphCacheEntry
(
tvm
.
node
.
NodeBase
):
class
GraphCacheEntry
(
tvm
.
node
.
NodeBase
):
"""CacheEntry of compilation into a TVM Function"""
"""CacheEntry of compilation into a TVM Function"""
pass
@tvm.register_node
@tvm.register_node
class
GraphFunc
(
tvm
.
node
.
NodeBase
):
class
GraphFunc
(
tvm
.
node
.
NodeBase
):
"""Compiled result of a graph into a TVM Function"""
"""Compiled result of a graph into a TVM Function"""
pass
class
Engine
(
object
):
class
Engine
(
object
):
...
...
nnvm/python/nnvm/frontend/caffe2.py
View file @
e20ef0d4
...
@@ -73,7 +73,6 @@ class Caffe2OpConverter(object):
...
@@ -73,7 +73,6 @@ class Caffe2OpConverter(object):
if
hasattr
(
cls
,
'_impl'
):
if
hasattr
(
cls
,
'_impl'
):
return
getattr
(
cls
,
'_impl'
)
return
getattr
(
cls
,
'_impl'
)
else
:
raise
NotImplementedError
(
'{} not implemented'
.
format
(
raise
NotImplementedError
(
'{} not implemented'
.
format
(
cls
.
__name__
))
cls
.
__name__
))
...
@@ -175,9 +174,8 @@ class Concat(Caffe2OpConverter):
...
@@ -175,9 +174,8 @@ class Concat(Caffe2OpConverter):
order
=
order
if
isinstance
(
order
,
str
)
else
order
.
decode
(
'UTF-8'
)
order
=
order
if
isinstance
(
order
,
str
)
else
order
.
decode
(
'UTF-8'
)
if
order
==
'NCHW'
:
if
order
==
'NCHW'
:
return
1
return
1
el
if
order
==
'NHWC'
:
if
order
==
'NHWC'
:
return
3
return
3
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Unsupported storage order: {} in caffe2"
.
format
(
order
))
"Unsupported storage order: {} in caffe2"
.
format
(
order
))
...
...
nnvm/python/nnvm/frontend/coreml.py
View file @
e20ef0d4
...
@@ -98,33 +98,33 @@ def ActivationParams(op, insym, symtab):
...
@@ -98,33 +98,33 @@ def ActivationParams(op, insym, symtab):
par
=
getattr
(
op
,
whichActivation
)
par
=
getattr
(
op
,
whichActivation
)
if
whichActivation
==
'linear'
:
if
whichActivation
==
'linear'
:
return
_sym
.
__add_scalar__
(
_sym
.
__mul_scalar__
(
insym
,
scalar
=
par
.
alpha
),
scalar
=
par
.
beta
)
return
_sym
.
__add_scalar__
(
_sym
.
__mul_scalar__
(
insym
,
scalar
=
par
.
alpha
),
scalar
=
par
.
beta
)
el
if
whichActivation
==
'ReLU'
:
if
whichActivation
==
'ReLU'
:
return
_sym
.
relu
(
insym
)
return
_sym
.
relu
(
insym
)
el
if
whichActivation
==
'leakyReLU'
:
if
whichActivation
==
'leakyReLU'
:
return
_sym
.
leaky_relu
(
insym
,
alpha
=
par
.
alpha
)
return
_sym
.
leaky_relu
(
insym
,
alpha
=
par
.
alpha
)
el
if
whichActivation
==
'thresholdedReLU'
:
if
whichActivation
==
'thresholdedReLU'
:
alpha_tensor
=
_sym
.
full_like
(
insym
,
fill_value
=
float
(
par
.
alpha
))
alpha_tensor
=
_sym
.
full_like
(
insym
,
fill_value
=
float
(
par
.
alpha
))
return
_sym
.
elemwise_mul
(
insym
,
_sym
.
greater
(
insym
,
alpha_tensor
))
return
_sym
.
elemwise_mul
(
insym
,
_sym
.
greater
(
insym
,
alpha_tensor
))
el
if
whichActivation
==
'PReLU'
:
if
whichActivation
==
'PReLU'
:
return
_sym
.
prelu
(
insym
,
alpha
=
par
.
alpha
)
return
_sym
.
prelu
(
insym
,
alpha
=
par
.
alpha
)
el
if
whichActivation
==
'tanh'
:
if
whichActivation
==
'tanh'
:
return
_sym
.
tanh
(
insym
)
return
_sym
.
tanh
(
insym
)
el
if
whichActivation
==
'scaledTanh'
:
if
whichActivation
==
'scaledTanh'
:
return
_sym
.
__mul_scalar__
(
_sym
.
tanh
(
_sym
.
__mul_scalar__
(
return
_sym
.
__mul_scalar__
(
_sym
.
tanh
(
_sym
.
__mul_scalar__
(
insym
,
scalar
=
par
.
beta
)),
scalar
=
par
.
alpha
)
insym
,
scalar
=
par
.
beta
)),
scalar
=
par
.
alpha
)
el
if
whichActivation
==
'sigmoid'
:
if
whichActivation
==
'sigmoid'
:
return
_sym
.
sigmoid
(
insym
)
return
_sym
.
sigmoid
(
insym
)
el
if
whichActivation
==
'sigmoidHard'
:
if
whichActivation
==
'sigmoidHard'
:
transformX
=
(
par
.
alpha
*
insym
)
+
par
.
beta
transformX
=
(
par
.
alpha
*
insym
)
+
par
.
beta
return
_sym
.
clip
(
transformX
,
a_min
=
0
,
a_max
=
1
)
return
_sym
.
clip
(
transformX
,
a_min
=
0
,
a_max
=
1
)
el
if
whichActivation
==
'ELU'
:
if
whichActivation
==
'ELU'
:
return
_sym
.
__mul_scalar__
(
_sym
.
__add_scalar__
(
return
_sym
.
__mul_scalar__
(
_sym
.
__add_scalar__
(
_sym
.
exp
(
insym
),
scalar
=-
1
),
scalar
=
par
.
alpha
)
_sym
.
exp
(
insym
),
scalar
=-
1
),
scalar
=
par
.
alpha
)
el
if
whichActivation
==
'softsign'
:
if
whichActivation
==
'softsign'
:
return
insym
/
(
1
+
(
_sym
.
relu
(
insym
)
+
_sym
.
relu
(
_sym
.
negative
(
insym
))))
return
insym
/
(
1
+
(
_sym
.
relu
(
insym
)
+
_sym
.
relu
(
_sym
.
negative
(
insym
))))
el
if
whichActivation
==
'softplus'
:
if
whichActivation
==
'softplus'
:
return
_sym
.
log
(
_sym
.
__add_scalar__
(
_sym
.
exp
(
insym
),
scalar
=
1
))
return
_sym
.
log
(
_sym
.
__add_scalar__
(
_sym
.
exp
(
insym
),
scalar
=
1
))
el
if
whichActivation
==
'parametricSoftplus'
:
if
whichActivation
==
'parametricSoftplus'
:
alpha
=
list
(
par
.
alpha
.
floatValue
)
alpha
=
list
(
par
.
alpha
.
floatValue
)
beta
=
list
(
par
.
alpha
.
floatValue
)
beta
=
list
(
par
.
alpha
.
floatValue
)
if
len
(
alpha
)
==
1
:
if
len
(
alpha
)
==
1
:
...
@@ -136,7 +136,6 @@ def ActivationParams(op, insym, symtab):
...
@@ -136,7 +136,6 @@ def ActivationParams(op, insym, symtab):
betasym
=
symtab
.
new_const
(
beta
)
betasym
=
symtab
.
new_const
(
beta
)
return
_sym
.
broadcast_mul
(
_sym
.
log
(
_sym
.
broadcast_add
(
return
_sym
.
broadcast_mul
(
_sym
.
log
(
_sym
.
broadcast_add
(
_sym
.
exp
(
insym
),
betasym
)),
alphasym
)
_sym
.
exp
(
insym
),
betasym
)),
alphasym
)
else
:
raise
NotImplementedError
(
'
%
s not implemented'
%
whichActivation
)
raise
NotImplementedError
(
'
%
s not implemented'
%
whichActivation
)
def
ScaleLayerParams
(
op
,
insym
,
symtab
):
def
ScaleLayerParams
(
op
,
insym
,
symtab
):
...
@@ -157,9 +156,8 @@ def PoolingLayerParams(op, insym, symtab):
...
@@ -157,9 +156,8 @@ def PoolingLayerParams(op, insym, symtab):
if
op
.
globalPooling
:
if
op
.
globalPooling
:
if
op
.
type
==
0
:
if
op
.
type
==
0
:
return
_sym
.
global_max_pool2d
(
insym
)
return
_sym
.
global_max_pool2d
(
insym
)
el
if
op
.
type
==
1
:
if
op
.
type
==
1
:
return
_sym
.
global_avg_pool2d
(
insym
)
return
_sym
.
global_avg_pool2d
(
insym
)
else
:
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
else
:
else
:
...
@@ -190,9 +188,8 @@ def PoolingLayerParams(op, insym, symtab):
...
@@ -190,9 +188,8 @@ def PoolingLayerParams(op, insym, symtab):
if
op
.
type
==
0
:
if
op
.
type
==
0
:
return
_sym
.
max_pool2d
(
insym
,
**
params
)
return
_sym
.
max_pool2d
(
insym
,
**
params
)
el
if
op
.
type
==
1
:
if
op
.
type
==
1
:
return
_sym
.
avg_pool2d
(
insym
,
**
params
)
return
_sym
.
avg_pool2d
(
insym
,
**
params
)
else
:
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
def
SoftmaxLayerParams
(
op
,
insym
,
symtab
):
def
SoftmaxLayerParams
(
op
,
insym
,
symtab
):
...
...
nnvm/python/nnvm/frontend/darknet.py
View file @
e20ef0d4
...
@@ -921,8 +921,6 @@ class GraphProto(object):
...
@@ -921,8 +921,6 @@ class GraphProto(object):
if
layer_num
!=
self
.
net
.
n
-
1
:
if
layer_num
!=
self
.
net
.
n
-
1
:
self
.
_outs
.
insert
(
0
,
sym
)
self
.
_outs
.
insert
(
0
,
sym
)
return
def
from_darknet
(
self
):
def
from_darknet
(
self
):
"""To convert the darknet symbol to nnvm symbols."""
"""To convert the darknet symbol to nnvm symbols."""
for
i
in
range
(
self
.
net
.
n
):
for
i
in
range
(
self
.
net
.
n
):
...
...
nnvm/python/nnvm/frontend/keras.py
View file @
e20ef0d4
...
@@ -47,34 +47,33 @@ def _convert_activation(insym, keras_layer, _):
...
@@ -47,34 +47,33 @@ def _convert_activation(insym, keras_layer, _):
beta
=
keras_layer
.
beta
if
hasattr
(
keras_layer
,
"beta"
)
else
0
beta
=
keras_layer
.
beta
if
hasattr
(
keras_layer
,
"beta"
)
else
0
return
_sym
.
__add_scalar__
(
_sym
.
__mul_scalar__
(
insym
,
\
return
_sym
.
__add_scalar__
(
_sym
.
__mul_scalar__
(
insym
,
\
scalar
=
alpha
),
scalar
=
beta
)
scalar
=
alpha
),
scalar
=
beta
)
el
if
act_type
==
'softmax'
:
if
act_type
==
'softmax'
:
return
_sym
.
softmax
(
insym
,
axis
=
1
)
return
_sym
.
softmax
(
insym
,
axis
=
1
)
el
if
act_type
==
'sigmoid'
:
if
act_type
==
'sigmoid'
:
return
_sym
.
sigmoid
(
insym
)
return
_sym
.
sigmoid
(
insym
)
el
if
act_type
==
'tanh'
:
if
act_type
==
'tanh'
:
return
_sym
.
tanh
(
insym
)
return
_sym
.
tanh
(
insym
)
el
if
act_type
==
'relu'
:
if
act_type
==
'relu'
:
return
_sym
.
relu
(
insym
)
return
_sym
.
relu
(
insym
)
el
if
act_type
==
'softplus'
:
if
act_type
==
'softplus'
:
return
_sym
.
log
(
_sym
.
__add_scalar__
(
_sym
.
exp
(
insym
),
scalar
=
1
))
return
_sym
.
log
(
_sym
.
__add_scalar__
(
_sym
.
exp
(
insym
),
scalar
=
1
))
el
if
act_type
==
'elu'
:
if
act_type
==
'elu'
:
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
"alpha"
)
else
1
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
"alpha"
)
else
1
return
_get_elu
(
insym
,
alpha
)
return
_get_elu
(
insym
,
alpha
)
el
if
act_type
==
'selu'
:
if
act_type
==
'selu'
:
# Alpha, Gamma values, obtained from https://arxiv.org/abs/1706.02515
# Alpha, Gamma values, obtained from https://arxiv.org/abs/1706.02515
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
"alpha"
)
\
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
"alpha"
)
\
else
1.6732632423543772848170429916717
else
1.6732632423543772848170429916717
gamma
=
keras_layer
.
gamma
if
hasattr
(
keras_layer
,
"gamma"
)
\
gamma
=
keras_layer
.
gamma
if
hasattr
(
keras_layer
,
"gamma"
)
\
else
1.0507009873554804934193349852946
else
1.0507009873554804934193349852946
return
gamma
*
_get_elu
(
insym
,
alpha
)
return
gamma
*
_get_elu
(
insym
,
alpha
)
el
if
act_type
==
'relu6'
:
if
act_type
==
'relu6'
:
return
_sym
.
clip
(
insym
,
a_min
=
0
,
a_max
=
6
)
return
_sym
.
clip
(
insym
,
a_min
=
0
,
a_max
=
6
)
el
if
act_type
==
'softsign'
:
if
act_type
==
'softsign'
:
return
insym
/
(
1
+
(
_sym
.
relu
(
insym
)
+
_sym
.
relu
(
_sym
.
negative
(
insym
))))
return
insym
/
(
1
+
(
_sym
.
relu
(
insym
)
+
_sym
.
relu
(
_sym
.
negative
(
insym
))))
el
if
act_type
==
'hard_sigmoid'
:
if
act_type
==
'hard_sigmoid'
:
transformX
=
(
0.2
*
insym
)
+
0.5
transformX
=
(
0.2
*
insym
)
+
0.5
return
_sym
.
clip
(
transformX
,
a_min
=
0
,
a_max
=
1
)
return
_sym
.
clip
(
transformX
,
a_min
=
0
,
a_max
=
1
)
else
:
raise
TypeError
(
"Unsupported activation type : {}"
.
format
(
act_type
))
raise
TypeError
(
"Unsupported activation type : {}"
.
format
(
act_type
))
...
@@ -84,12 +83,12 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
...
@@ -84,12 +83,12 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
if
keras_layer
.
max_value
:
if
keras_layer
.
max_value
:
return
_sym
.
clip
(
insym
,
a_min
=
0
,
a_max
=
keras_layer
.
max_value
)
return
_sym
.
clip
(
insym
,
a_min
=
0
,
a_max
=
keras_layer
.
max_value
)
return
_sym
.
relu
(
insym
)
return
_sym
.
relu
(
insym
)
el
if
act_type
==
'LeakyReLU'
:
if
act_type
==
'LeakyReLU'
:
return
_sym
.
leaky_relu
(
insym
,
alpha
=
keras_layer
.
alpha
)
return
_sym
.
leaky_relu
(
insym
,
alpha
=
keras_layer
.
alpha
)
el
if
act_type
==
'ELU'
:
if
act_type
==
'ELU'
:
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
"alpha"
)
else
1
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
"alpha"
)
else
1
return
_get_elu
(
insym
,
alpha
)
return
_get_elu
(
insym
,
alpha
)
el
if
act_type
==
'PReLU'
:
if
act_type
==
'PReLU'
:
assert
hasattr
(
keras_layer
,
"alpha"
),
\
assert
hasattr
(
keras_layer
,
"alpha"
),
\
"alpha required for PReLU."
"alpha required for PReLU."
_check_data_format
(
keras_layer
)
_check_data_format
(
keras_layer
)
...
@@ -97,11 +96,10 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
...
@@ -97,11 +96,10 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
return
-
symtab
.
new_const
(
keras_layer
.
get_weights
()[
0
]
\
return
-
symtab
.
new_const
(
keras_layer
.
get_weights
()[
0
]
\
.
transpose
(
np
.
roll
(
range
(
size
),
1
)))
\
.
transpose
(
np
.
roll
(
range
(
size
),
1
)))
\
*
_sym
.
relu
(
-
insym
)
+
_sym
.
relu
(
insym
)
*
_sym
.
relu
(
-
insym
)
+
_sym
.
relu
(
insym
)
el
if
act_type
==
'ThresholdedReLU'
:
if
act_type
==
'ThresholdedReLU'
:
theta
=
keras_layer
.
theta
if
hasattr
(
keras_layer
,
"theta"
)
else
1.0
theta
=
keras_layer
.
theta
if
hasattr
(
keras_layer
,
"theta"
)
else
1.0
theta_tensor
=
_sym
.
full_like
(
insym
[
0
],
fill_value
=
float
(
theta
))
theta_tensor
=
_sym
.
full_like
(
insym
[
0
],
fill_value
=
float
(
theta
))
return
_sym
.
elemwise_mul
(
insym
[
0
],
_sym
.
greater
(
insym
[
0
],
theta_tensor
,
out_type
=
"float32"
))
return
_sym
.
elemwise_mul
(
insym
[
0
],
_sym
.
greater
(
insym
[
0
],
theta_tensor
,
out_type
=
"float32"
))
else
:
raise
TypeError
(
"Unsupported advanced activation type : {}"
.
format
(
act_type
))
raise
TypeError
(
"Unsupported advanced activation type : {}"
.
format
(
act_type
))
...
@@ -280,9 +278,8 @@ def _convert_pooling(insym, keras_layer, symtab):
...
@@ -280,9 +278,8 @@ def _convert_pooling(insym, keras_layer, symtab):
# global pool in keras = global pool + flatten in nnvm
# global pool in keras = global pool + flatten in nnvm
if
pool_type
==
'GlobalMaxPooling2D'
:
if
pool_type
==
'GlobalMaxPooling2D'
:
return
_convert_flatten
(
_sym
.
global_max_pool2d
(
insym
),
keras_layer
,
symtab
)
return
_convert_flatten
(
_sym
.
global_max_pool2d
(
insym
),
keras_layer
,
symtab
)
el
if
pool_type
==
'GlobalAveragePooling2D'
:
if
pool_type
==
'GlobalAveragePooling2D'
:
return
_convert_flatten
(
_sym
.
global_avg_pool2d
(
insym
),
keras_layer
,
symtab
)
return
_convert_flatten
(
_sym
.
global_avg_pool2d
(
insym
),
keras_layer
,
symtab
)
else
:
pool_h
,
pool_w
=
keras_layer
.
pool_size
pool_h
,
pool_w
=
keras_layer
.
pool_size
stride_h
,
stride_w
=
keras_layer
.
strides
stride_h
,
stride_w
=
keras_layer
.
strides
params
=
{
'pool_size'
:
[
pool_h
,
pool_w
],
params
=
{
'pool_size'
:
[
pool_h
,
pool_w
],
...
@@ -300,10 +297,9 @@ def _convert_pooling(insym, keras_layer, symtab):
...
@@ -300,10 +297,9 @@ def _convert_pooling(insym, keras_layer, symtab):
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
keras_layer
.
padding
))
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
keras_layer
.
padding
))
if
pool_type
==
'MaxPooling2D'
:
if
pool_type
==
'MaxPooling2D'
:
return
_sym
.
max_pool2d
(
insym
,
**
params
)
return
_sym
.
max_pool2d
(
insym
,
**
params
)
el
if
pool_type
==
'AveragePooling2D'
:
if
pool_type
==
'AveragePooling2D'
:
# TODO: in keras, padded zeros are not calculated
# TODO: in keras, padded zeros are not calculated
return
_sym
.
avg_pool2d
(
insym
,
**
params
)
return
_sym
.
avg_pool2d
(
insym
,
**
params
)
else
:
raise
TypeError
(
"Unsupported pooling type : {}"
.
format
(
keras_layer
))
raise
TypeError
(
"Unsupported pooling type : {}"
.
format
(
keras_layer
))
...
...
nnvm/python/nnvm/frontend/mxnet.py
View file @
e20ef0d4
...
@@ -424,7 +424,7 @@ def _topo_sort(symbol):
...
@@ -424,7 +424,7 @@ def _topo_sort(symbol):
if
childs
is
None
:
if
childs
is
None
:
dep_cnts
[
name
]
=
0
dep_cnts
[
name
]
=
0
else
:
else
:
dep_cnts
[
name
]
=
len
(
set
([
c
.
attr
(
'name'
)
for
c
in
childs
])
)
dep_cnts
[
name
]
=
len
(
{
c
.
attr
(
'name'
)
for
c
in
childs
}
)
for
child
in
childs
:
for
child
in
childs
:
child_name
=
child
.
attr
(
'name'
)
child_name
=
child
.
attr
(
'name'
)
if
child_name
not
in
deps
:
if
child_name
not
in
deps
:
...
...
nnvm/python/nnvm/frontend/onnx_caffe2_utils.py
View file @
e20ef0d4
...
@@ -9,7 +9,6 @@ def dimension_picker(prefix, surfix=''):
...
@@ -9,7 +9,6 @@ def dimension_picker(prefix, surfix=''):
kernel
=
attr
[
'kernel_shape'
]
kernel
=
attr
[
'kernel_shape'
]
if
len
(
kernel
)
==
2
:
if
len
(
kernel
)
==
2
:
return
prefix
+
'2d'
+
surfix
return
prefix
+
'2d'
+
surfix
else
:
raise
NotImplementedError
(
"Only 2d kernel supported."
)
raise
NotImplementedError
(
"Only 2d kernel supported."
)
return
_impl
return
_impl
...
...
nnvm/python/nnvm/frontend/tensorflow.py
View file @
e20ef0d4
...
@@ -68,7 +68,6 @@ def _dimension_picker(prefix, surfix=''):
...
@@ -68,7 +68,6 @@ def _dimension_picker(prefix, surfix=''):
kernel
=
attr
[
'kernel_shape'
]
kernel
=
attr
[
'kernel_shape'
]
if
len
(
kernel
)
==
2
:
if
len
(
kernel
)
==
2
:
return
prefix
+
'2d'
+
surfix
return
prefix
+
'2d'
+
surfix
else
:
raise
NotImplementedError
(
"Only 2d kernel supported."
)
raise
NotImplementedError
(
"Only 2d kernel supported."
)
return
_impl
return
_impl
...
@@ -433,7 +432,6 @@ def _reshape():
...
@@ -433,7 +432,6 @@ def _reshape():
op_name
=
"reshape"
,
op_name
=
"reshape"
,
extras
=
{
'shape'
:
tuple
(
params_new
[
0
]
.
asnumpy
()
.
flatten
())},
extras
=
{
'shape'
:
tuple
(
params_new
[
0
]
.
asnumpy
()
.
flatten
())},
ignores
=
[
'Tshape'
])(
inputs
,
attr
)
ignores
=
[
'Tshape'
])(
inputs
,
attr
)
else
:
raise
RuntimeError
(
"Reshape with dynamic shape input not supported yet."
)
raise
RuntimeError
(
"Reshape with dynamic shape input not supported yet."
)
return
_impl
return
_impl
...
@@ -1394,7 +1392,7 @@ class GraphProto(object):
...
@@ -1394,7 +1392,7 @@ class GraphProto(object):
self
.
_nodes
[
name
]
=
_sym
.
Variable
(
name
=
name
,
self
.
_nodes
[
name
]
=
_sym
.
Variable
(
name
=
name
,
shape
=
self
.
_params
[
name
]
.
shape
)
shape
=
self
.
_params
[
name
]
.
shape
)
else
:
else
:
if
key
!=
'dtype'
and
key
!=
'_output_shapes'
and
key
!=
'_class'
:
if
key
not
in
(
'dtype'
,
'_output_shapes'
,
'_class'
)
:
raise
NotImplementedError
\
raise
NotImplementedError
\
(
"Other attributes for a Const(param) Node {} ? ."
.
format
(
key
))
(
"Other attributes for a Const(param) Node {} ? ."
.
format
(
key
))
...
...
nnvm/python/nnvm/frontend/util/tensorflow_parser.py
View file @
e20ef0d4
...
@@ -115,6 +115,8 @@ class TFParser(object):
...
@@ -115,6 +115,8 @@ class TFParser(object):
"""TODO: Load checkpoint model."""
"""TODO: Load checkpoint model."""
raise
RuntimeError
(
"InputConfiguration: Loading tf checkpoint model is "
raise
RuntimeError
(
"InputConfiguration: Loading tf checkpoint model is "
"not supported yet."
)
"not supported yet."
)
# pylint: disable=unreachable
return
0
def
parse
(
self
):
def
parse
(
self
):
"""Parse tensorflow models: checkpoints, saved models, and single pb
"""Parse tensorflow models: checkpoints, saved models, and single pb
...
...
nnvm/python/nnvm/symbol.py
View file @
e20ef0d4
...
@@ -50,9 +50,8 @@ class Symbol(SymbolBase):
...
@@ -50,9 +50,8 @@ class Symbol(SymbolBase):
"""x.__add__(y) <=> x+y"""
"""x.__add__(y) <=> x+y"""
if
isinstance
(
other
,
Symbol
):
if
isinstance
(
other
,
Symbol
):
return
__add_symbol__
(
self
,
other
)
return
__add_symbol__
(
self
,
other
)
el
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__add_scalar__
(
self
,
scalar
=
other
)
return
__add_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__radd__
(
self
,
other
):
def
__radd__
(
self
,
other
):
...
@@ -64,13 +63,11 @@ class Symbol(SymbolBase):
...
@@ -64,13 +63,11 @@ class Symbol(SymbolBase):
return
__sub_symbol__
(
self
,
other
)
return
__sub_symbol__
(
self
,
other
)
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__sub_scalar__
(
self
,
scalar
=
other
)
return
__sub_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__rsub__
(
self
,
other
):
def
__rsub__
(
self
,
other
):
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__rsub_scalar__
(
self
,
scalar
=
other
)
return
__rsub_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__mul__
(
self
,
other
):
def
__mul__
(
self
,
other
):
...
@@ -79,7 +76,6 @@ class Symbol(SymbolBase):
...
@@ -79,7 +76,6 @@ class Symbol(SymbolBase):
return
__mul_symbol__
(
self
,
other
)
return
__mul_symbol__
(
self
,
other
)
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__mul_scalar__
(
self
,
scalar
=
other
)
return
__mul_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__rmul__
(
self
,
other
):
def
__rmul__
(
self
,
other
):
...
@@ -91,27 +87,23 @@ class Symbol(SymbolBase):
...
@@ -91,27 +87,23 @@ class Symbol(SymbolBase):
return
__div_symbol__
(
self
,
other
)
return
__div_symbol__
(
self
,
other
)
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__div_scalar__
(
self
,
scalar
=
other
)
return
__div_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__rdiv__
(
self
,
other
):
def
__rdiv__
(
self
,
other
):
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__rdiv_scalar__
(
self
,
scalar
=
other
)
return
__rdiv_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__lshift__
(
self
,
other
):
def
__lshift__
(
self
,
other
):
"""x.__lshift__(y) <=> x << y"""
"""x.__lshift__(y) <=> x << y"""
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__lshift_scalar__
(
self
,
scalar
=
other
)
return
__lshift_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__rshift__
(
self
,
other
):
def
__rshift__
(
self
,
other
):
"""x.__rshift__(y) <=> x >> y"""
"""x.__rshift__(y) <=> x >> y"""
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__rshift_scalar__
(
self
,
scalar
=
other
)
return
__rshift_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__truediv__
(
self
,
other
):
def
__truediv__
(
self
,
other
):
...
@@ -126,13 +118,11 @@ class Symbol(SymbolBase):
...
@@ -126,13 +118,11 @@ class Symbol(SymbolBase):
return
__pow_symbol__
(
self
,
other
)
return
__pow_symbol__
(
self
,
other
)
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__pow_scalar__
(
self
,
scalar
=
other
)
return
__pow_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__rpow__
(
self
,
other
):
def
__rpow__
(
self
,
other
):
if
isinstance
(
other
,
_Number
):
if
isinstance
(
other
,
_Number
):
return
__rpow_scalar__
(
self
,
scalar
=
other
)
return
__rpow_scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
raise
TypeError
(
'type
%
s not supported'
%
str
(
type
(
other
)))
def
__neg__
(
self
):
def
__neg__
(
self
):
...
@@ -238,11 +228,10 @@ class Symbol(SymbolBase):
...
@@ -238,11 +228,10 @@ class Symbol(SymbolBase):
"""internal function to get list option"""
"""internal function to get list option"""
if
option
==
'all'
:
if
option
==
'all'
:
return
_ctypes
.
c_int
(
0
)
return
_ctypes
.
c_int
(
0
)
el
if
option
==
'read_only'
:
if
option
==
'read_only'
:
return
_ctypes
.
c_int
(
1
)
return
_ctypes
.
c_int
(
1
)
el
if
option
==
'aux_state'
:
if
option
==
'aux_state'
:
return
_ctypes
.
c_int
(
2
)
return
_ctypes
.
c_int
(
2
)
else
:
raise
ValueError
(
"option need to be in {'all', 'read_only, 'aux_state'}"
)
raise
ValueError
(
"option need to be in {'all', 'read_only, 'aux_state'}"
)
def
list_input_variables
(
self
,
option
=
'all'
):
def
list_input_variables
(
self
,
option
=
'all'
):
...
...
nnvm/python/nnvm/testing/inception_v3.py
View file @
e20ef0d4
...
@@ -23,10 +23,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
...
@@ -23,10 +23,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
def
Pooling
(
data
,
kernel
,
stride
,
pad
,
pool_type
,
name
):
def
Pooling
(
data
,
kernel
,
stride
,
pad
,
pool_type
,
name
):
if
pool_type
==
'max'
:
if
pool_type
==
'max'
:
return
sym
.
max_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
,
name
=
name
)
return
sym
.
max_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
,
name
=
name
)
el
if
pool_type
==
'avg'
:
if
pool_type
==
'avg'
:
return
sym
.
avg_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
,
name
=
name
,
return
sym
.
avg_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
,
name
=
name
,
count_include_pad
=
True
)
count_include_pad
=
True
)
else
:
raise
ValueError
(
"Invalid pooling type: "
+
pool_type
)
raise
ValueError
(
"Invalid pooling type: "
+
pool_type
)
def
Inception7A
(
data
,
def
Inception7A
(
data
,
...
...
nnvm/python/nnvm/testing/yolo_detection.py
View file @
e20ef0d4
...
@@ -88,7 +88,6 @@ def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets):
...
@@ -88,7 +88,6 @@ def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets):
before_correct_dets
.
append
(
detection
)
before_correct_dets
.
append
(
detection
)
dets
.
extend
(
_correct_boxes
(
before_correct_dets
,
im_shape
[
0
],
im_shape
[
1
],
dets
.
extend
(
_correct_boxes
(
before_correct_dets
,
im_shape
[
0
],
im_shape
[
1
],
net_shape
[
0
],
net_shape
[
1
],
relative
))
net_shape
[
0
],
net_shape
[
1
],
relative
))
return
def
_get_region_detections
(
l
,
im_shape
,
net_shape
,
thresh
,
relative
,
dets
):
def
_get_region_detections
(
l
,
im_shape
,
net_shape
,
thresh
,
relative
,
dets
):
data
=
l
[
'output'
]
data
=
l
[
'output'
]
...
@@ -114,7 +113,6 @@ def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets):
...
@@ -114,7 +113,6 @@ def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets):
_correct_boxes
(
before_correct_dets
,
im_shape
[
0
],
im_shape
[
1
],
_correct_boxes
(
before_correct_dets
,
im_shape
[
0
],
im_shape
[
1
],
net_shape
[
0
],
net_shape
[
1
],
relative
)
net_shape
[
0
],
net_shape
[
1
],
relative
)
dets
.
extend
(
before_correct_dets
)
dets
.
extend
(
before_correct_dets
)
return
def
fill_network_boxes
(
net_shape
,
im_shape
,
def
fill_network_boxes
(
net_shape
,
im_shape
,
thresh
,
relative
,
tvm_out
):
thresh
,
relative
,
tvm_out
):
...
...
nnvm/python/nnvm/top/attr_dict.py
View file @
e20ef0d4
...
@@ -129,13 +129,12 @@ class AttrDict(object):
...
@@ -129,13 +129,12 @@ class AttrDict(object):
lowercase
=
self
[
key
]
.
lower
()
lowercase
=
self
[
key
]
.
lower
()
if
lowercase
==
"1"
:
if
lowercase
==
"1"
:
return
True
return
True
el
if
lowercase
==
"0"
:
if
lowercase
==
"0"
:
return
False
return
False
el
if
lowercase
==
"true"
:
if
lowercase
==
"true"
:
return
True
return
True
el
if
lowercase
==
"false"
:
if
lowercase
==
"false"
:
return
False
return
False
else
:
raise
ValueError
(
"Wrong bool format for key
%
s"
%
key
)
raise
ValueError
(
"Wrong bool format for key
%
s"
%
key
)
def
get_str
(
self
,
key
):
def
get_str
(
self
,
key
):
...
...
python/tvm/_ffi/base.py
View file @
e20ef0d4
...
@@ -32,7 +32,6 @@ else:
...
@@ -32,7 +32,6 @@ else:
class
TVMError
(
Exception
):
class
TVMError
(
Exception
):
"""Error thrown by TVM function"""
"""Error thrown by TVM function"""
pass
def
_load_lib
():
def
_load_lib
():
...
...
python/tvm/_ffi/function.py
View file @
e20ef0d4
...
@@ -51,7 +51,6 @@ class Function(_FunctionBase):
...
@@ -51,7 +51,6 @@ class Function(_FunctionBase):
tvm.register_func: How to register global function.
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
tvm.get_global_func: How to get global function.
"""
"""
pass
class
ModuleBase
(
object
):
class
ModuleBase
(
object
):
...
@@ -207,10 +206,10 @@ def get_global_func(name, allow_missing=False):
...
@@ -207,10 +206,10 @@ def get_global_func(name, allow_missing=False):
check_call
(
_LIB
.
TVMFuncGetGlobal
(
c_str
(
name
),
ctypes
.
byref
(
handle
)))
check_call
(
_LIB
.
TVMFuncGetGlobal
(
c_str
(
name
),
ctypes
.
byref
(
handle
)))
if
handle
.
value
:
if
handle
.
value
:
return
Function
(
handle
,
False
)
return
Function
(
handle
,
False
)
else
:
if
allow_missing
:
if
allow_missing
:
return
None
return
None
else
:
raise
ValueError
(
"Cannot find global function
%
s"
%
name
)
raise
ValueError
(
"Cannot find global function
%
s"
%
name
)
...
...
python/tvm/_ffi/node_generic.py
View file @
e20ef0d4
...
@@ -36,16 +36,16 @@ def convert_to_node(value):
...
@@ -36,16 +36,16 @@ def convert_to_node(value):
"""
"""
if
isinstance
(
value
,
_CLASS_NODE_BASE
):
if
isinstance
(
value
,
_CLASS_NODE_BASE
):
return
value
return
value
el
if
isinstance
(
value
,
bool
):
if
isinstance
(
value
,
bool
):
return
const
(
value
,
'uint1x1'
)
return
const
(
value
,
'uint1x1'
)
el
if
isinstance
(
value
,
Number
):
if
isinstance
(
value
,
Number
):
return
const
(
value
)
return
const
(
value
)
el
if
isinstance
(
value
,
string_types
):
if
isinstance
(
value
,
string_types
):
return
_api_internal
.
_str
(
value
)
return
_api_internal
.
_str
(
value
)
el
if
isinstance
(
value
,
(
list
,
tuple
)):
if
isinstance
(
value
,
(
list
,
tuple
)):
value
=
[
convert_to_node
(
x
)
for
x
in
value
]
value
=
[
convert_to_node
(
x
)
for
x
in
value
]
return
_api_internal
.
_Array
(
*
value
)
return
_api_internal
.
_Array
(
*
value
)
el
if
isinstance
(
value
,
dict
):
if
isinstance
(
value
,
dict
):
vlist
=
[]
vlist
=
[]
for
item
in
value
.
items
():
for
item
in
value
.
items
():
if
(
not
isinstance
(
item
[
0
],
_CLASS_NODE_BASE
)
and
if
(
not
isinstance
(
item
[
0
],
_CLASS_NODE_BASE
)
and
...
@@ -54,11 +54,11 @@ def convert_to_node(value):
...
@@ -54,11 +54,11 @@ def convert_to_node(value):
vlist
.
append
(
item
[
0
])
vlist
.
append
(
item
[
0
])
vlist
.
append
(
convert_to_node
(
item
[
1
]))
vlist
.
append
(
convert_to_node
(
item
[
1
]))
return
_api_internal
.
_Map
(
*
vlist
)
return
_api_internal
.
_Map
(
*
vlist
)
el
if
isinstance
(
value
,
NodeGeneric
):
if
isinstance
(
value
,
NodeGeneric
):
return
value
.
asnode
()
return
value
.
asnode
()
el
if
value
is
None
:
if
value
is
None
:
return
None
return
None
else
:
raise
ValueError
(
"don't know how to convert type
%
s to node"
%
type
(
value
))
raise
ValueError
(
"don't know how to convert type
%
s to node"
%
type
(
value
))
...
...
python/tvm/arith.py
View file @
e20ef0d4
...
@@ -31,11 +31,11 @@ class IntervalSet(IntSet):
...
@@ -31,11 +31,11 @@ class IntervalSet(IntSet):
@register_node
@register_node
class
StrideSet
(
IntSet
):
class
StrideSet
(
IntSet
):
"""Represent set of strided integers"""
"""Represent set of strided integers"""
pass
@register_node
@register_node
class
ModularSet
(
IntSet
):
class
ModularSet
(
IntSet
):
"""Represent range of (coeff * x + base) for x in Z """
"""Represent range of (coeff * x + base) for x in Z """
pass
_init_api
(
"tvm.arith"
)
_init_api
(
"tvm.arith"
)
python/tvm/autotvm/measure/executor.py
View file @
e20ef0d4
...
@@ -69,15 +69,14 @@ class Future(object):
...
@@ -69,15 +69,14 @@ class Future(object):
class
FutureError
(
RuntimeError
):
class
FutureError
(
RuntimeError
):
"""Base error class of all future events"""
"""Base error class of all future events"""
pass
# pylint:disable=redefined-builtin
# pylint:disable=redefined-builtin
class
TimeoutError
(
FutureError
):
class
TimeoutError
(
FutureError
):
"""Error raised when a task is timeout."""
"""Error raised when a task is timeout."""
pass
class
ExecutionError
(
FutureError
):
class
ExecutionError
(
FutureError
):
"""
"""
Error raised when future execution crashes or failed.
Error raised when future execution crashes or failed.
"""
"""
pass
python/tvm/autotvm/record.py
View file @
e20ef0d4
...
@@ -83,7 +83,7 @@ def encode(inp, result, protocol='json'):
...
@@ -83,7 +83,7 @@ def encode(inp, result, protocol='json'):
"v"
:
AUTOTVM_LOG_VERSION
"v"
:
AUTOTVM_LOG_VERSION
}
}
return
json
.
dumps
(
json_dict
)
return
json
.
dumps
(
json_dict
)
el
if
protocol
==
'pickle'
:
if
protocol
==
'pickle'
:
row
=
(
str
(
inp
.
target
),
row
=
(
str
(
inp
.
target
),
str
(
base64
.
b64encode
(
pickle
.
dumps
([
inp
.
task
.
name
,
str
(
base64
.
b64encode
(
pickle
.
dumps
([
inp
.
task
.
name
,
inp
.
task
.
args
,
inp
.
task
.
args
,
...
@@ -92,7 +92,7 @@ def encode(inp, result, protocol='json'):
...
@@ -92,7 +92,7 @@ def encode(inp, result, protocol='json'):
str
(
base64
.
b64encode
(
pickle
.
dumps
(
inp
.
config
))
.
decode
()),
str
(
base64
.
b64encode
(
pickle
.
dumps
(
inp
.
config
))
.
decode
()),
str
(
base64
.
b64encode
(
pickle
.
dumps
(
tuple
(
result
)))
.
decode
()))
str
(
base64
.
b64encode
(
pickle
.
dumps
(
tuple
(
result
)))
.
decode
()))
return
'
\t
'
.
join
(
row
)
return
'
\t
'
.
join
(
row
)
else
:
raise
RuntimeError
(
"Invalid log protocol: "
+
protocol
)
raise
RuntimeError
(
"Invalid log protocol: "
+
protocol
)
...
@@ -136,7 +136,7 @@ def decode(row, protocol='json'):
...
@@ -136,7 +136,7 @@ def decode(row, protocol='json'):
result
=
MeasureResult
(
*
[
tuple
(
x
)
if
isinstance
(
x
,
list
)
else
x
for
x
in
row
[
"r"
]])
result
=
MeasureResult
(
*
[
tuple
(
x
)
if
isinstance
(
x
,
list
)
else
x
for
x
in
row
[
"r"
]])
return
inp
,
result
return
inp
,
result
el
if
protocol
==
'pickle'
:
if
protocol
==
'pickle'
:
items
=
row
.
split
(
"
\t
"
)
items
=
row
.
split
(
"
\t
"
)
tgt
=
_target
.
create
(
items
[
0
])
tgt
=
_target
.
create
(
items
[
0
])
task_tuple
=
pickle
.
loads
(
base64
.
b64decode
(
items
[
1
]
.
encode
()))
task_tuple
=
pickle
.
loads
(
base64
.
b64decode
(
items
[
1
]
.
encode
()))
...
@@ -146,7 +146,7 @@ def decode(row, protocol='json'):
...
@@ -146,7 +146,7 @@ def decode(row, protocol='json'):
tsk
=
task
.
Task
(
task_tuple
[
0
],
task_tuple
[
1
])
tsk
=
task
.
Task
(
task_tuple
[
0
],
task_tuple
[
1
])
tsk
.
workload
=
task_tuple
[
3
]
tsk
.
workload
=
task_tuple
[
3
]
return
MeasureInput
(
tgt
,
tsk
,
config
),
MeasureResult
(
*
result
)
return
MeasureInput
(
tgt
,
tsk
,
config
),
MeasureResult
(
*
result
)
else
:
raise
RuntimeError
(
"Invalid log protocol: "
+
protocol
)
raise
RuntimeError
(
"Invalid log protocol: "
+
protocol
)
...
...
python/tvm/autotvm/task/space.py
View file @
e20ef0d4
...
@@ -32,7 +32,6 @@ class InstantiationError(ValueError):
...
@@ -32,7 +32,6 @@ class InstantiationError(ValueError):
raised by cfg.raise_error
raised by cfg.raise_error
e.g. too many unrolling, too many threads in a block
e.g. too many unrolling, too many threads in a block
"""
"""
pass
class
TransformSpace
(
object
):
class
TransformSpace
(
object
):
...
@@ -321,7 +320,7 @@ class ReorderSpace(TransformSpace):
...
@@ -321,7 +320,7 @@ class ReorderSpace(TransformSpace):
if
np
.
sum
(
tmp_pt
)
==
size
:
if
np
.
sum
(
tmp_pt
)
==
size
:
merged
.
append
(
list
(
tmp_stack
))
merged
.
append
(
list
(
tmp_stack
))
return
return
else
:
for
i
in
range
(
len
(
chains
)):
for
i
in
range
(
len
(
chains
)):
# use i == np.argmax(....) here to take spatial order into consideration
# use i == np.argmax(....) here to take spatial order into consideration
# if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
# if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
...
@@ -441,7 +440,7 @@ class AnnotateSpace(TransformSpace):
...
@@ -441,7 +440,7 @@ class AnnotateSpace(TransformSpace):
if
now
==
self
.
num_axis
:
if
now
==
self
.
num_axis
:
# only vectorize inner most dimension
# only vectorize inner most dimension
vec_ct
=
tmp_stack
.
count
(
'vec'
)
vec_ct
=
tmp_stack
.
count
(
'vec'
)
if
vec_ct
==
0
or
vec_ct
==
1
:
if
vec_ct
in
(
0
,
1
)
:
self
.
entities
.
append
(
AnnotateEntity
(
list
(
tmp_stack
)))
self
.
entities
.
append
(
AnnotateEntity
(
list
(
tmp_stack
)))
else
:
else
:
for
ann
in
self
.
anns
[
now
]:
for
ann
in
self
.
anns
[
now
]:
...
...
python/tvm/autotvm/task/task.py
View file @
e20ef0d4
...
@@ -294,7 +294,7 @@ def get_config():
...
@@ -294,7 +294,7 @@ def get_config():
class
FlopCalculationError
(
RuntimeError
):
class
FlopCalculationError
(
RuntimeError
):
"""Error happens when estimating FLOP for a compute op"""
"""Error happens when estimating FLOP for a compute op"""
pass
def
compute_flop
(
sch
):
def
compute_flop
(
sch
):
"""Calculate number of FLOP (floating number operations) of the compute ops in a schedule
"""Calculate number of FLOP (floating number operations) of the compute ops in a schedule
...
@@ -328,13 +328,13 @@ def compute_flop(sch):
...
@@ -328,13 +328,13 @@ def compute_flop(sch):
if
len
(
source
)
!=
1
:
if
len
(
source
)
!=
1
:
raise
FlopCalculationError
(
"Found multiple output in the source of reduce op"
)
raise
FlopCalculationError
(
"Found multiple output in the source of reduce op"
)
return
num_iter
*
(
_count_flop
(
combiner
[
0
])
+
_count_flop
(
source
[
0
]))
return
num_iter
*
(
_count_flop
(
combiner
[
0
])
+
_count_flop
(
source
[
0
]))
el
if
isinstance
(
exp
,
(
expr
.
FloatImm
,
expr
.
IntImm
,
expr
.
UIntImm
)):
if
isinstance
(
exp
,
(
expr
.
FloatImm
,
expr
.
IntImm
,
expr
.
UIntImm
)):
return
0
return
0
el
if
isinstance
(
exp
,
expr
.
Cast
):
if
isinstance
(
exp
,
expr
.
Cast
):
return
_count_flop
(
exp
.
value
)
return
_count_flop
(
exp
.
value
)
el
if
isinstance
(
exp
,
expr
.
Var
):
if
isinstance
(
exp
,
expr
.
Var
):
return
0
return
0
el
if
isinstance
(
exp
,
(
expr
.
Add
,
expr
.
Sub
,
expr
.
Mul
,
expr
.
Div
,
expr
.
Mod
,
if
isinstance
(
exp
,
(
expr
.
Add
,
expr
.
Sub
,
expr
.
Mul
,
expr
.
Div
,
expr
.
Mod
,
expr
.
Max
,
expr
.
Min
,
expr
.
Max
,
expr
.
Min
,
expr
.
EQ
,
expr
.
NE
,
expr
.
LT
,
expr
.
LE
,
expr
.
GT
,
expr
.
GE
,
expr
.
EQ
,
expr
.
NE
,
expr
.
LT
,
expr
.
LE
,
expr
.
GT
,
expr
.
GE
,
expr
.
And
,
expr
.
Or
,
expr
.
Not
)):
expr
.
And
,
expr
.
Or
,
expr
.
Not
)):
...
@@ -344,12 +344,12 @@ def compute_flop(sch):
...
@@ -344,12 +344,12 @@ def compute_flop(sch):
return
base
+
_count_flop
(
exp
.
a
)
return
base
+
_count_flop
(
exp
.
a
)
return
base
+
_count_flop
(
exp
.
a
)
+
_count_flop
(
exp
.
b
)
return
base
+
_count_flop
(
exp
.
a
)
+
_count_flop
(
exp
.
b
)
el
if
isinstance
(
exp
,
expr
.
Select
):
if
isinstance
(
exp
,
expr
.
Select
):
return
_count_flop
(
exp
.
condition
)
+
max
(
_count_flop
(
exp
.
true_value
),
return
_count_flop
(
exp
.
condition
)
+
max
(
_count_flop
(
exp
.
true_value
),
_count_flop
(
exp
.
false_value
))
_count_flop
(
exp
.
false_value
))
el
if
isinstance
(
exp
,
expr
.
Call
):
if
isinstance
(
exp
,
expr
.
Call
):
return
sum
([
_count_flop
(
x
)
for
x
in
exp
.
args
])
return
sum
([
_count_flop
(
x
)
for
x
in
exp
.
args
])
else
:
raise
FlopCalculationError
(
"Found unsupported operator in the compute expr"
)
raise
FlopCalculationError
(
"Found unsupported operator in the compute expr"
)
def
traverse
(
ops
):
def
traverse
(
ops
):
...
...
python/tvm/autotvm/tuner/tuner.py
View file @
e20ef0d4
...
@@ -69,7 +69,7 @@ class Tuner(object):
...
@@ -69,7 +69,7 @@ class Tuner(object):
results: Array of autotvm.measure.MeasureResult
results: Array of autotvm.measure.MeasureResult
result for measurement
result for measurement
"""
"""
pass
def
tune
(
self
,
n_trial
,
measure_option
,
early_stopping
=
None
,
callbacks
=
()):
def
tune
(
self
,
n_trial
,
measure_option
,
early_stopping
=
None
,
callbacks
=
()):
"""Begin tuning
"""Begin tuning
...
...
python/tvm/container.py
View file @
e20ef0d4
...
@@ -90,7 +90,7 @@ class Range(NodeBase):
...
@@ -90,7 +90,7 @@ class Range(NodeBase):
You do not need to create Range explicitly.
You do not need to create Range explicitly.
Python list and tuple will be converted automatically to Range in api functions.
Python list and tuple will be converted automatically to Range in api functions.
"""
"""
pass
@register_node
@register_node
class
LoweredFunc
(
NodeBase
):
class
LoweredFunc
(
NodeBase
):
...
...
python/tvm/contrib/nvcc.py
View file @
e20ef0d4
...
@@ -151,14 +151,14 @@ def find_libdevice_path(arch):
...
@@ -151,14 +151,14 @@ def find_libdevice_path(arch):
selected_ver
=
0
selected_ver
=
0
selected_path
=
None
selected_path
=
None
cuda_ver
=
get_cuda_version
(
cuda_path
)
cuda_ver
=
get_cuda_version
(
cuda_path
)
if
cuda_ver
==
9.0
or
cuda_ver
==
9.1
:
if
cuda_ver
in
(
9.0
,
9.1
)
:
path
=
os
.
path
.
join
(
lib_path
,
"libdevice.10.bc"
)
path
=
os
.
path
.
join
(
lib_path
,
"libdevice.10.bc"
)
else
:
else
:
for
fn
in
os
.
listdir
(
lib_path
):
for
fn
in
os
.
listdir
(
lib_path
):
if
not
fn
.
startswith
(
"libdevice"
):
if
not
fn
.
startswith
(
"libdevice"
):
continue
continue
ver
=
int
(
fn
.
split
(
"."
)[
-
3
]
.
split
(
"_"
)[
-
1
])
ver
=
int
(
fn
.
split
(
"."
)[
-
3
]
.
split
(
"_"
)[
-
1
])
if
ver
>
selected_ver
and
ver
<=
arch
:
if
selected_ver
<
ver
<=
arch
:
selected_ver
=
ver
selected_ver
=
ver
selected_path
=
fn
selected_path
=
fn
if
selected_path
is
None
:
if
selected_path
is
None
:
...
...
python/tvm/contrib/verilog.py
View file @
e20ef0d4
...
@@ -118,7 +118,6 @@ def _find_vpi_path():
...
@@ -118,7 +118,6 @@ def _find_vpi_path():
vpi_found
=
[
p
for
p
in
vpi_path
if
os
.
path
.
exists
(
p
)
and
os
.
path
.
isfile
(
p
)]
vpi_found
=
[
p
for
p
in
vpi_path
if
os
.
path
.
exists
(
p
)
and
os
.
path
.
isfile
(
p
)]
if
vpi_found
:
if
vpi_found
:
return
os
.
path
.
dirname
(
vpi_found
[
0
])
return
os
.
path
.
dirname
(
vpi_found
[
0
])
else
:
raise
ValueError
(
"Cannot find tvm_vpi.vpi, make sure you did `make verilog`"
)
raise
ValueError
(
"Cannot find tvm_vpi.vpi, make sure you did `make verilog`"
)
def
search_path
():
def
search_path
():
...
...
python/tvm/hybrid/parser.py
View file @
e20ef0d4
...
@@ -189,9 +189,9 @@ class HybridParser(ast.NodeVisitor):
...
@@ -189,9 +189,9 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
return
entry
return
entry
el
if
ty
is
Symbol
.
ConstVar
:
if
ty
is
Symbol
.
ConstVar
:
return
entry
if
isinstance
(
node
.
ctx
,
ast
.
Load
)
else
None
return
entry
if
isinstance
(
node
.
ctx
,
ast
.
Load
)
else
None
el
if
ty
is
Symbol
.
BufferVar
:
if
ty
is
Symbol
.
BufferVar
:
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
entry
.
dtype
,
entry
.
name
,
[
_api
.
const
(
0
,
'int32'
)],
\
return
_make
.
Call
(
entry
.
dtype
,
entry
.
name
,
[
_api
.
const
(
0
,
'int32'
)],
\
_expr
.
Call
.
Halide
,
entry
.
op
,
entry
.
value_index
)
_expr
.
Call
.
Halide
,
entry
.
op
,
entry
.
value_index
)
...
@@ -274,7 +274,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -274,7 +274,7 @@ class HybridParser(ast.NodeVisitor):
buf
,
args
=
lhs
buf
,
args
=
lhs
return
_make
.
Provide
(
buf
.
op
,
0
,
rhs
,
args
)
return
_make
.
Provide
(
buf
.
op
,
0
,
rhs
,
args
)
return
util
.
make_nop
()
return
util
.
make_nop
()
else
:
lhs
,
args
=
self
.
visit
(
lhs
)
lhs
,
args
=
self
.
visit
(
lhs
)
_internal_assert
(
isinstance
(
lhs
,
Tensor
),
\
_internal_assert
(
isinstance
(
lhs
,
Tensor
),
\
"An array access's LHS is expected to be a expr.Call!"
)
"An array access's LHS is expected to be a expr.Call!"
)
...
@@ -347,7 +347,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -347,7 +347,7 @@ class HybridParser(ast.NodeVisitor):
if
isinstance
(
cond
,
_expr
.
UIntImm
):
if
isinstance
(
cond
,
_expr
.
UIntImm
):
if
cond
.
value
:
if
cond
.
value
:
return
visit_list_to_block
(
self
.
visit
,
node
.
body
)
return
visit_list_to_block
(
self
.
visit
,
node
.
body
)
el
if
node
.
orelse
:
if
node
.
orelse
:
return
visit_list_to_block
(
self
.
visit
,
node
.
orelse
)
return
visit_list_to_block
(
self
.
visit
,
node
.
orelse
)
return
util
.
make_nop
()
return
util
.
make_nop
()
...
@@ -451,7 +451,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -451,7 +451,7 @@ class HybridParser(ast.NodeVisitor):
bodies
.
append
(
body
)
bodies
.
append
(
body
)
return
concat_list_to_block
(
bodies
)
return
concat_list_to_block
(
bodies
)
el
if
iter_var
is
None
:
if
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
offset
=
iter_var
=
_api
.
var
(
_name
)
offset
=
iter_var
=
_api
.
var
(
_name
)
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
'int32'
)):
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
'int32'
)):
...
...
python/tvm/hybrid/util.py
View file @
e20ef0d4
...
@@ -60,7 +60,7 @@ def replace_io(body, rmap):
...
@@ -60,7 +60,7 @@ def replace_io(body, rmap):
if
isinstance
(
op
,
_stmt
.
Provide
)
and
op
.
func
in
rmap
.
keys
():
if
isinstance
(
op
,
_stmt
.
Provide
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
buf
=
rmap
[
op
.
func
]
return
_make
.
Provide
(
buf
.
op
,
op
.
value_index
,
op
.
value
,
op
.
args
)
return
_make
.
Provide
(
buf
.
op
,
op
.
value_index
,
op
.
value
,
op
.
args
)
el
if
isinstance
(
op
,
_expr
.
Call
)
and
op
.
func
in
rmap
.
keys
():
if
isinstance
(
op
,
_expr
.
Call
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
buf
=
rmap
[
op
.
func
]
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
op
.
args
,
\
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
op
.
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
...
...
python/tvm/intrin.py
View file @
e20ef0d4
...
@@ -495,7 +495,7 @@ def _rule_float_suffix(op):
...
@@ -495,7 +495,7 @@ def _rule_float_suffix(op):
"""
"""
if
op
.
dtype
==
"float32"
:
if
op
.
dtype
==
"float32"
:
return
call_pure_extern
(
op
.
dtype
,
"
%
sf"
%
op
.
name
,
*
op
.
args
)
return
call_pure_extern
(
op
.
dtype
,
"
%
sf"
%
op
.
name
,
*
op
.
args
)
el
if
op
.
dtype
==
"float64"
:
if
op
.
dtype
==
"float64"
:
return
call_pure_extern
(
op
.
dtype
,
op
.
name
,
*
op
.
args
)
return
call_pure_extern
(
op
.
dtype
,
op
.
name
,
*
op
.
args
)
return
op
return
op
...
...
python/tvm/make.py
View file @
e20ef0d4
...
@@ -56,7 +56,7 @@ def static_cast(dtype, expr):
...
@@ -56,7 +56,7 @@ def static_cast(dtype, expr):
if
target_type
.
type_code
==
src_type
.
type_code
and
src_type
.
bits
==
target_type
.
bits
:
if
target_type
.
type_code
==
src_type
.
type_code
and
src_type
.
bits
==
target_type
.
bits
:
if
src_type
.
lanes
==
target_type
.
lanes
:
if
src_type
.
lanes
==
target_type
.
lanes
:
return
expr
return
expr
el
if
src_type
.
lanes
==
1
and
target_type
.
lanes
>
1
:
if
src_type
.
lanes
==
1
and
target_type
.
lanes
>
1
:
return
Broadcast
(
expr
,
target_type
.
lanes
)
return
Broadcast
(
expr
,
target_type
.
lanes
)
return
Cast
(
dtype
,
expr
)
return
Cast
(
dtype
,
expr
)
...
...
python/tvm/ndarray.py
View file @
e20ef0d4
...
@@ -23,7 +23,6 @@ class NDArray(NDArrayBase):
...
@@ -23,7 +23,6 @@ class NDArray(NDArrayBase):
Instead, this is a minimal data structure to demonstrate
Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers.
how can we use TVM in existing project which might have their own array containers.
"""
"""
pass
def
cpu
(
dev_id
=
0
):
def
cpu
(
dev_id
=
0
):
...
...
python/tvm/relay/_parser.py
View file @
e20ef0d4
...
@@ -43,8 +43,8 @@ try:
...
@@ -43,8 +43,8 @@ try:
from
antlr4.tree.Tree
import
TerminalNode
from
antlr4.tree.Tree
import
TerminalNode
except
ImportError
:
except
ImportError
:
raise
ParseError
(
"Couldn't find ANTLR runtime."
+
raise
ParseError
(
"Couldn't find ANTLR runtime."
+
"Try running `pip{
} install antlr4-python{
}-runtime`."
"Try running `pip{
version} install antlr4-python{version
}-runtime`."
.
format
(
PYTHON_VERSION
,
PYTHON_VERSION
))
.
format
(
version
=
PYTHON_VERSION
))
BINARY_OPS
=
{
BINARY_OPS
=
{
RelayParser
.
MUL
:
op
.
multiply
,
RelayParser
.
MUL
:
op
.
multiply
,
...
@@ -179,32 +179,30 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -179,32 +179,30 @@ class ParseTreeToRelayIR(RelayVisitor):
# variables
# variables
if
node_type
==
RelayLexer
.
GLOBAL_VAR
:
if
node_type
==
RelayLexer
.
GLOBAL_VAR
:
return
lookup
(
deque
([
self
.
global_var_scope
]),
node_text
[
1
:])
return
lookup
(
deque
([
self
.
global_var_scope
]),
node_text
[
1
:])
el
if
node_type
==
RelayLexer
.
LOCAL_VAR
:
if
node_type
==
RelayLexer
.
LOCAL_VAR
:
# Remove the leading '%' and lookup the name.
# Remove the leading '%' and lookup the name.
var
=
lookup
(
self
.
var_scopes
,
name
)
var
=
lookup
(
self
.
var_scopes
,
name
)
if
var
is
None
:
if
var
is
None
:
raise
ParseError
(
"Couldn't resolve `{}`."
.
format
(
name
))
raise
ParseError
(
"Couldn't resolve `{}`."
.
format
(
name
))
return
var
return
var
el
if
node_type
==
RelayLexer
.
GRAPH_VAR
:
if
node_type
==
RelayLexer
.
GRAPH_VAR
:
try
:
try
:
return
self
.
graph_expr
[
int
(
name
)]
return
self
.
graph_expr
[
int
(
name
)]
except
IndexError
:
except
IndexError
:
raise
ParseError
(
"Couldn't resolve `{}`"
.
format
(
name
))
raise
ParseError
(
"Couldn't resolve `{}`"
.
format
(
name
))
# data types
# data types
el
if
node_type
==
RelayLexer
.
NAT
:
if
node_type
==
RelayLexer
.
NAT
:
return
int
(
node_text
)
return
int
(
node_text
)
el
if
node_type
==
RelayLexer
.
FLOAT
:
if
node_type
==
RelayLexer
.
FLOAT
:
return
float
(
node_text
)
return
float
(
node_text
)
el
if
node_type
==
RelayLexer
.
BOOL_LIT
:
if
node_type
==
RelayLexer
.
BOOL_LIT
:
if
node_text
==
"True"
:
if
node_text
==
"True"
:
return
True
return
True
el
if
node_text
==
"False"
:
if
node_text
==
"False"
:
return
False
return
False
else
:
raise
ParseError
(
"Unrecognized BOOL_LIT: `{}`"
.
format
(
node_text
))
raise
ParseError
(
"Unrecognized BOOL_LIT: `{}`"
.
format
(
node_text
))
else
:
raise
ParseError
(
"todo: {}"
.
format
(
node_text
))
raise
ParseError
(
"todo: {}"
.
format
(
node_text
))
def
visit_list
(
self
,
ctx_list
):
def
visit_list
(
self
,
ctx_list
):
...
...
python/tvm/relay/adt.py
View file @
e20ef0d4
...
@@ -8,7 +8,7 @@ from .expr import Expr, Call
...
@@ -8,7 +8,7 @@ from .expr import Expr, Call
class
Pattern
(
RelayNode
):
class
Pattern
(
RelayNode
):
"""Base type for pattern matching constructs."""
"""Base type for pattern matching constructs."""
pass
@register_relay_node
@register_relay_node
class
PatternWildcard
(
Pattern
):
class
PatternWildcard
(
Pattern
):
...
...
python/tvm/relay/backend/compile_engine.py
View file @
e20ef0d4
...
@@ -10,7 +10,6 @@ from . import _backend
...
@@ -10,7 +10,6 @@ from . import _backend
class
CachedFunc
(
NodeBase
):
class
CachedFunc
(
NodeBase
):
"""Low-level tensor function to back a relay primitive function.
"""Low-level tensor function to back a relay primitive function.
"""
"""
pass
@register_relay_node
@register_relay_node
...
@@ -34,7 +33,6 @@ class CCacheKey(NodeBase):
...
@@ -34,7 +33,6 @@ class CCacheKey(NodeBase):
class
CCacheValue
(
NodeBase
):
class
CCacheValue
(
NodeBase
):
"""Value in the CompileEngine, including usage statistics.
"""Value in the CompileEngine, including usage statistics.
"""
"""
pass
def
_get_cache_key
(
source_func
,
target
):
def
_get_cache_key
(
source_func
,
target
):
...
...
python/tvm/relay/backend/interpreter.py
View file @
e20ef0d4
...
@@ -49,7 +49,6 @@ class TupleValue(Value):
...
@@ -49,7 +49,6 @@ class TupleValue(Value):
@register_relay_node
@register_relay_node
class
Closure
(
Value
):
class
Closure
(
Value
):
"""A closure produced by the interpreter."""
"""A closure produced by the interpreter."""
pass
@register_relay_node
@register_relay_node
...
...
python/tvm/relay/build_module.py
View file @
e20ef0d4
...
@@ -444,7 +444,6 @@ def create_executor(kind="debug",
...
@@ -444,7 +444,6 @@ def create_executor(kind="debug",
target
=
_target
.
create
(
target
)
target
=
_target
.
create
(
target
)
if
kind
==
"debug"
:
if
kind
==
"debug"
:
return
_interpreter
.
Interpreter
(
mod
,
ctx
,
target
)
return
_interpreter
.
Interpreter
(
mod
,
ctx
,
target
)
el
if
kind
==
"graph"
:
if
kind
==
"graph"
:
return
GraphExecutor
(
mod
,
ctx
,
target
)
return
GraphExecutor
(
mod
,
ctx
,
target
)
else
:
raise
RuntimeError
(
"unknown mode {0}"
.
format
(
mode
))
raise
RuntimeError
(
"unknown mode {0}"
.
format
(
mode
))
python/tvm/relay/frontend/caffe2.py
View file @
e20ef0d4
...
@@ -15,7 +15,6 @@ def dimension_picker(prefix, surfix=''):
...
@@ -15,7 +15,6 @@ def dimension_picker(prefix, surfix=''):
kernel
=
attr
[
'kernel_shape'
]
kernel
=
attr
[
'kernel_shape'
]
if
len
(
kernel
)
==
2
:
if
len
(
kernel
)
==
2
:
return
prefix
+
'2d'
+
surfix
return
prefix
+
'2d'
+
surfix
else
:
raise
NotImplementedError
(
"Only 2d kernel supported."
)
raise
NotImplementedError
(
"Only 2d kernel supported."
)
return
_impl
return
_impl
...
@@ -104,7 +103,6 @@ class Caffe2OpConverter(object):
...
@@ -104,7 +103,6 @@ class Caffe2OpConverter(object):
if
hasattr
(
cls
,
'_impl'
):
if
hasattr
(
cls
,
'_impl'
):
return
getattr
(
cls
,
'_impl'
)
return
getattr
(
cls
,
'_impl'
)
else
:
raise
NotImplementedError
(
'{} not implemented'
.
format
(
raise
NotImplementedError
(
'{} not implemented'
.
format
(
cls
.
__name__
))
cls
.
__name__
))
...
@@ -234,9 +232,8 @@ class Concat(Caffe2OpConverter):
...
@@ -234,9 +232,8 @@ class Concat(Caffe2OpConverter):
order
=
order
if
isinstance
(
order
,
str
)
else
order
.
decode
(
'UTF-8'
)
order
=
order
if
isinstance
(
order
,
str
)
else
order
.
decode
(
'UTF-8'
)
if
order
==
'NCHW'
:
if
order
==
'NCHW'
:
return
1
return
1
el
if
order
==
'NHWC'
:
if
order
==
'NHWC'
:
return
3
return
3
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Unsupported storage order: {} in caffe2"
.
format
(
order
))
"Unsupported storage order: {} in caffe2"
.
format
(
order
))
...
...
python/tvm/relay/frontend/common.py
View file @
e20ef0d4
...
@@ -10,7 +10,6 @@ from .. import op as _op
...
@@ -10,7 +10,6 @@ from .. import op as _op
class
RequiredAttr
(
object
):
class
RequiredAttr
(
object
):
"""Dummpy class to represent required attr"""
"""Dummpy class to represent required attr"""
pass
class
StrAttrsDict
(
object
):
class
StrAttrsDict
(
object
):
...
...
python/tvm/relay/frontend/coreml.py
View file @
e20ef0d4
...
@@ -100,37 +100,37 @@ def _ActivationParams(op, inexpr, etab):
...
@@ -100,37 +100,37 @@ def _ActivationParams(op, inexpr, etab):
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
)
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
par
.
beta
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
par
.
beta
,
dtype
=
'float32'
)
return
_op
.
add
(
_op
.
multiply
(
inexpr
,
alpha
),
beta
)
return
_op
.
add
(
_op
.
multiply
(
inexpr
,
alpha
),
beta
)
el
if
whichActivation
==
'ReLU'
:
if
whichActivation
==
'ReLU'
:
return
_op
.
nn
.
relu
(
inexpr
)
return
_op
.
nn
.
relu
(
inexpr
)
el
if
whichActivation
==
'leakyReLU'
:
if
whichActivation
==
'leakyReLU'
:
_op
.
nn
.
leaky_relu
(
inexpr
,
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
_op
.
nn
.
leaky_relu
(
inexpr
,
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
elif
whichActivation
==
'thresholdedReLU'
:
elif
whichActivation
==
'thresholdedReLU'
:
alpha_tensor
=
_op
.
full_like
(
inexpr
,
fill_value
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
alpha_tensor
=
_op
.
full_like
(
inexpr
,
fill_value
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
return
_op
.
multiply
(
inexpr
,
_op
.
greater
(
inexpr
,
alpha_tensor
)
.
as_type
(
'float32'
))
return
_op
.
multiply
(
inexpr
,
_op
.
greater
(
inexpr
,
alpha_tensor
)
.
as_type
(
'float32'
))
el
if
whichActivation
==
'PReLU'
:
if
whichActivation
==
'PReLU'
:
return
_op
.
nn
.
prelu
(
inexpr
,
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
return
_op
.
nn
.
prelu
(
inexpr
,
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
el
if
whichActivation
==
'tanh'
:
if
whichActivation
==
'tanh'
:
return
_op
.
tanh
(
inexpr
)
return
_op
.
tanh
(
inexpr
)
el
if
whichActivation
==
'scaledTanh'
:
if
whichActivation
==
'scaledTanh'
:
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
)
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
par
.
beta
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
par
.
beta
,
dtype
=
'float32'
)
return
_op
.
multiply
(
_op
.
tanh
(
_op
.
multiply
(
inexpr
,
beta
)),
alpha
)
return
_op
.
multiply
(
_op
.
tanh
(
_op
.
multiply
(
inexpr
,
beta
)),
alpha
)
el
if
whichActivation
==
'sigmoid'
:
if
whichActivation
==
'sigmoid'
:
return
_op
.
sigmoid
(
inexpr
)
return
_op
.
sigmoid
(
inexpr
)
el
if
whichActivation
==
'sigmoidHard'
:
if
whichActivation
==
'sigmoidHard'
:
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
)
alpha
=
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
par
.
beta
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
par
.
beta
,
dtype
=
'float32'
)
transformX
=
(
alpha
*
inexpr
)
+
beta
transformX
=
(
alpha
*
inexpr
)
+
beta
return
_op
.
clip
(
transformX
,
a_min
=
0.
,
a_max
=
1.
)
return
_op
.
clip
(
transformX
,
a_min
=
0.
,
a_max
=
1.
)
el
if
whichActivation
==
'ELU'
:
if
whichActivation
==
'ELU'
:
return
_op
.
multiply
(
_op
.
add
(
_op
.
exp
(
inexpr
),
_expr
.
const
(
-
1
,
dtype
=
'float32'
)),
return
_op
.
multiply
(
_op
.
add
(
_op
.
exp
(
inexpr
),
_expr
.
const
(
-
1
,
dtype
=
'float32'
)),
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
_expr
.
const
(
par
.
alpha
,
dtype
=
'float32'
))
el
if
whichActivation
==
'softsign'
:
if
whichActivation
==
'softsign'
:
return
inexpr
/
(
_expr
.
const
(
1
,
dtype
=
'float32'
)
+
(
return
inexpr
/
(
_expr
.
const
(
1
,
dtype
=
'float32'
)
+
(
op
.
nn
.
relu
(
inexpr
)
+
_op
.
nn
.
relu
(
_op
.
negative
(
inexpr
))))
op
.
nn
.
relu
(
inexpr
)
+
_op
.
nn
.
relu
(
_op
.
negative
(
inexpr
))))
el
if
whichActivation
==
'softplus'
:
if
whichActivation
==
'softplus'
:
return
_op
.
log
(
_op
.
add
(
_op
.
exp
(
inexpr
),
_expr
.
const
(
1
,
dtype
=
'float32'
)))
return
_op
.
log
(
_op
.
add
(
_op
.
exp
(
inexpr
),
_expr
.
const
(
1
,
dtype
=
'float32'
)))
el
if
whichActivation
==
'parametricSoftplus'
:
if
whichActivation
==
'parametricSoftplus'
:
alpha
=
list
(
par
.
alpha
.
floatValue
)
alpha
=
list
(
par
.
alpha
.
floatValue
)
beta
=
list
(
par
.
alpha
.
floatValue
)
beta
=
list
(
par
.
alpha
.
floatValue
)
if
len
(
alpha
)
==
1
:
if
len
(
alpha
)
==
1
:
...
@@ -142,7 +142,6 @@ def _ActivationParams(op, inexpr, etab):
...
@@ -142,7 +142,6 @@ def _ActivationParams(op, inexpr, etab):
alpha_expr
=
etab
.
new_const
(
alpha
)
alpha_expr
=
etab
.
new_const
(
alpha
)
beta_expr
=
etab
.
new_const
(
beta
)
beta_expr
=
etab
.
new_const
(
beta
)
return
_op
.
multiply
(
_op
.
log
(
_op
.
add
(
_op
.
exp
(
inexpr
),
beta_expr
)),
alpha_expr
)
return
_op
.
multiply
(
_op
.
log
(
_op
.
add
(
_op
.
exp
(
inexpr
),
beta_expr
)),
alpha_expr
)
else
:
raise
NotImplementedError
(
'
%
s not implemented'
%
whichActivation
)
raise
NotImplementedError
(
'
%
s not implemented'
%
whichActivation
)
...
@@ -163,9 +162,8 @@ def _PoolingLayerParams(op, inexpr, etab):
...
@@ -163,9 +162,8 @@ def _PoolingLayerParams(op, inexpr, etab):
if
op
.
globalPooling
:
if
op
.
globalPooling
:
if
op
.
type
==
0
:
if
op
.
type
==
0
:
return
_op
.
nn
.
global_max_pool2d
(
inexpr
)
return
_op
.
nn
.
global_max_pool2d
(
inexpr
)
el
if
op
.
type
==
1
:
if
op
.
type
==
1
:
return
_op
.
nn
.
global_avg_pool2d
(
inexpr
)
return
_op
.
nn
.
global_avg_pool2d
(
inexpr
)
else
:
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
else
:
else
:
...
@@ -196,9 +194,8 @@ def _PoolingLayerParams(op, inexpr, etab):
...
@@ -196,9 +194,8 @@ def _PoolingLayerParams(op, inexpr, etab):
if
op
.
type
==
0
:
if
op
.
type
==
0
:
return
_op
.
nn
.
max_pool2d
(
inexpr
,
**
params
)
return
_op
.
nn
.
max_pool2d
(
inexpr
,
**
params
)
el
if
op
.
type
==
1
:
if
op
.
type
==
1
:
return
_op
.
nn
.
avg_pool2d
(
inexpr
,
**
params
)
return
_op
.
nn
.
avg_pool2d
(
inexpr
,
**
params
)
else
:
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
raise
NotImplementedError
(
"Only max and average pooling implemented"
)
...
...
python/tvm/relay/frontend/keras.py
View file @
e20ef0d4
...
@@ -60,21 +60,21 @@ def _convert_activation(inexpr, keras_layer, _):
...
@@ -60,21 +60,21 @@ def _convert_activation(inexpr, keras_layer, _):
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
beta
,
dtype
=
'float32'
)
beta
=
_expr
.
const
(
beta
,
dtype
=
'float32'
)
return
_op
.
add
(
_op
.
multiply
(
inexpr
,
alpha
),
beta
)
return
_op
.
add
(
_op
.
multiply
(
inexpr
,
alpha
),
beta
)
el
if
act_type
==
'softmax'
:
if
act_type
==
'softmax'
:
return
_op
.
nn
.
softmax
(
inexpr
,
axis
=
1
)
return
_op
.
nn
.
softmax
(
inexpr
,
axis
=
1
)
el
if
act_type
==
'sigmoid'
:
if
act_type
==
'sigmoid'
:
return
_op
.
sigmoid
(
inexpr
)
return
_op
.
sigmoid
(
inexpr
)
el
if
act_type
==
'tanh'
:
if
act_type
==
'tanh'
:
return
_op
.
tanh
(
inexpr
)
return
_op
.
tanh
(
inexpr
)
el
if
act_type
==
'relu'
:
if
act_type
==
'relu'
:
return
_op
.
nn
.
relu
(
inexpr
)
return
_op
.
nn
.
relu
(
inexpr
)
el
if
act_type
==
'softplus'
:
if
act_type
==
'softplus'
:
return
_op
.
log
(
_op
.
add
(
_op
.
exp
(
inexpr
),
_expr
.
const
(
1.
,
dtype
=
'float32'
)))
return
_op
.
log
(
_op
.
add
(
_op
.
exp
(
inexpr
),
_expr
.
const
(
1.
,
dtype
=
'float32'
)))
el
if
act_type
==
'elu'
:
if
act_type
==
'elu'
:
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
'alpha'
)
else
1.
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
'alpha'
)
else
1.
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
return
_get_elu
(
inexpr
,
alpha
)
return
_get_elu
(
inexpr
,
alpha
)
el
if
act_type
==
'selu'
:
if
act_type
==
'selu'
:
# Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
# Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
'alpha'
)
\
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
'alpha'
)
\
else
1.6732632423543772848170429916717
else
1.6732632423543772848170429916717
...
@@ -83,14 +83,14 @@ def _convert_activation(inexpr, keras_layer, _):
...
@@ -83,14 +83,14 @@ def _convert_activation(inexpr, keras_layer, _):
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
gamma
=
_expr
.
const
(
gamma
,
dtype
=
'float32'
)
gamma
=
_expr
.
const
(
gamma
,
dtype
=
'float32'
)
return
gamma
*
_get_elu
(
inexpr
,
alpha
)
return
gamma
*
_get_elu
(
inexpr
,
alpha
)
el
if
act_type
==
'relu6'
:
if
act_type
==
'relu6'
:
return
_op
.
clip
(
inexpr
,
a_min
=
0.
,
a_max
=
6.
)
return
_op
.
clip
(
inexpr
,
a_min
=
0.
,
a_max
=
6.
)
el
if
act_type
==
'softsign'
:
if
act_type
==
'softsign'
:
return
inexpr
/
(
_expr
.
const
(
1.
,
dtype
=
'float32'
)
+
_op
.
abs
(
inexpr
))
return
inexpr
/
(
_expr
.
const
(
1.
,
dtype
=
'float32'
)
+
_op
.
abs
(
inexpr
))
el
if
act_type
==
'hard_sigmoid'
:
if
act_type
==
'hard_sigmoid'
:
x
=
(
_expr
.
const
(
0.2
,
dtype
=
'float32'
)
*
inexpr
)
+
_expr
.
const
(
0.5
,
dtype
=
'float32'
)
x
=
(
_expr
.
const
(
0.2
,
dtype
=
'float32'
)
*
inexpr
)
+
_expr
.
const
(
0.5
,
dtype
=
'float32'
)
return
_op
.
clip
(
x
,
a_min
=
0.
,
a_max
=
1.
)
return
_op
.
clip
(
x
,
a_min
=
0.
,
a_max
=
1.
)
else
:
raise
TypeError
(
"Unsupported activation type : {}"
.
format
(
act_type
))
raise
TypeError
(
"Unsupported activation type : {}"
.
format
(
act_type
))
...
@@ -100,24 +100,24 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
...
@@ -100,24 +100,24 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
if
keras_layer
.
max_value
:
if
keras_layer
.
max_value
:
return
_op
.
clip
(
inexpr
,
a_min
=
0.
,
a_max
=
float
(
keras_layer
.
max_value
))
return
_op
.
clip
(
inexpr
,
a_min
=
0.
,
a_max
=
float
(
keras_layer
.
max_value
))
return
_op
.
nn
.
relu
(
inexpr
)
return
_op
.
nn
.
relu
(
inexpr
)
el
if
act_type
==
'LeakyReLU'
:
if
act_type
==
'LeakyReLU'
:
return
_op
.
nn
.
leaky_relu
(
inexpr
,
alpha
=
float
(
keras_layer
.
alpha
))
return
_op
.
nn
.
leaky_relu
(
inexpr
,
alpha
=
float
(
keras_layer
.
alpha
))
el
if
act_type
==
'ELU'
:
if
act_type
==
'ELU'
:
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
'alpha'
)
else
1.
alpha
=
keras_layer
.
alpha
if
hasattr
(
keras_layer
,
'alpha'
)
else
1.
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
alpha
=
_expr
.
const
(
alpha
,
dtype
=
'float32'
)
return
_get_elu
(
inexpr
,
alpha
)
return
_get_elu
(
inexpr
,
alpha
)
el
if
act_type
==
'PReLU'
:
if
act_type
==
'PReLU'
:
assert
hasattr
(
keras_layer
,
'alpha'
),
"alpha required for PReLU."
assert
hasattr
(
keras_layer
,
'alpha'
),
"alpha required for PReLU."
_check_data_format
(
keras_layer
)
_check_data_format
(
keras_layer
)
size
=
len
(
keras_layer
.
alpha
.
shape
)
size
=
len
(
keras_layer
.
alpha
.
shape
)
alpha
=
etab
.
new_const
(
keras_layer
.
get_weights
()[
0
]
\
alpha
=
etab
.
new_const
(
keras_layer
.
get_weights
()[
0
]
\
.
transpose
(
np
.
roll
(
range
(
size
),
1
)))
.
transpose
(
np
.
roll
(
range
(
size
),
1
)))
return
_op
.
negative
(
alpha
)
*
_op
.
nn
.
relu
(
_op
.
negative
(
inexpr
))
+
_op
.
nn
.
relu
(
inexpr
)
return
_op
.
negative
(
alpha
)
*
_op
.
nn
.
relu
(
_op
.
negative
(
inexpr
))
+
_op
.
nn
.
relu
(
inexpr
)
el
if
act_type
==
'ThresholdedReLU'
:
if
act_type
==
'ThresholdedReLU'
:
theta
=
keras_layer
.
theta
if
hasattr
(
keras_layer
,
'theta'
)
else
1.
theta
=
keras_layer
.
theta
if
hasattr
(
keras_layer
,
'theta'
)
else
1.
return
_op
.
multiply
(
inexpr
,
_op
.
greater
(
inexpr
,
\
return
_op
.
multiply
(
inexpr
,
_op
.
greater
(
inexpr
,
\
_expr
.
const
(
theta
,
dtype
=
'float32'
))
.
astype
(
'float32'
))
_expr
.
const
(
theta
,
dtype
=
'float32'
))
.
astype
(
'float32'
))
else
:
raise
TypeError
(
"Unsupported advanced activation type : {}"
.
format
(
act_type
))
raise
TypeError
(
"Unsupported advanced activation type : {}"
.
format
(
act_type
))
...
@@ -297,9 +297,8 @@ def _convert_pooling(inexpr, keras_layer, etab):
...
@@ -297,9 +297,8 @@ def _convert_pooling(inexpr, keras_layer, etab):
# global pool in keras = global pool + flatten in nnvm/relay
# global pool in keras = global pool + flatten in nnvm/relay
if
pool_type
==
'GlobalMaxPooling2D'
:
if
pool_type
==
'GlobalMaxPooling2D'
:
return
_convert_flatten
(
_op
.
nn
.
global_max_pool2d
(
inexpr
),
keras_layer
,
etab
)
return
_convert_flatten
(
_op
.
nn
.
global_max_pool2d
(
inexpr
),
keras_layer
,
etab
)
el
if
pool_type
==
'GlobalAveragePooling2D'
:
if
pool_type
==
'GlobalAveragePooling2D'
:
return
_convert_flatten
(
_op
.
nn
.
global_avg_pool2d
(
inexpr
),
keras_layer
,
etab
)
return
_convert_flatten
(
_op
.
nn
.
global_avg_pool2d
(
inexpr
),
keras_layer
,
etab
)
else
:
pool_h
,
pool_w
=
keras_layer
.
pool_size
pool_h
,
pool_w
=
keras_layer
.
pool_size
stride_h
,
stride_w
=
keras_layer
.
strides
stride_h
,
stride_w
=
keras_layer
.
strides
params
=
{
'pool_size'
:
[
pool_h
,
pool_w
],
params
=
{
'pool_size'
:
[
pool_h
,
pool_w
],
...
@@ -317,10 +316,9 @@ def _convert_pooling(inexpr, keras_layer, etab):
...
@@ -317,10 +316,9 @@ def _convert_pooling(inexpr, keras_layer, etab):
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
keras_layer
.
padding
))
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
keras_layer
.
padding
))
if
pool_type
==
'MaxPooling2D'
:
if
pool_type
==
'MaxPooling2D'
:
return
_op
.
nn
.
max_pool2d
(
inexpr
,
**
params
)
return
_op
.
nn
.
max_pool2d
(
inexpr
,
**
params
)
el
if
pool_type
==
'AveragePooling2D'
:
if
pool_type
==
'AveragePooling2D'
:
params
[
'count_include_pad'
]
=
False
params
[
'count_include_pad'
]
=
False
return
_op
.
nn
.
avg_pool2d
(
inexpr
,
**
params
)
return
_op
.
nn
.
avg_pool2d
(
inexpr
,
**
params
)
else
:
raise
TypeError
(
"Unsupported pooling type : {}"
.
format
(
keras_layer
))
raise
TypeError
(
"Unsupported pooling type : {}"
.
format
(
keras_layer
))
...
...
python/tvm/relay/frontend/mxnet.py
View file @
e20ef0d4
...
@@ -39,7 +39,7 @@ def _mx_fully_connected(inputs, attrs):
...
@@ -39,7 +39,7 @@ def _mx_fully_connected(inputs, attrs):
def
_get_channel_axis
(
layout
,
op_name
):
def
_get_channel_axis
(
layout
,
op_name
):
if
layout
==
"NCHW"
:
if
layout
==
"NCHW"
:
return
1
return
1
el
if
layout
==
"NHWC"
:
if
layout
==
"NHWC"
:
return
3
return
3
raise
RuntimeError
(
"layout: {} is not supported in {}"
.
format
(
layout
,
op_name
))
raise
RuntimeError
(
"layout: {} is not supported in {}"
.
format
(
layout
,
op_name
))
...
@@ -49,11 +49,11 @@ def _mx_activations(inputs, attrs):
...
@@ -49,11 +49,11 @@ def _mx_activations(inputs, attrs):
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
if
act_type
==
"sigmoid"
:
if
act_type
==
"sigmoid"
:
return
_op
.
sigmoid
(
inputs
[
0
])
return
_op
.
sigmoid
(
inputs
[
0
])
el
if
act_type
==
"tanh"
:
if
act_type
==
"tanh"
:
return
_op
.
tanh
(
inputs
[
0
])
return
_op
.
tanh
(
inputs
[
0
])
el
if
act_type
==
"relu"
:
if
act_type
==
"relu"
:
return
_op
.
nn
.
relu
(
inputs
[
0
])
return
_op
.
nn
.
relu
(
inputs
[
0
])
el
if
act_type
==
"softrelu"
:
if
act_type
==
"softrelu"
:
def
_stable_softrelu
(
x
):
def
_stable_softrelu
(
x
):
# log(1 + exp(-abs(x))) + relu(x)
# log(1 + exp(-abs(x))) + relu(x)
one
=
_expr
.
const
(
1
,
dtype
=
"float32"
)
one
=
_expr
.
const
(
1
,
dtype
=
"float32"
)
...
@@ -147,7 +147,7 @@ def _mx_pooling(inputs, attrs):
...
@@ -147,7 +147,7 @@ def _mx_pooling(inputs, attrs):
if
global_pool
:
if
global_pool
:
return
_op
.
nn
.
global_max_pool2d
(
inputs
[
0
])
return
_op
.
nn
.
global_max_pool2d
(
inputs
[
0
])
return
_pool2d
(
_op
.
nn
.
max_pool2d
,
False
)
return
_pool2d
(
_op
.
nn
.
max_pool2d
,
False
)
el
if
pool_type
==
"avg"
:
if
pool_type
==
"avg"
:
if
global_pool
:
if
global_pool
:
return
_op
.
nn
.
global_avg_pool2d
(
inputs
[
0
])
return
_op
.
nn
.
global_avg_pool2d
(
inputs
[
0
])
return
_pool2d
(
_op
.
nn
.
avg_pool2d
,
True
)
return
_pool2d
(
_op
.
nn
.
avg_pool2d
,
True
)
...
@@ -209,10 +209,10 @@ def _mx_leaky_relu(inputs, attrs):
...
@@ -209,10 +209,10 @@ def _mx_leaky_relu(inputs, attrs):
act_type
=
attrs
.
get_str
(
"act_type"
)
act_type
=
attrs
.
get_str
(
"act_type"
)
if
act_type
==
"leaky"
:
if
act_type
==
"leaky"
:
return
_op
.
nn
.
leaky_relu
(
inputs
[
0
],
alpha
=
attrs
.
get_float
(
"slope"
,
0.25
))
return
_op
.
nn
.
leaky_relu
(
inputs
[
0
],
alpha
=
attrs
.
get_float
(
"slope"
,
0.25
))
el
if
act_type
==
"prelu"
:
if
act_type
==
"prelu"
:
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
_op
.
nn
.
prelu
(
*
inputs
)
return
_op
.
nn
.
prelu
(
*
inputs
)
el
if
act_type
==
"elu"
:
if
act_type
==
"elu"
:
# -slope * relu(1-exp(x)) + relu(x)
# -slope * relu(1-exp(x)) + relu(x)
slope
=
attrs
.
get_float
(
"slope"
,
0.25
)
slope
=
attrs
.
get_float
(
"slope"
,
0.25
)
one
=
_expr
.
const
(
1
,
dtype
=
"float32"
)
one
=
_expr
.
const
(
1
,
dtype
=
"float32"
)
...
@@ -220,7 +220,7 @@ def _mx_leaky_relu(inputs, attrs):
...
@@ -220,7 +220,7 @@ def _mx_leaky_relu(inputs, attrs):
mslope
=
_op
.
nn
.
relu
(
_op
.
subtract
(
one
,
_op
.
exp
(
x
)))
mslope
=
_op
.
nn
.
relu
(
_op
.
subtract
(
one
,
_op
.
exp
(
x
)))
mslope
=
_op
.
multiply
(
mslope
,
_expr
.
const
(
-
slope
,
dtype
=
"float32"
))
mslope
=
_op
.
multiply
(
mslope
,
_expr
.
const
(
-
slope
,
dtype
=
"float32"
))
return
_op
.
add
(
mslope
,
_op
.
nn
.
relu
(
x
))
return
_op
.
add
(
mslope
,
_op
.
nn
.
relu
(
x
))
el
if
act_type
==
"rrelu"
:
if
act_type
==
"rrelu"
:
# NOTE this is only converted for inference.
# NOTE this is only converted for inference.
lower_bound
=
attrs
.
get_float
(
"lower_bound"
)
lower_bound
=
attrs
.
get_float
(
"lower_bound"
)
upper_bound
=
attrs
.
get_float
(
"upper_bound"
)
upper_bound
=
attrs
.
get_float
(
"upper_bound"
)
...
...
python/tvm/relay/frontend/onnx.py
View file @
e20ef0d4
...
@@ -18,7 +18,6 @@ def dimension_picker(prefix, surfix=''):
...
@@ -18,7 +18,6 @@ def dimension_picker(prefix, surfix=''):
kernel
=
attr
[
'kernel_shape'
]
kernel
=
attr
[
'kernel_shape'
]
if
len
(
kernel
)
==
2
:
if
len
(
kernel
)
==
2
:
return
prefix
+
'2d'
+
surfix
return
prefix
+
'2d'
+
surfix
else
:
raise
NotImplementedError
(
"Only 2d kernel supported."
)
raise
NotImplementedError
(
"Only 2d kernel supported."
)
return
_impl
return
_impl
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
e20ef0d4
...
@@ -175,7 +175,6 @@ def _dimension_picker(prefix, surfix=''):
...
@@ -175,7 +175,6 @@ def _dimension_picker(prefix, surfix=''):
kernel
=
attr
[
'kernel_shape'
]
kernel
=
attr
[
'kernel_shape'
]
if
len
(
kernel
)
==
2
:
if
len
(
kernel
)
==
2
:
return
prefix
+
'2d'
+
surfix
return
prefix
+
'2d'
+
surfix
else
:
raise
NotImplementedError
(
"Only 2d kernel supported."
)
raise
NotImplementedError
(
"Only 2d kernel supported."
)
return
_impl
return
_impl
...
@@ -522,7 +521,6 @@ def _reshape():
...
@@ -522,7 +521,6 @@ def _reshape():
op_name
=
"reshape"
,
op_name
=
"reshape"
,
extras
=
{
'newshape'
:
tuple
(
params_new
.
asnumpy
()
.
flatten
())},
extras
=
{
'newshape'
:
tuple
(
params_new
.
asnumpy
()
.
flatten
())},
ignores
=
[
'Tshape'
])(
inputs
,
attr
)
ignores
=
[
'Tshape'
])(
inputs
,
attr
)
else
:
raise
RuntimeError
(
"Reshape with dynamic shape input not supported yet."
)
raise
RuntimeError
(
"Reshape with dynamic shape input not supported yet."
)
return
_impl
return
_impl
...
@@ -1385,7 +1383,7 @@ class GraphProto(object):
...
@@ -1385,7 +1383,7 @@ class GraphProto(object):
shape
=
self
.
_params
[
name
]
.
shape
,
shape
=
self
.
_params
[
name
]
.
shape
,
dtype
=
self
.
_params
[
name
]
.
dtype
)]
dtype
=
self
.
_params
[
name
]
.
dtype
)]
else
:
else
:
if
key
!=
'dtype'
and
key
!=
'_output_shapes'
and
key
!=
'_class'
:
if
key
not
in
(
'dtype'
,
'_output_shapes'
,
'_class'
)
:
raise
NotImplementedError
\
raise
NotImplementedError
\
(
"Other attributes for a Const(param) Node {} ? ."
.
format
(
key
))
(
"Other attributes for a Const(param) Node {} ? ."
.
format
(
key
))
...
...
python/tvm/relay/frontend/tflite.py
View file @
e20ef0d4
...
@@ -126,13 +126,12 @@ class OperatorConverter(object):
...
@@ -126,13 +126,12 @@ class OperatorConverter(object):
if
tensor_wrapper
.
tensor
.
Type
()
==
TensorType
.
UINT8
:
if
tensor_wrapper
.
tensor
.
Type
()
==
TensorType
.
UINT8
:
return
np
.
frombuffer
(
tensor_wrapper
.
buffer
.
DataAsNumpy
(),
dtype
=
np
.
uint8
)
.
reshape
(
return
np
.
frombuffer
(
tensor_wrapper
.
buffer
.
DataAsNumpy
(),
dtype
=
np
.
uint8
)
.
reshape
(
tensor_wrapper
.
tensor
.
ShapeAsNumpy
())
tensor_wrapper
.
tensor
.
ShapeAsNumpy
())
el
if
tensor_wrapper
.
tensor
.
Type
()
==
TensorType
.
FLOAT32
:
if
tensor_wrapper
.
tensor
.
Type
()
==
TensorType
.
FLOAT32
:
return
np
.
frombuffer
(
tensor_wrapper
.
buffer
.
DataAsNumpy
(),
dtype
=
np
.
float32
)
.
reshape
(
return
np
.
frombuffer
(
tensor_wrapper
.
buffer
.
DataAsNumpy
(),
dtype
=
np
.
float32
)
.
reshape
(
tensor_wrapper
.
tensor
.
ShapeAsNumpy
())
tensor_wrapper
.
tensor
.
ShapeAsNumpy
())
el
if
tensor_wrapper
.
tensor
.
Type
()
==
TensorType
.
INT32
:
if
tensor_wrapper
.
tensor
.
Type
()
==
TensorType
.
INT32
:
return
np
.
frombuffer
(
tensor_wrapper
.
buffer
.
DataAsNumpy
(),
dtype
=
np
.
int32
)
.
reshape
(
return
np
.
frombuffer
(
tensor_wrapper
.
buffer
.
DataAsNumpy
(),
dtype
=
np
.
int32
)
.
reshape
(
tensor_wrapper
.
tensor
.
ShapeAsNumpy
())
tensor_wrapper
.
tensor
.
ShapeAsNumpy
())
else
:
raise
NotImplementedError
(
"Not support tensor type {}"
raise
NotImplementedError
(
"Not support tensor type {}"
.
format
(
str
(
tensor_wrapper
.
tensor
.
Type
())))
.
format
(
str
(
tensor_wrapper
.
tensor
.
Type
())))
...
@@ -145,11 +144,10 @@ class OperatorConverter(object):
...
@@ -145,11 +144,10 @@ class OperatorConverter(object):
if
tensor_type
==
TensorType
.
UINT8
:
if
tensor_type
==
TensorType
.
UINT8
:
return
"uint8"
return
"uint8"
el
if
tensor_type
==
TensorType
.
FLOAT32
:
if
tensor_type
==
TensorType
.
FLOAT32
:
return
"float32"
return
"float32"
el
if
tensor_type
==
TensorType
.
INT32
:
if
tensor_type
==
TensorType
.
INT32
:
return
"int32"
return
"int32"
else
:
raise
NotImplementedError
(
"Not support tensor type {}"
.
format
(
str
(
tensor_type
)))
raise
NotImplementedError
(
"Not support tensor type {}"
.
format
(
str
(
tensor_type
)))
def
convert_conv2d
(
self
,
op
):
def
convert_conv2d
(
self
,
op
):
...
@@ -192,7 +190,7 @@ class OperatorConverter(object):
...
@@ -192,7 +190,7 @@ class OperatorConverter(object):
in_expr
=
self
.
get_expr
(
input_tensor_idx
)
in_expr
=
self
.
get_expr
(
input_tensor_idx
)
if
input_shape_length
==
1
or
input_shape_length
==
2
:
if
input_shape_length
in
(
1
,
2
)
:
# The rule is channel first (after N but before H, W).
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
# length of 2 means N*H*W, C, do nothing.
...
@@ -275,7 +273,7 @@ class OperatorConverter(object):
...
@@ -275,7 +273,7 @@ class OperatorConverter(object):
in_expr
=
self
.
get_expr
(
input_tensor_idx
)
in_expr
=
self
.
get_expr
(
input_tensor_idx
)
# TFLite is N H W C, our layout is N C H W
# TFLite is N H W C, our layout is N C H W
if
input_shape_length
==
1
or
input_shape_length
==
2
:
if
input_shape_length
in
(
1
,
2
)
:
# The rule is channel first (after N but before H, W).
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
# length of 2 means N*H*W, C, do nothing.
...
@@ -299,7 +297,7 @@ class OperatorConverter(object):
...
@@ -299,7 +297,7 @@ class OperatorConverter(object):
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
# add more if we need target shapes in future
if
output_shape_length
==
1
or
output_shape_length
==
2
:
if
output_shape_length
in
(
1
,
2
)
:
pass
pass
elif
output_shape_length
==
3
:
elif
output_shape_length
==
3
:
out
=
_op
.
transpose
(
out
,
axes
=
(
0
,
2
,
1
))
out
=
_op
.
transpose
(
out
,
axes
=
(
0
,
2
,
1
))
...
@@ -320,13 +318,12 @@ class OperatorConverter(object):
...
@@ -320,13 +318,12 @@ class OperatorConverter(object):
assert
fused_activation_fn
!=
ActivationFunctionType
.
NONE
assert
fused_activation_fn
!=
ActivationFunctionType
.
NONE
if
fused_activation_fn
==
ActivationFunctionType
.
RELU6
:
if
fused_activation_fn
==
ActivationFunctionType
.
RELU6
:
return
_op
.
clip
(
in_expr
,
a_min
=
0
,
a_max
=
6
)
return
_op
.
clip
(
in_expr
,
a_min
=
0
,
a_max
=
6
)
el
if
fused_activation_fn
==
ActivationFunctionType
.
RELU
:
if
fused_activation_fn
==
ActivationFunctionType
.
RELU
:
return
_op
.
nn
.
relu
(
in_expr
)
return
_op
.
nn
.
relu
(
in_expr
)
el
if
fused_activation_fn
==
ActivationFunctionType
.
RELU_N1_TO_1
:
if
fused_activation_fn
==
ActivationFunctionType
.
RELU_N1_TO_1
:
return
_op
.
clip
(
in_expr
,
a_min
=-
1
,
a_max
=
1
)
return
_op
.
clip
(
in_expr
,
a_min
=-
1
,
a_max
=
1
)
el
if
fused_activation_fn
==
ActivationFunctionType
.
TANH
:
if
fused_activation_fn
==
ActivationFunctionType
.
TANH
:
return
_op
.
tanh
(
in_expr
)
return
_op
.
tanh
(
in_expr
)
else
:
fused_activation_fn_str
=
self
.
activation_fn_type
[
fused_activation_fn
]
fused_activation_fn_str
=
self
.
activation_fn_type
[
fused_activation_fn
]
raise
NotImplementedError
(
"Unsupported fused activation fn {}"
raise
NotImplementedError
(
"Unsupported fused activation fn {}"
.
format
(
fused_activation_fn_str
))
.
format
(
fused_activation_fn_str
))
...
@@ -401,7 +398,7 @@ class OperatorConverter(object):
...
@@ -401,7 +398,7 @@ class OperatorConverter(object):
# weight tensor type should be UINT8 (quantization) or FLOAT32
# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type
=
weight_tensor
.
tensor
.
Type
()
weight_tensor_type
=
weight_tensor
.
tensor
.
Type
()
assert
weight_tensor_type
==
TensorType
.
UINT8
or
weight_tensor_type
==
TensorType
.
FLOAT32
assert
weight_tensor_type
in
(
TensorType
.
UINT8
,
TensorType
.
FLOAT32
)
weight_tensor_type_str
=
self
.
get_tensor_type_str
(
weight_tensor_type
)
weight_tensor_type_str
=
self
.
get_tensor_type_str
(
weight_tensor_type
)
in_expr
=
self
.
get_expr
(
input_tensor_idx
)
in_expr
=
self
.
get_expr
(
input_tensor_idx
)
...
@@ -434,7 +431,7 @@ class OperatorConverter(object):
...
@@ -434,7 +431,7 @@ class OperatorConverter(object):
bias_tensor
=
input_tensors
[
2
]
bias_tensor
=
input_tensors
[
2
]
bias_tensor_type
=
bias_tensor
.
tensor
.
Type
()
bias_tensor_type
=
bias_tensor
.
tensor
.
Type
()
# bias tensor type should be INT32 (quantization) or FLOAT32
# bias tensor type should be INT32 (quantization) or FLOAT32
assert
bias_tensor_type
==
TensorType
.
INT32
or
bias_tensor_type
==
TensorType
.
FLOAT32
assert
bias_tensor_type
in
(
TensorType
.
INT32
,
TensorType
.
FLOAT32
)
bias_tensor_type_str
=
self
.
get_tensor_type_str
(
bias_tensor_type
)
bias_tensor_type_str
=
self
.
get_tensor_type_str
(
bias_tensor_type
)
bias_expr
=
self
.
exp_tab
.
new_const
(
self
.
get_tensor_value
(
bias_tensor
),
bias_expr
=
self
.
exp_tab
.
new_const
(
self
.
get_tensor_value
(
bias_tensor
),
dtype
=
bias_tensor_type_str
)
dtype
=
bias_tensor_type_str
)
...
...
python/tvm/relay/op/nn/_nn.py
View file @
e20ef0d4
...
@@ -57,7 +57,7 @@ def compute_conv2d(attrs, inputs, out_type, target):
...
@@ -57,7 +57,7 @@ def compute_conv2d(attrs, inputs, out_type, target):
layout
=
attrs
.
data_layout
layout
=
attrs
.
data_layout
kernel_layout
=
attrs
.
kernel_layout
kernel_layout
=
attrs
.
kernel_layout
out_dtype
=
attrs
.
out_dtype
out_dtype
=
attrs
.
out_dtype
out_dtype
=
(
inputs
[
0
]
.
dtype
if
(
out_dtype
==
"same"
or
out_dtype
==
""
)
out_dtype
=
(
inputs
[
0
]
.
dtype
if
out_dtype
in
(
"same"
,
""
)
else
out_dtype
)
else
out_dtype
)
assert
layout
in
[
"NCHW"
,
"NHWC"
,
"NCHW4c"
]
assert
layout
in
[
"NCHW"
,
"NHWC"
,
"NCHW4c"
]
...
@@ -95,15 +95,15 @@ def schedule_conv2d(attrs, outs, target):
...
@@ -95,15 +95,15 @@ def schedule_conv2d(attrs, outs, target):
with
target
:
with
target
:
if
groups
==
1
and
layout
==
"NCHW"
:
if
groups
==
1
and
layout
==
"NCHW"
:
return
topi
.
generic
.
schedule_conv2d_nchw
(
outs
)
return
topi
.
generic
.
schedule_conv2d_nchw
(
outs
)
el
if
groups
==
1
and
layout
==
"NCHW4c"
:
if
groups
==
1
and
layout
==
"NCHW4c"
:
return
topi
.
generic
.
schedule_conv2d_nchw
(
outs
)
return
topi
.
generic
.
schedule_conv2d_nchw
(
outs
)
el
if
groups
==
1
and
layout
==
"NHWC"
:
if
groups
==
1
and
layout
==
"NHWC"
:
return
topi
.
generic
.
schedule_conv2d_nhwc
(
outs
)
return
topi
.
generic
.
schedule_conv2d_nhwc
(
outs
)
el
if
groups
!=
1
:
if
groups
!=
1
:
if
layout
==
"NCHW"
:
if
layout
==
"NCHW"
:
# TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
# TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
return
topi
.
generic
.
schedule_depthwise_conv2d_nchw
(
outs
)
return
topi
.
generic
.
schedule_depthwise_conv2d_nchw
(
outs
)
el
if
layout
==
"NHWC"
and
kernel_layout
==
"HWOI"
:
if
layout
==
"NHWC"
and
kernel_layout
==
"HWOI"
:
return
topi
.
generic
.
schedule_depthwise_conv2d_nhwc
(
outs
)
return
topi
.
generic
.
schedule_depthwise_conv2d_nhwc
(
outs
)
raise
ValueError
(
"No compatible schedule"
)
raise
ValueError
(
"No compatible schedule"
)
...
@@ -127,7 +127,7 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
...
@@ -127,7 +127,7 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
groups
=
attrs
.
groups
groups
=
attrs
.
groups
layout
=
attrs
.
data_layout
layout
=
attrs
.
data_layout
out_dtype
=
attrs
.
out_dtype
out_dtype
=
attrs
.
out_dtype
out_dtype
=
(
inputs
[
0
]
.
dtype
if
(
out_dtype
==
"same"
or
out_dtype
==
""
)
out_dtype
=
(
inputs
[
0
]
.
dtype
if
out_dtype
in
(
"same"
,
""
)
else
out_dtype
)
else
out_dtype
)
assert
layout
==
"NCHW"
,
"only support nchw for now"
assert
layout
==
"NCHW"
,
"only support nchw for now"
assert
dilation
==
(
1
,
1
),
"not support dilate now"
assert
dilation
==
(
1
,
1
),
"not support dilate now"
...
...
python/tvm/relay/op/op_attrs.py
View file @
e20ef0d4
...
@@ -6,19 +6,18 @@ from ..base import register_relay_attr_node
...
@@ -6,19 +6,18 @@ from ..base import register_relay_attr_node
@register_relay_attr_node
@register_relay_attr_node
class
Conv2DAttrs
(
Attrs
):
class
Conv2DAttrs
(
Attrs
):
"""Attribute of nn.conv2d"""
"""Attribute of nn.conv2d"""
pass
@register_relay_attr_node
@register_relay_attr_node
class
Conv2DWinogradAttrs
(
Attrs
):
class
Conv2DWinogradAttrs
(
Attrs
):
"""Attribute of nn.contrib_conv2d_winograd_without_weight_transform"""
"""Attribute of nn.contrib_conv2d_winograd_without_weight_transform"""
pass
@register_relay_attr_node
@register_relay_attr_node
class
Conv2DWinogradWeightTransformAttrs
(
Attrs
):
class
Conv2DWinogradWeightTransformAttrs
(
Attrs
):
"""Attribute of nn.contrib_conv2d_winograd_weight_transform"""
"""Attribute of nn.contrib_conv2d_winograd_weight_transform"""
pass
@register_relay_attr_node
@register_relay_attr_node
class
GlobalPool2DAttrs
(
Attrs
):
class
GlobalPool2DAttrs
(
Attrs
):
"""Attribute of nn.global_pool"""
"""Attribute of nn.global_pool"""
pass
python/tvm/relay/testing/inception_v3.py
View file @
e20ef0d4
...
@@ -29,10 +29,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
...
@@ -29,10 +29,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
def
Pooling
(
data
,
kernel
,
stride
,
pad
,
pool_type
,
name
):
def
Pooling
(
data
,
kernel
,
stride
,
pad
,
pool_type
,
name
):
if
pool_type
==
'max'
:
if
pool_type
==
'max'
:
return
relay
.
nn
.
max_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
)
return
relay
.
nn
.
max_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
)
el
if
pool_type
==
'avg'
:
if
pool_type
==
'avg'
:
return
relay
.
nn
.
avg_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
,
return
relay
.
nn
.
avg_pool2d
(
data
=
data
,
pool_size
=
kernel
,
strides
=
stride
,
padding
=
pad
,
count_include_pad
=
True
)
count_include_pad
=
True
)
else
:
raise
ValueError
(
"Invalid pooling type: "
+
pool_type
)
raise
ValueError
(
"Invalid pooling type: "
+
pool_type
)
def
Inception7A
(
data
,
def
Inception7A
(
data
,
...
...
python/tvm/relay/ty.py
View file @
e20ef0d4
...
@@ -172,7 +172,6 @@ class TypeCall(Type):
...
@@ -172,7 +172,6 @@ class TypeCall(Type):
@register_relay_node
@register_relay_node
class
TypeConstraint
(
Type
):
class
TypeConstraint
(
Type
):
"""Abstract class representing a type constraint."""
"""Abstract class representing a type constraint."""
pass
@register_relay_node
@register_relay_node
...
...
python/tvm/rpc/proxy.py
View file @
e20ef0d4
...
@@ -389,7 +389,7 @@ class ProxyServerHandler(object):
...
@@ -389,7 +389,7 @@ class ProxyServerHandler(object):
if
key
in
pool_src
:
if
key
in
pool_src
:
self
.
_pair_up
(
pool_src
.
pop
(
key
),
handler
)
self
.
_pair_up
(
pool_src
.
pop
(
key
),
handler
)
return
return
el
if
key
not
in
pool_dst
:
if
key
not
in
pool_dst
:
pool_dst
[
key
]
=
handler
pool_dst
[
key
]
=
handler
def
cleanup
():
def
cleanup
():
"""Cleanup client connection if timeout"""
"""Cleanup client connection if timeout"""
...
...
python/tvm/rpc/tornado_util.py
View file @
e20ef0d4
...
@@ -95,7 +95,6 @@ class TCPHandler(object):
...
@@ -95,7 +95,6 @@ class TCPHandler(object):
if
msg
:
if
msg
:
self
.
on_message
(
msg
)
self
.
on_message
(
msg
)
return
True
return
True
else
:
# normal close, remote is closed
# normal close, remote is closed
self
.
close
()
self
.
close
()
except
socket
.
error
as
err
:
except
socket
.
error
as
err
:
...
...
python/tvm/rpc/tracker.py
View file @
e20ef0d4
...
@@ -86,7 +86,7 @@ class Scheduler(object):
...
@@ -86,7 +86,7 @@ class Scheduler(object):
value: object
value: object
The resource to remove
The resource to remove
"""
"""
pass
def
summary
(
self
):
def
summary
(
self
):
"""Get summary information of the scheduler."""
"""Get summary information of the scheduler."""
...
...
python/tvm/schedule.py
View file @
e20ef0d4
...
@@ -143,19 +143,16 @@ class Buffer(NodeBase):
...
@@ -143,19 +143,16 @@ class Buffer(NodeBase):
@register_node
@register_node
class
Split
(
NodeBase
):
class
Split
(
NodeBase
):
"""Split operation on axis."""
"""Split operation on axis."""
pass
@register_node
@register_node
class
Fuse
(
NodeBase
):
class
Fuse
(
NodeBase
):
"""Fuse operation on axis."""
"""Fuse operation on axis."""
pass
@register_node
@register_node
class
Singleton
(
NodeBase
):
class
Singleton
(
NodeBase
):
"""Singleton axis."""
"""Singleton axis."""
pass
@register_node
@register_node
...
...
python/tvm/stmt.py
View file @
e20ef0d4
...
@@ -381,7 +381,7 @@ def stmt_list(stmt):
...
@@ -381,7 +381,7 @@ def stmt_list(stmt):
"""
"""
if
isinstance
(
stmt
,
Block
):
if
isinstance
(
stmt
,
Block
):
return
stmt_list
(
stmt
.
first
)
+
stmt_list
(
stmt
.
rest
)
return
stmt_list
(
stmt
.
first
)
+
stmt_list
(
stmt
.
rest
)
el
if
isinstance
(
stmt
,
ProducerConsumer
):
if
isinstance
(
stmt
,
ProducerConsumer
):
return
stmt_list
(
stmt
.
body
)
return
stmt_list
(
stmt
.
body
)
return
[
stmt
]
return
[
stmt
]
...
...
python/tvm/tensor.py
View file @
e20ef0d4
...
@@ -33,7 +33,6 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
...
@@ -33,7 +33,6 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
@register_node
@register_node
class
TensorIntrinCall
(
NodeBase
):
class
TensorIntrinCall
(
NodeBase
):
"""Intermediate structure for calling a tensor intrinsic."""
"""Intermediate structure for calling a tensor intrinsic."""
pass
itervar_cls
=
None
itervar_cls
=
None
...
@@ -144,7 +143,6 @@ class Operation(NodeBase):
...
@@ -144,7 +143,6 @@ class Operation(NodeBase):
@register_node
@register_node
class
PlaceholderOp
(
Operation
):
class
PlaceholderOp
(
Operation
):
"""Placeholder operation."""
"""Placeholder operation."""
pass
@register_node
@register_node
...
@@ -164,7 +162,6 @@ class ComputeOp(Operation):
...
@@ -164,7 +162,6 @@ class ComputeOp(Operation):
@register_node
@register_node
class
TensorComputeOp
(
Operation
):
class
TensorComputeOp
(
Operation
):
"""Tensor operation."""
"""Tensor operation."""
pass
@register_node
@register_node
...
@@ -179,7 +176,7 @@ class ScanOp(Operation):
...
@@ -179,7 +176,7 @@ class ScanOp(Operation):
@register_node
@register_node
class
ExternOp
(
Operation
):
class
ExternOp
(
Operation
):
"""Extern operation."""
"""Extern operation."""
pass
@register_node
@register_node
class
HybridOp
(
Operation
):
class
HybridOp
(
Operation
):
...
...
topi/python/topi/arm_cpu/bitserial_conv2d.py
View file @
e20ef0d4
...
@@ -61,7 +61,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
...
@@ -61,7 +61,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
if
out_dtype
is
None
:
if
out_dtype
is
None
:
out_dtype
=
data
.
dtype
out_dtype
=
data
.
dtype
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
layout
==
"NCHW"
or
layout
==
"NHWC"
,
"only support layouts NCHW and NHWC"
assert
layout
in
(
"NCHW"
,
"NHWC"
)
,
"only support layouts NCHW and NHWC"
if
dorefa
:
if
dorefa
:
assert
layout
==
"NCHW"
,
"Cannot support dorea with NHWC layout yet"
assert
layout
==
"NCHW"
,
"Cannot support dorea with NHWC layout yet"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
...
...
topi/python/topi/arm_cpu/conv2d.py
View file @
e20ef0d4
...
@@ -554,7 +554,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
...
@@ -554,7 +554,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
data_layout_key
=
"data_layout"
if
"data_layout"
in
new_attrs
else
"layout"
data_layout_key
=
"data_layout"
if
"data_layout"
in
new_attrs
else
"layout"
layout
=
attrs
[
data_layout_key
]
layout
=
attrs
[
data_layout_key
]
out_dtype
=
attrs
[
"out_dtype"
]
out_dtype
=
attrs
[
"out_dtype"
]
if
out_dtype
==
""
or
out_dtype
==
"same"
:
if
out_dtype
in
(
"same"
,
""
)
:
out_dtype
=
tinfos
[
0
]
.
dtype
out_dtype
=
tinfos
[
0
]
.
dtype
if
layout
!=
'NCHW'
:
if
layout
!=
'NCHW'
:
...
...
topi/python/topi/cuda/conv2d.py
View file @
e20ef0d4
...
@@ -93,9 +93,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
...
@@ -93,9 +93,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
if
layout
==
'NCHW'
:
if
layout
==
'NCHW'
:
return
nn
.
conv2d_nchw
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
return
nn
.
conv2d_nchw
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
el
if
layout
==
'HWCN'
:
if
layout
==
'HWCN'
:
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
else
:
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
...
...
topi/python/topi/cuda/conv2d_winograd.py
View file @
e20ef0d4
...
@@ -362,7 +362,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
...
@@ -362,7 +362,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
data_layout_key
=
"data_layout"
if
"data_layout"
in
new_attrs
else
"layout"
data_layout_key
=
"data_layout"
if
"data_layout"
in
new_attrs
else
"layout"
layout
=
attrs
[
data_layout_key
]
layout
=
attrs
[
data_layout_key
]
out_dtype
=
attrs
[
"out_dtype"
]
out_dtype
=
attrs
[
"out_dtype"
]
if
out_dtype
==
""
or
out_dtype
==
"same"
:
if
out_dtype
in
(
""
,
"same"
)
:
out_dtype
=
tinfos
[
0
]
.
dtype
out_dtype
=
tinfos
[
0
]
.
dtype
data
,
kernel
=
tinfos
[
0
:
2
]
data
,
kernel
=
tinfos
[
0
:
2
]
...
@@ -428,7 +428,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
...
@@ -428,7 +428,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
)
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
return
F
.
nn
.
contrib_conv2d_winograd_without_weight_transform
(
*
copy_inputs
,
**
new_attrs
)
return
F
.
nn
.
contrib_conv2d_winograd_without_weight_transform
(
*
copy_inputs
,
**
new_attrs
)
el
if
groups
!=
CI
:
if
groups
!=
CI
:
workload
=
autotvm
.
task
.
args_to_workload
(
workload
=
autotvm
.
task
.
args_to_workload
(
[
tinfos
[
0
],
tinfos
[
1
],
strides
,
padding
,
dilation
,
groups
,
out_dtype
],
[
tinfos
[
0
],
tinfos
[
1
],
strides
,
padding
,
dilation
,
groups
,
out_dtype
],
group_conv2d_nchw
)
group_conv2d_nchw
)
...
...
topi/python/topi/cuda/reduction.py
View file @
e20ef0d4
...
@@ -96,7 +96,7 @@ def schedule_reduce(outs):
...
@@ -96,7 +96,7 @@ def schedule_reduce(outs):
"""Internal travserse function"""
"""Internal travserse function"""
if
isinstance
(
operator
,
tvm
.
tensor
.
PlaceholderOp
):
if
isinstance
(
operator
,
tvm
.
tensor
.
PlaceholderOp
):
return
return
el
if
tag
.
is_injective
(
operator
.
tag
):
if
tag
.
is_injective
(
operator
.
tag
):
sch
[
operator
]
.
compute_inline
()
sch
[
operator
]
.
compute_inline
()
for
tensor
in
operator
.
input_tensors
:
for
tensor
in
operator
.
input_tensors
:
if
tensor
.
op
not
in
scheduled_ops
:
if
tensor
.
op
not
in
scheduled_ops
:
...
...
topi/python/topi/nn/bitserial_conv2d.py
View file @
e20ef0d4
...
@@ -92,14 +92,14 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
...
@@ -92,14 +92,14 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
if
layout
==
'NCHW'
:
if
layout
==
'NCHW'
:
return
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
return
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
el
if
layout
==
'NHWC'
:
if
layout
==
'NHWC'
:
return
spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
return
spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
def
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
):
def
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
):
""" Get the workload structure. """
""" Get the workload structure. """
assert
layout
==
"NCHW"
or
layout
==
"NHWC"
,
\
assert
layout
in
(
"NCHW"
,
"NHWC"
)
,
\
"Only support layouts NCHW and NHWC"
"Only support layouts NCHW and NHWC"
if
layout
==
"NCHW"
:
if
layout
==
"NCHW"
:
_
,
CI
,
IH
,
IW
=
[
x
.
value
for
x
in
data
.
shape
]
_
,
CI
,
IH
,
IW
=
[
x
.
value
for
x
in
data
.
shape
]
...
...
topi/python/topi/nn/conv2d.py
View file @
e20ef0d4
...
@@ -48,11 +48,10 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
...
@@ -48,11 +48,10 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration
# default declaration
if
layout
==
'NCHW'
:
if
layout
==
'NCHW'
:
return
conv2d_nchw
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
return
conv2d_nchw
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
el
if
layout
==
'HWCN'
:
if
layout
==
'HWCN'
:
return
conv2d_hwcn
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
return
conv2d_hwcn
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
el
if
layout
==
'NHWC'
:
if
layout
==
'NHWC'
:
return
conv2d_nhwc
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
return
conv2d_nhwc
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
else
:
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
...
...
topi/python/topi/testing/upsampling_python.py
View file @
e20ef0d4
...
@@ -17,12 +17,11 @@ def upsampling_python(data, scale, layout='NCHW'):
...
@@ -17,12 +17,11 @@ def upsampling_python(data, scale, layout='NCHW'):
for
c
in
range
(
oshape
[
1
]):
for
c
in
range
(
oshape
[
1
]):
output_np
[
b
,
c
,
:,
:]
=
upsample_nearest
(
data
[
b
,
c
,
:,
:],
scale
)
output_np
[
b
,
c
,
:,
:]
=
upsample_nearest
(
data
[
b
,
c
,
:,
:],
scale
)
return
output_np
return
output_np
el
if
layout
==
'NHWC'
:
if
layout
==
'NHWC'
:
oshape
=
(
ishape
[
0
],
ishape
[
1
]
*
scale
,
ishape
[
1
]
*
scale
,
ishape
[
3
])
oshape
=
(
ishape
[
0
],
ishape
[
1
]
*
scale
,
ishape
[
1
]
*
scale
,
ishape
[
3
])
output_np
=
np
.
zeros
(
oshape
,
dtype
=
data
.
dtype
)
output_np
=
np
.
zeros
(
oshape
,
dtype
=
data
.
dtype
)
for
b
in
range
(
oshape
[
0
]):
for
b
in
range
(
oshape
[
0
]):
for
c
in
range
(
oshape
[
3
]):
for
c
in
range
(
oshape
[
3
]):
output_np
[
b
,
:,
:,
c
]
=
upsample_nearest
(
data
[
b
,
:,
:,
c
],
scale
)
output_np
[
b
,
:,
:,
c
]
=
upsample_nearest
(
data
[
b
,
:,
:,
c
],
scale
)
return
output_np
return
output_np
else
:
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
topi/python/topi/x86/bitserial_conv2d.py
View file @
e20ef0d4
...
@@ -59,7 +59,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
...
@@ -59,7 +59,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
if
out_dtype
is
None
:
if
out_dtype
is
None
:
out_dtype
=
data
.
dtype
out_dtype
=
data
.
dtype
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
layout
==
"NCHW"
or
layout
==
"NHWC"
,
"only support layouts NCHW and NHWC"
assert
layout
in
(
"NCHW"
,
"NHWC"
)
,
"only support layouts NCHW and NHWC"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
sch
=
_get_schedule
(
wkl
,
layout
)
sch
=
_get_schedule
(
wkl
,
layout
)
...
...
topi/python/topi/x86/conv2d.py
View file @
e20ef0d4
...
@@ -71,11 +71,10 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
...
@@ -71,11 +71,10 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
_get_default_config
(
cfg
,
data
,
kernel
,
strides
,
padding
,
out_dtype
)
_get_default_config
(
cfg
,
data
,
kernel
,
strides
,
padding
,
out_dtype
)
return
_declaration_conv_impl
(
cfg
,
data
,
kernel
,
strides
,
return
_declaration_conv_impl
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
)
padding
,
dilation
,
layout
,
out_dtype
)
el
if
layout
==
'HWCN'
:
if
layout
==
'HWCN'
:
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
el
if
layout
==
'NHWC'
:
if
layout
==
'NHWC'
:
return
nn
.
conv2d_nhwc
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
return
nn
.
conv2d_nhwc
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
else
:
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
...
...
vta/python/vta/environment.py
View file @
e20ef0d4
...
@@ -223,9 +223,8 @@ class Environment(object):
...
@@ -223,9 +223,8 @@ class Environment(object):
"""The target host"""
"""The target host"""
if
self
.
TARGET
==
"pynq"
:
if
self
.
TARGET
==
"pynq"
:
return
"llvm -target=armv7-none-linux-gnueabihf"
return
"llvm -target=armv7-none-linux-gnueabihf"
el
if
self
.
TARGET
==
"sim"
:
if
self
.
TARGET
==
"sim"
:
return
"llvm"
return
"llvm"
else
:
raise
ValueError
(
"Unknown target
%
s"
%
self
.
TARGET
)
raise
ValueError
(
"Unknown target
%
s"
%
self
.
TARGET
)
...
...
vta/python/vta/graph.py
View file @
e20ef0d4
...
@@ -169,7 +169,7 @@ def clean_cast(graph):
...
@@ -169,7 +169,7 @@ def clean_cast(graph):
op_name
=
node
.
attr
(
"op_name"
)
op_name
=
node
.
attr
(
"op_name"
)
if
op_name
==
"cast"
:
if
op_name
==
"cast"
:
return
_clean_cast
(
node
.
get_children
(),
target_type
)
return
_clean_cast
(
node
.
get_children
(),
target_type
)
el
if
op_name
==
"relu"
:
if
op_name
==
"relu"
:
data
,
has_clip
=
_clean_cast
(
data
,
has_clip
=
_clean_cast
(
node
.
get_children
(),
target_type
)
node
.
get_children
(),
target_type
)
data
=
nnvm
.
sym
.
relu
(
data
)
data
=
nnvm
.
sym
.
relu
(
data
)
...
...
vta/python/vta/intrin.py
View file @
e20ef0d4
...
@@ -64,7 +64,7 @@ def gemm(env, mock=False):
...
@@ -64,7 +64,7 @@ def gemm(env, mock=False):
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
dev
.
vta_push_uop
)
if
index
==
0
or
index
==
2
:
if
index
in
(
0
,
2
)
:
irb
.
emit
(
tvm
.
call_extern
(
irb
.
emit
(
tvm
.
call_extern
(
"int32"
,
"VTAUopPush"
,
"int32"
,
"VTAUopPush"
,
0
,
0
,
0
,
0
,
...
...
vta/python/vta/ir_pass.py
View file @
e20ef0d4
...
@@ -77,7 +77,6 @@ def fold_uop_loop(stmt_in):
...
@@ -77,7 +77,6 @@ def fold_uop_loop(stmt_in):
args
.
append
(
m
[
1
])
args
.
append
(
m
[
1
])
args
+=
op
.
args
[
base_args
+
3
:]
args
+=
op
.
args
[
base_args
+
3
:]
return
tvm
.
call_extern
(
"int32"
,
"VTAUopPush"
,
*
args
)
return
tvm
.
call_extern
(
"int32"
,
"VTAUopPush"
,
*
args
)
else
:
if
op
.
name
not
in
(
"VTATLSCommandHandle"
,
"tvm_thread_context"
):
if
op
.
name
not
in
(
"VTATLSCommandHandle"
,
"tvm_thread_context"
):
raise
RuntimeError
(
"unexpected op
%
s"
%
op
)
raise
RuntimeError
(
"unexpected op
%
s"
%
op
)
return
op
return
op
...
@@ -165,21 +164,20 @@ def cpu_access_rewrite(stmt_in):
...
@@ -165,21 +164,20 @@ def cpu_access_rewrite(stmt_in):
op
.
condition
,
let_stmt
)
op
.
condition
,
let_stmt
)
del
rw_info
[
buffer_var
]
del
rw_info
[
buffer_var
]
return
alloc
return
alloc
el
if
isinstance
(
op
,
tvm
.
expr
.
Load
):
if
isinstance
(
op
,
tvm
.
expr
.
Load
):
buffer_var
=
op
.
buffer_var
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
tvm
.
var
(
rw_info
[
buffer_var
]
=
tvm
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
make
.
Load
(
op
.
dtype
,
new_var
,
op
.
index
)
return
tvm
.
make
.
Load
(
op
.
dtype
,
new_var
,
op
.
index
)
el
if
isinstance
(
op
,
tvm
.
stmt
.
Store
):
if
isinstance
(
op
,
tvm
.
stmt
.
Store
):
buffer_var
=
op
.
buffer_var
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
tvm
.
var
(
rw_info
[
buffer_var
]
=
tvm
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
make
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
return
tvm
.
make
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
else
:
raise
RuntimeError
(
"not reached"
)
raise
RuntimeError
(
"not reached"
)
stmt
=
tvm
.
ir_pass
.
IRTransform
(
stmt
=
tvm
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
...
@@ -233,22 +231,19 @@ def lift_alloc_to_scope_begin(stmt_in):
...
@@ -233,22 +231,19 @@ def lift_alloc_to_scope_begin(stmt_in):
if
op
.
attr_key
==
"virtual_thread"
:
if
op
.
attr_key
==
"virtual_thread"
:
lift_stmt
.
append
([])
lift_stmt
.
append
([])
return
None
def
_post_order
(
op
):
def
_post_order
(
op
):
if
isinstance
(
op
,
tvm
.
stmt
.
Allocate
):
if
isinstance
(
op
,
tvm
.
stmt
.
Allocate
):
lift_stmt
[
-
1
]
.
append
(
op
)
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
return
op
.
body
el
if
isinstance
(
op
,
tvm
.
stmt
.
AttrStmt
):
if
isinstance
(
op
,
tvm
.
stmt
.
AttrStmt
):
if
op
.
attr_key
==
"storage_scope"
:
if
op
.
attr_key
==
"storage_scope"
:
lift_stmt
[
-
1
]
.
append
(
op
)
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
return
op
.
body
el
if
op
.
attr_key
==
"virtual_thread"
:
if
op
.
attr_key
==
"virtual_thread"
:
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
op
return
op
el
if
isinstance
(
op
,
tvm
.
stmt
.
For
):
if
isinstance
(
op
,
tvm
.
stmt
.
For
):
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
else
:
raise
RuntimeError
(
"not reached"
)
raise
RuntimeError
(
"not reached"
)
stmt
=
tvm
.
ir_pass
.
IRTransform
(
stmt
=
tvm
.
ir_pass
.
IRTransform
(
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
...
@@ -297,7 +292,7 @@ def inject_coproc_sync(stmt_in):
...
@@ -297,7 +292,7 @@ def inject_coproc_sync(stmt_in):
sync
=
tvm
.
make
.
Call
(
sync
=
tvm
.
make
.
Call
(
"int32"
,
"vta.coproc_sync"
,
[],
tvm
.
expr
.
Call
.
Intrinsic
,
None
,
0
)
"int32"
,
"vta.coproc_sync"
,
[],
tvm
.
expr
.
Call
.
Intrinsic
,
None
,
0
)
return
tvm
.
make
.
Block
(
stmt
.
body
,
tvm
.
make
.
Evaluate
(
sync
))
return
tvm
.
make
.
Block
(
stmt
.
body
,
tvm
.
make
.
Evaluate
(
sync
))
el
if
_match_pragma
(
stmt
,
"trim_loop"
):
if
_match_pragma
(
stmt
,
"trim_loop"
):
op
=
stmt
.
body
op
=
stmt
.
body
assert
isinstance
(
op
,
tvm
.
stmt
.
For
)
assert
isinstance
(
op
,
tvm
.
stmt
.
For
)
return
tvm
.
make
.
For
(
return
tvm
.
make
.
For
(
...
@@ -584,7 +579,7 @@ def annotate_alu_coproc_scope(stmt_in):
...
@@ -584,7 +579,7 @@ def annotate_alu_coproc_scope(stmt_in):
tvm
.
make
.
StringImm
(
"VTAPushALUOp"
))
tvm
.
make
.
StringImm
(
"VTAPushALUOp"
))
irb
.
emit
(
stmt
)
irb
.
emit
(
stmt
)
return
irb
.
get
()
return
irb
.
get
()
el
if
_match_pragma
(
stmt
,
"skip_alu"
):
if
_match_pragma
(
stmt
,
"skip_alu"
):
return
tvm
.
make
.
Evaluate
(
0
)
return
tvm
.
make
.
Evaluate
(
0
)
return
stmt
return
stmt
...
...
vta/python/vta/top/vta_conv2d.py
View file @
e20ef0d4
...
@@ -193,7 +193,7 @@ def _build(funcs, target, target_host):
...
@@ -193,7 +193,7 @@ def _build(funcs, target, target_host):
tvm_t
=
tvm
.
target
.
create
(
target
)
tvm_t
=
tvm
.
target
.
create
(
target
)
if
tvm_t
.
device_name
==
"vta"
:
if
tvm_t
.
device_name
==
"vta"
:
return
tvm
.
build
(
funcs
,
target
=
"ext_dev"
,
target_host
=
target_host
)
return
tvm
.
build
(
funcs
,
target
=
"ext_dev"
,
target_host
=
target_host
)
el
if
tvm_t
.
device_name
==
"rasp"
or
tvm_t
.
device_name
==
"vtacpu"
:
if
tvm_t
.
device_name
==
"rasp"
or
tvm_t
.
device_name
==
"vtacpu"
:
return
tvm
.
build
(
funcs
,
target
=
target_host
)
return
tvm
.
build
(
funcs
,
target
=
target_host
)
return
tvm
.
build
(
funcs
,
target
=
target
)
return
tvm
.
build
(
funcs
,
target
=
target
)
...
@@ -279,9 +279,8 @@ def schedule_conv2d(attrs, outs, target):
...
@@ -279,9 +279,8 @@ def schedule_conv2d(attrs, outs, target):
target
=
tvm
.
target
.
create
(
target
)
target
=
tvm
.
target
.
create
(
target
)
if
target
.
device_name
==
"vta"
:
if
target
.
device_name
==
"vta"
:
return
schedule_packed_conv2d
(
outs
)
return
schedule_packed_conv2d
(
outs
)
el
if
str
(
target
)
.
startswith
(
"llvm"
):
if
str
(
target
)
.
startswith
(
"llvm"
):
return
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
return
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
else
:
raise
RuntimeError
(
"not support target
%
s"
%
target
)
raise
RuntimeError
(
"not support target
%
s"
%
target
)
return
_nn
.
schedule_conv2d
(
attrs
,
outs
,
target
)
return
_nn
.
schedule_conv2d
(
attrs
,
outs
,
target
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment