Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1e4aea81
Commit
1e4aea81
authored
Aug 22, 2019
by
Animesh Jain
Committed by
Yizhi Liu
Aug 23, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to read from out_types. (#3782)
parent
17f8f96b
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
100 additions
and
54 deletions
+100
-54
python/tvm/relay/op/nn/_nn.py
+18
-4
src/relay/pass/legalize.cc
+9
-3
src/relay/qnn/op/dequantize.cc
+2
-2
src/relay/qnn/op/quantize.cc
+2
-2
src/relay/qnn/op/requantize.cc
+20
-13
tests/python/relay/test_pass_legalize.py
+8
-7
topi/python/topi/arm_cpu/conv2d.py
+30
-12
topi/python/topi/nn/conv2d.py
+11
-11
No files found.
python/tvm/relay/op/nn/_nn.py
View file @
1e4aea81
...
@@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
...
@@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
return
topi
.
nn
.
conv2d_alter_layout
(
attrs
,
inputs
,
tinfos
,
op
)
return
topi
.
nn
.
conv2d_alter_layout
(
attrs
,
inputs
,
tinfos
,
op
)
@reg.register_legalize
(
"nn.conv2d"
)
@reg.register_legalize
(
"nn.conv2d"
)
def
legalize_conv2d
(
attrs
,
inputs
,
arg_dtypes
):
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
"""Legalize conv2d"""
"""Legalize conv2d op.
from
...
import
op
return
topi
.
nn
.
conv2d_legalize
(
attrs
,
inputs
,
arg_dtypes
,
op
)
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return
topi
.
nn
.
conv2d_legalize
(
attrs
,
inputs
,
types
)
reg
.
register_pattern
(
"nn.conv2d"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
reg
.
register_pattern
(
"nn.conv2d"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
...
...
src/relay/pass/legalize.cc
View file @
1e4aea81
...
@@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
...
@@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
Expr
new_e
;
Expr
new_e
;
bool
modified
=
false
;
bool
modified
=
false
;
if
(
fop_legalize
.
count
(
op
))
{
if
(
fop_legalize
.
count
(
op
))
{
tvm
::
Array
<
tvm
::
relay
::
Type
>
arg_types
;
// Collect input and output dtypes to pass on to Legalize API.
tvm
::
Array
<
tvm
::
relay
::
Type
>
types
;
for
(
auto
&
expr
:
ref_call
->
args
)
{
for
(
auto
&
expr
:
ref_call
->
args
)
{
arg_
types
.
push_back
(
expr
->
checked_type
());
types
.
push_back
(
expr
->
checked_type
());
}
}
Expr
legalized_value
=
fop_legalize
[
op
](
ref_call
->
attrs
,
new_args
,
arg_types
);
types
.
push_back
(
ref_call
->
checked_type
());
// Transform the op by calling the registered legalize function.
Expr
legalized_value
=
fop_legalize
[
op
](
ref_call
->
attrs
,
new_args
,
types
);
// Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if
(
legalized_value
.
defined
())
{
if
(
legalized_value
.
defined
())
{
new_e
=
legalized_value
;
new_e
=
legalized_value
;
modified
=
true
;
modified
=
true
;
...
...
src/relay/qnn/op/dequantize.cc
View file @
1e4aea81
...
@@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,
...
@@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,
Expr
DequantizeLegalize
(
const
Attrs
&
attrs
,
Expr
DequantizeLegalize
(
const
Attrs
&
attrs
,
const
Array
<
Expr
>&
new_args
,
const
Array
<
Expr
>&
new_args
,
const
Array
<
tvm
::
relay
::
Type
>&
arg_
types
)
{
const
Array
<
tvm
::
relay
::
Type
>&
types
)
{
CHECK_EQ
(
new_args
.
size
(),
1
);
CHECK_EQ
(
new_args
.
size
(),
1
);
auto
&
data
=
new_args
[
0
];
auto
&
data
=
new_args
[
0
];
const
auto
*
dequantize_attrs
=
attrs
.
as
<
DequantizeAttrs
>
();
const
auto
*
dequantize_attrs
=
attrs
.
as
<
DequantizeAttrs
>
();
CHECK
(
dequantize_attrs
!=
nullptr
);
CHECK
(
dequantize_attrs
!=
nullptr
);
CHECK_EQ
(
arg_types
.
size
(),
1
);
CHECK_EQ
(
types
.
size
(),
2
);
return
DequantizeLower
(
data
,
dequantize_attrs
);
return
DequantizeLower
(
data
,
dequantize_attrs
);
}
}
...
...
src/relay/qnn/op/quantize.cc
View file @
1e4aea81
...
@@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,
...
@@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,
Expr
QuantizeLegalize
(
const
Attrs
&
attrs
,
Expr
QuantizeLegalize
(
const
Attrs
&
attrs
,
const
Array
<
Expr
>&
new_args
,
const
Array
<
Expr
>&
new_args
,
const
Array
<
tvm
::
relay
::
Type
>&
arg_
types
)
{
const
Array
<
tvm
::
relay
::
Type
>&
types
)
{
CHECK_EQ
(
new_args
.
size
(),
1
);
CHECK_EQ
(
new_args
.
size
(),
1
);
auto
&
data
=
new_args
[
0
];
auto
&
data
=
new_args
[
0
];
const
auto
*
quantize_attrs
=
attrs
.
as
<
QuantizeAttrs
>
();
const
auto
*
quantize_attrs
=
attrs
.
as
<
QuantizeAttrs
>
();
CHECK
(
quantize_attrs
!=
nullptr
);
CHECK
(
quantize_attrs
!=
nullptr
);
CHECK_EQ
(
arg_types
.
size
(),
1
);
CHECK_EQ
(
types
.
size
(),
2
);
return
QuantizeLower
(
data
,
quantize_attrs
);
return
QuantizeLower
(
data
,
quantize_attrs
);
}
}
...
...
src/relay/qnn/op/requantize.cc
View file @
1e4aea81
...
@@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
...
@@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
* 7) Cast to the out_dtype.
* 7) Cast to the out_dtype.
*/
*/
Expr
RequantizeLower
(
const
Expr
&
input_tensor
,
const
RequantizeAttrs
*
param
,
Expr
RequantizeLower
(
const
Expr
&
input_tensor
,
const
RequantizeAttrs
*
param
,
const
Array
<
IndexExpr
>&
input_shape
)
{
const
Array
<
IndexExpr
>&
input_shape
,
const
DataType
&
out_dtype
)
{
double
double_multiplier
=
param
->
input_scale
/
param
->
output_scale
;
double
double_multiplier
=
param
->
input_scale
/
param
->
output_scale
;
// Choose high precision datatype to be int64. This is for avoiding overflow
// Choose high precision datatype to be int64. This is for avoiding overflow
...
@@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
...
@@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
auto
shifted_int64_t
=
Add
(
output_zp
,
scaled_int64_t
);
auto
shifted_int64_t
=
Add
(
output_zp
,
scaled_int64_t
);
// 7) Clip to the out_dtype min/max.
// 7) Clip to the out_dtype min/max.
auto
q_min
=
GetQmin
(
param
->
out_dtype
);
auto
q_min
=
GetQmin
(
out_dtype
);
auto
q_max
=
GetQmax
(
param
->
out_dtype
);
auto
q_max
=
GetQmax
(
out_dtype
);
auto
clipped_t
=
Clip
(
shifted_int64_t
,
q_min
,
q_max
);
auto
clipped_t
=
Clip
(
shifted_int64_t
,
q_min
,
q_max
);
return
Cast
(
clipped_t
,
param
->
out_dtype
);
return
Cast
(
clipped_t
,
out_dtype
);
}
}
/*
/*
...
@@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
...
@@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/
*/
Expr
RequantizeLegalize
(
const
Attrs
&
attrs
,
const
Array
<
Expr
>&
new_args
,
Expr
RequantizeLegalize
(
const
Attrs
&
attrs
,
const
Array
<
Expr
>&
new_args
,
const
Array
<
tvm
::
relay
::
Type
>&
arg_
types
)
{
const
Array
<
tvm
::
relay
::
Type
>&
types
)
{
CHECK_EQ
(
new_args
.
size
(),
1
);
CHECK_EQ
(
new_args
.
size
(),
1
);
auto
&
quantized_data
=
new_args
[
0
];
auto
&
quantized_data
=
new_args
[
0
];
const
auto
*
param
=
attrs
.
as
<
RequantizeAttrs
>
();
const
auto
*
param
=
attrs
.
as
<
RequantizeAttrs
>
();
CHECK
(
param
!=
nullptr
);
CHECK
(
param
!=
nullptr
);
// Find input shape.
// Find input shape.
CHECK_EQ
(
arg_types
.
size
(),
1
);
CHECK_EQ
(
types
.
size
(),
2
);
auto
input_dtype
=
arg_types
[
0
];
auto
in_type
=
types
[
0
];
auto
input_tensor_type
=
input_dtype
.
as
<
TensorTypeNode
>
();
auto
in_tensor_type
=
in_type
.
as
<
TensorTypeNode
>
();
CHECK
(
input_tensor_type
!=
nullptr
)
<<
"Type information missing."
CHECK
(
in_tensor_type
!=
nullptr
)
<<
"Type information missing."
<<
" Please run infer_type pass."
;
<<
" Please run infer_type pass."
;
Array
<
IndexExpr
>
input_shape
=
input_tensor_type
->
shape
;
Array
<
IndexExpr
>
input_shape
=
in_tensor_type
->
shape
;
// Find the output dtype.
auto
out_type
=
types
[
1
];
auto
out_tensor_type
=
out_type
.
as
<
TensorTypeNode
>
();
CHECK
(
out_tensor_type
!=
nullptr
)
<<
"Type information missing."
<<
" Please run infer_type pass."
;
auto
out_dtype
=
out_tensor_type
->
dtype
;
// Check rounding validity.
// Check rounding validity.
CHECK
(
param
->
rounding
==
"UPWARD"
||
param
->
rounding
==
"TONEAREST"
)
CHECK
(
param
->
rounding
==
"UPWARD"
||
param
->
rounding
==
"TONEAREST"
)
<<
"QNN requantize supports two rounding modes - UPWARD and "
<<
"QNN requantize supports two rounding modes - UPWARD and "
<<
"TONEAREST"
;
<<
"TONEAREST"
;
return
RequantizeLower
(
quantized_data
,
param
,
input_shape
);
return
RequantizeLower
(
quantized_data
,
param
,
input_shape
,
out_dtype
);
}
}
/*
/*
...
@@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
...
@@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_ou
pt
ut) * (Q_input - zp_input)
Q_output = zp_output + (scale_input)/(scale_ou
tp
ut) * (Q_input - zp_input)
)code"
TVM_ADD_FILELINE
)
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.RequantizeAttrs"
)
.
set_attrs_type_key
(
"relay.attrs.RequantizeAttrs"
)
...
...
tests/python/relay/test_pass_legalize.py
View file @
1e4aea81
...
@@ -47,7 +47,7 @@ def test_legalize():
...
@@ -47,7 +47,7 @@ def test_legalize():
return
y
return
y
@register_legalize
(
"nn.conv2d"
,
level
=
100
)
@register_legalize
(
"nn.conv2d"
,
level
=
100
)
def
legalize_conv2d
(
attrs
,
inputs
,
arg_
types
):
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
data
,
weight
=
inputs
data
,
weight
=
inputs
weight
=
relay
.
multiply
(
weight
,
relay
.
const
(
2.0
,
"float32"
))
weight
=
relay
.
multiply
(
weight
,
relay
.
const
(
2.0
,
"float32"
))
return
relay
.
nn
.
conv2d
(
data
,
weight
,
**
attrs
)
return
relay
.
nn
.
conv2d
(
data
,
weight
,
**
attrs
)
...
@@ -80,7 +80,7 @@ def test_legalize_none():
...
@@ -80,7 +80,7 @@ def test_legalize_none():
called
=
[
False
]
called
=
[
False
]
@register_legalize
(
"nn.global_max_pool2d"
,
level
=
101
)
@register_legalize
(
"nn.global_max_pool2d"
,
level
=
101
)
def
legalize_conv2d
(
attrs
,
inputs
,
arg_
types
):
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
called
[
0
]
=
True
called
[
0
]
=
True
return
None
return
None
...
@@ -103,12 +103,13 @@ def test_legalize_multi_input():
...
@@ -103,12 +103,13 @@ def test_legalize_multi_input():
return
func
return
func
@register_legalize
(
"concatenate"
,
level
=
100
)
@register_legalize
(
"concatenate"
,
level
=
100
)
def
legalize_concatenate
(
attrs
,
inputs
,
arg_
types
):
def
legalize_concatenate
(
attrs
,
inputs
,
types
):
# Check that the correct multi-input case is handled.
# Check that the correct multi-input case is handled.
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
assert
isinstance
(
inputs
[
0
],
tvm
.
relay
.
expr
.
Tuple
)
assert
isinstance
(
inputs
[
0
],
tvm
.
relay
.
expr
.
Tuple
)
assert
len
(
arg_types
)
==
1
assert
len
(
types
)
==
2
assert
isinstance
(
arg_types
[
0
],
tvm
.
relay
.
ty
.
TupleType
)
assert
isinstance
(
types
[
0
],
tvm
.
relay
.
ty
.
TupleType
)
assert
isinstance
(
types
[
1
],
tvm
.
relay
.
ty
.
TensorType
)
return
None
return
None
def
expected
():
def
expected
():
...
@@ -153,9 +154,9 @@ def test_legalize_arm_layout_functional():
...
@@ -153,9 +154,9 @@ def test_legalize_arm_layout_functional():
return
func
return
func
@register_legalize
(
"nn.conv2d"
,
level
=
101
)
@register_legalize
(
"nn.conv2d"
,
level
=
101
)
def
legalize_conv2d
(
attrs
,
inputs
,
arg_
types
):
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
from
topi.arm_cpu.conv2d
import
_conv2d_legalize
from
topi.arm_cpu.conv2d
import
_conv2d_legalize
return
_conv2d_legalize
(
attrs
,
inputs
,
arg_types
,
tvm
.
relay
.
op
)
return
_conv2d_legalize
(
attrs
,
inputs
,
types
)
a
=
before
()
a
=
before
()
b
=
run_opt_pass
(
a
,
transform
.
Legalize
())
b
=
run_opt_pass
(
a
,
transform
.
Legalize
())
...
...
topi/python/topi/arm_cpu/conv2d.py
View file @
1e4aea81
...
@@ -18,10 +18,11 @@
...
@@ -18,10 +18,11 @@
"""Conv2D schedule for ARM CPU"""
"""Conv2D schedule for ARM CPU"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
warnings
import
logging
import
tvm
import
tvm
from
tvm
import
autotvm
from
tvm
import
autotvm
from
tvm
import
relay
import
tvm.contrib.nnpack
import
tvm.contrib.nnpack
from
..generic
import
schedule_conv2d_nchw
,
schedule_conv2d_winograd_without_weight_transform
,
\
from
..generic
import
schedule_conv2d_nchw
,
schedule_conv2d_winograd_without_weight_transform
,
\
...
@@ -35,6 +36,8 @@ from ..nn import conv2d_legalize
...
@@ -35,6 +36,8 @@ from ..nn import conv2d_legalize
from
..nn.util
import
get_const_int
,
get_pad_tuple
from
..nn.util
import
get_const_int
,
get_pad_tuple
from
..nn.winograd_util
import
winograd_transform_matrices
from
..nn.winograd_util
import
winograd_transform_matrices
logger
=
logging
.
getLogger
(
'topi'
)
@autotvm.register_topi_compute
(
conv2d
,
'arm_cpu'
,
[
'direct'
])
@autotvm.register_topi_compute
(
conv2d
,
'arm_cpu'
,
[
'direct'
])
def
conv2d_arm_cpu
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
def
conv2d_arm_cpu
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
"""TOPI compute callback for conv2d
"""TOPI compute callback for conv2d
...
@@ -671,7 +674,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
...
@@ -671,7 +674,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
if
layout
!=
'NCHW'
:
if
layout
!=
'NCHW'
:
return
None
return
None
if
dilation
!=
(
1
,
1
):
if
dilation
!=
(
1
,
1
):
warnings
.
warn
(
"Does not support weight pre-transform for dilated convolution."
)
logger
.
warning
(
"Does not support weight pre-transform for dilated convolution."
)
return
None
return
None
data
,
kernel
=
tinfos
[
0
:
2
]
data
,
kernel
=
tinfos
[
0
:
2
]
...
@@ -786,31 +789,46 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
...
@@ -786,31 +789,46 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
return
None
return
None
@conv2d_legalize.register
(
"arm_cpu"
)
@conv2d_legalize.register
(
"arm_cpu"
)
def
_conv2d_legalize
(
attrs
,
inputs
,
arg_types
,
F
):
def
_conv2d_legalize
(
attrs
,
inputs
,
arg_types
):
if
F
.
__name__
!=
'tvm.relay.op'
:
"""Legalizes Conv2D op.
return
None
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if
attrs
[
'data_layout'
]
==
'NHWC'
:
if
attrs
[
'data_layout'
]
==
'NHWC'
:
data
,
kernel
=
inputs
data
,
kernel
=
inputs
if
attrs
[
'kernel_layout'
]
==
'HWIO'
:
if
attrs
[
'kernel_layout'
]
==
'HWIO'
:
# Handle HWIO layout. This is common in TF graph.
# Handle HWIO layout. This is common in TF graph.
kernel
=
F
.
transpose
(
kernel
,
axes
=
(
3
,
2
,
0
,
1
))
kernel
=
relay
.
transpose
(
kernel
,
axes
=
(
3
,
2
,
0
,
1
))
elif
attrs
[
'kernel_layout'
]
==
'HWOI'
:
elif
attrs
[
'kernel_layout'
]
==
'HWOI'
:
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel
=
F
.
transpose
(
kernel
,
axes
=
(
2
,
3
,
0
,
1
))
kernel
=
relay
.
transpose
(
kernel
,
axes
=
(
2
,
3
,
0
,
1
))
elif
attrs
[
'kernel_layout'
]
!=
'OIHW'
:
elif
attrs
[
'kernel_layout'
]
!=
'OIHW'
:
return
None
return
None
warnings
.
warn
(
"Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
logger
.
warning
(
"Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+
"fallback to NCHW. This can result in performance degradation."
)
+
"fallback to NCHW. This can result in performance degradation."
)
# Set new attrs for the tranposed conv.
# Set new attrs for the tranposed conv.
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
new_attrs
[
'data_layout'
]
=
'NCHW'
new_attrs
[
'data_layout'
]
=
'NCHW'
new_attrs
[
'kernel_layout'
]
=
'OIHW'
new_attrs
[
'kernel_layout'
]
=
'OIHW'
# Convert from NHWC to NCHW.
# Convert from NHWC to NCHW.
data
=
F
.
transpose
(
data
,
axes
=
(
0
,
3
,
1
,
2
))
data
=
relay
.
transpose
(
data
,
axes
=
(
0
,
3
,
1
,
2
))
conv
=
F
.
nn
.
conv2d
(
data
,
kernel
,
**
new_attrs
)
conv
=
relay
.
nn
.
conv2d
(
data
,
kernel
,
**
new_attrs
)
# Convert back to original NHWC layout.
# Convert back to original NHWC layout.
out
=
F
.
transpose
(
conv
,
axes
=
(
0
,
2
,
3
,
1
))
out
=
relay
.
transpose
(
conv
,
axes
=
(
0
,
2
,
3
,
1
))
return
out
return
out
return
None
return
None
topi/python/topi/nn/conv2d.py
View file @
1e4aea81
...
@@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
...
@@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
@tvm.target.generic_func
@tvm.target.generic_func
def
conv2d_legalize
(
attrs
,
inputs
,
arg_dtypes
,
F
):
def
conv2d_legalize
(
attrs
,
inputs
,
types
):
"""Legalizes Conv2D op.
"""Legalizes Conv2D op.
Parameters
Parameters
----------
----------
attrs :
nnvm.top.AttrDict or
tvm.attrs.Attrs
attrs : tvm.attrs.Attrs
Attributes of current convolution
Attributes of current convolution
inputs : list of tvm.relay.Expr
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
The args of the Relay expr to be legalized
arg_dtypes : list of types
types : list of types
List of types of input arguments
List of input and output types
F: symbol
The context, can be either nnvm.sym or relay.op
Returns
Note
-------
----
result : tvm.relay.Expr
Unlike other TOPI functions, this function operates on both graph level and operator level,
The legalized expr
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
"""
"""
# not to change by default
# not to change by default
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment