Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
36201fe9
Commit
36201fe9
authored
Oct 02, 2019
by
Animesh Jain
Committed by
Zhi
Oct 02, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[QNN][Relay] Calling Dialect passes from inside Relay Build API. (#3971)
parent
a7873b0a
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
177 additions
and
34 deletions
+177
-34
include/tvm/relay/op.h
+16
-0
include/tvm/relay/qnn/transform.h
+60
-0
src/relay/backend/build_module.cc
+10
-5
src/relay/ir/op.cc
+11
-0
src/relay/pass/legalize.cc
+33
-25
src/relay/qnn/pass/legalize.cc
+47
-0
tests/python/relay/test_op_qnn_conv2d.py
+0
-1
tests/python/relay/test_op_qnn_dequantize.py
+0
-1
tests/python/relay/test_op_qnn_quantize.py
+0
-1
tests/python/relay/test_op_qnn_requantize.py
+0
-1
No files found.
include/tvm/relay/op.h
View file @
36201fe9
...
@@ -154,6 +154,12 @@ class Op : public relay::Expr {
...
@@ -154,6 +154,12 @@ class Op : public relay::Expr {
template
<
typename
ValueType
>
template
<
typename
ValueType
>
inline
static
OpMap
<
ValueType
>
GetAttr
(
const
std
::
string
&
attr_name
);
inline
static
OpMap
<
ValueType
>
GetAttr
(
const
std
::
string
&
attr_name
);
/*!
/*!
* \brief Checks if an attr is present in the registry.
* \param attr_name The name of the attribute.
* \return bool True if the attr is present.
*/
inline
static
bool
HasAttr
(
const
std
::
string
&
attr_name
);
/*!
* \brief Get an Op for a given operator name.
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* Will raise an error if the op has not been registered.
* \param op_name Name of the operator.
* \param op_name Name of the operator.
...
@@ -171,6 +177,12 @@ class Op : public relay::Expr {
...
@@ -171,6 +177,12 @@ class Op : public relay::Expr {
* \return reference to GenericOpMap
* \return reference to GenericOpMap
*/
*/
TVM_DLL
static
const
GenericOpMap
&
GetGenericAttr
(
const
std
::
string
&
key
);
TVM_DLL
static
const
GenericOpMap
&
GetGenericAttr
(
const
std
::
string
&
key
);
/*!
* \brief Checks if the key is present in the registry
* \param key The attribute key
* \return bool True if the key is present
*/
TVM_DLL
static
const
bool
HasGenericAttr
(
const
std
::
string
&
key
);
};
};
/*! \brief Helper structure to register operators */
/*! \brief Helper structure to register operators */
...
@@ -393,6 +405,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
...
@@ -393,6 +405,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
return
OpMap
<
ValueType
>
(
Op
::
GetGenericAttr
(
key
));
return
OpMap
<
ValueType
>
(
Op
::
GetGenericAttr
(
key
));
}
}
inline
bool
Op
::
HasAttr
(
const
std
::
string
&
key
)
{
return
Op
::
HasGenericAttr
(
key
);
}
inline
OpNode
*
OpRegistry
::
get
()
{
inline
OpNode
*
OpRegistry
::
get
()
{
return
const_cast
<
OpNode
*>
(
op_
.
operator
->
());
return
const_cast
<
OpNode
*>
(
op_
.
operator
->
());
}
}
...
...
include/tvm/relay/qnn/transform.h
0 → 100644
View file @
36201fe9
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/qnn/transform.h
*
* This file implements a pass manager for QNN ops using Relay Pass manager.
*/
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
#define TVM_RELAY_QNN_TRANSFORM_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/relay/transform.h>
namespace
tvm
{
namespace
relay
{
using
relay
::
transform
::
Pass
;
namespace
qnn
{
namespace
transform
{
/*!
* \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
* converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
* Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
* One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
* attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
* only QNN ops. One can register a transformation/legalization function for an op by using the
* FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
* gives us separation of concerns, leading to a better software practice. The legalization can be
* configured to happen per target.
*
* \return The pass.
*/
TVM_DLL
Pass
Legalize
();
}
// namespace transform
}
// namespace qnn
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_QNN_TRANSFORM_H_
src/relay/backend/build_module.cc
View file @
36201fe9
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <tvm/runtime/vm.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/qnn/transform.h>
#include <memory>
#include <memory>
#include "utils.h"
#include "utils.h"
...
@@ -286,6 +287,15 @@ class RelayBuildModule : public runtime::ModuleNode {
...
@@ -286,6 +287,15 @@ class RelayBuildModule : public runtime::ModuleNode {
const
TargetsMap
&
targets
,
const
TargetsMap
&
targets
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
Array
<
Pass
>
pass_seqs
;
Array
<
Pass
>
pass_seqs
;
// Run all dialect legalization passes.
pass_seqs
.
push_back
(
relay
::
qnn
::
transform
::
Legalize
());
// Legalize pass is restricted to homogeneous execution for now.
if
(
targets
.
size
()
==
1
)
{
pass_seqs
.
push_back
(
transform
::
Legalize
());
}
pass_seqs
.
push_back
(
transform
::
SimplifyInference
());
pass_seqs
.
push_back
(
transform
::
SimplifyInference
());
PackedFunc
fskip
=
PackedFunc
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
PackedFunc
fskip
=
PackedFunc
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Expr
expr
=
args
[
0
];
Expr
expr
=
args
[
0
];
...
@@ -309,11 +319,6 @@ class RelayBuildModule : public runtime::ModuleNode {
...
@@ -309,11 +319,6 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs
.
push_back
(
transform
::
CanonicalizeCast
());
pass_seqs
.
push_back
(
transform
::
CanonicalizeCast
());
pass_seqs
.
push_back
(
transform
::
CanonicalizeOps
());
pass_seqs
.
push_back
(
transform
::
CanonicalizeOps
());
// Legalize pass is restricted to homogeneous execution for now.
if
(
targets
.
size
()
==
1
)
{
pass_seqs
.
push_back
(
transform
::
Legalize
());
}
// Alter layout transformation is only applied to homogeneous execution yet.
// Alter layout transformation is only applied to homogeneous execution yet.
if
(
targets
.
size
()
==
1
)
{
if
(
targets
.
size
()
==
1
)
{
pass_seqs
.
push_back
(
transform
::
AlterOpLayout
());
pass_seqs
.
push_back
(
transform
::
AlterOpLayout
());
...
...
src/relay/ir/op.cc
View file @
36201fe9
...
@@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
...
@@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
return
*
it
->
second
.
get
();
return
*
it
->
second
.
get
();
}
}
// Check if a key is present in the registry.
const
bool
Op
::
HasGenericAttr
(
const
std
::
string
&
key
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
mgr
->
mutex
);
auto
it
=
mgr
->
attr
.
find
(
key
);
if
(
it
==
mgr
->
attr
.
end
())
{
return
false
;
}
return
true
;
}
void
OpRegistry
::
UpdateAttr
(
const
std
::
string
&
key
,
void
OpRegistry
::
UpdateAttr
(
const
std
::
string
&
key
,
TVMRetValue
value
,
TVMRetValue
value
,
int
plevel
)
{
int
plevel
)
{
...
...
src/relay/pass/legalize.cc
View file @
36201fe9
...
@@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
...
@@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
call_node
);
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
call_node
);
Call
new_call
=
Downcast
<
Call
>
(
new_e
);
Call
new_call
=
Downcast
<
Call
>
(
new_e
);
// Check if the string is registered in the OpRegistry.
if
(
!
Op
::
HasAttr
(
legalize_map_attr_name_
))
{
return
new_e
;
}
// Collect the registered legalize function.
// Collect the registered legalize function.
auto
fop_legalize
=
Op
::
GetAttr
<
FTVMLegalize
>
(
legalize_map_attr_name_
);
auto
fop_legalize
=
Op
::
GetAttr
<
FTVMLegalize
>
(
legalize_map_attr_name_
);
Op
op
=
Downcast
<
Op
>
(
call_node
->
op
);
auto
call_op
=
call_node
->
op
;
if
(
call_op
.
as
<
OpNode
>
())
{
if
(
fop_legalize
.
count
(
op
))
{
Op
op
=
Downcast
<
Op
>
(
call_node
->
op
);
// Collect the new_args.
tvm
::
Array
<
Expr
>
call_args
=
new_call
->
args
;
if
(
fop_legalize
.
count
(
op
))
{
// Collect the new_args.
// Collect input and output dtypes to pass on to Legalize API.
tvm
::
Array
<
Expr
>
call_args
=
new_call
->
args
;
tvm
::
Array
<
tvm
::
relay
::
Type
>
types
;
for
(
auto
arg
:
call_node
->
args
)
{
// Collect input and output dtypes to pass on to Legalize API.
types
.
push_back
(
arg
->
checked_type
());
tvm
::
Array
<
tvm
::
relay
::
Type
>
types
;
}
for
(
auto
arg
:
call_node
->
args
)
{
types
.
push_back
(
call_node
->
checked_type
());
types
.
push_back
(
arg
->
checked_type
());
}
// Transform the op by calling the registered legalize function.
types
.
push_back
(
call_node
->
checked_type
());
Expr
legalized_value
=
fop_legalize
[
op
](
call_node
->
attrs
,
call_args
,
types
);
// Transform the op by calling the registered legalize function.
// Reassign new_e if the transformation succeeded.
Expr
legalized_value
=
fop_legalize
[
op
](
call_node
->
attrs
,
call_args
,
types
);
if
(
legalized_value
.
defined
())
{
// Check that the returned Expr from legalize is CallNode.
// Reassign new_e if the transformation succeeded.
const
CallNode
*
legalized_call_node
=
legalized_value
.
as
<
CallNode
>
();
if
(
legalized_value
.
defined
())
{
CHECK
(
legalized_call_node
)
// Check that the returned Expr from legalize is CallNode.
<<
"Can only replace the original operator with another call node"
;
const
CallNode
*
legalized_call_node
=
legalized_value
.
as
<
CallNode
>
();
CHECK
(
legalized_call_node
)
new_e
=
legalized_value
;
<<
"Can only replace the original operator with another call node"
;
new_e
=
legalized_value
;
}
}
}
}
}
...
@@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
...
@@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
relay
::
legalize
::
Legalize
(
f
,
legalize_map_attr_name
));
return
Downcast
<
Function
>
(
relay
::
legalize
::
Legalize
(
f
,
legalize_map_attr_name
));
};
};
return
CreateFunctionPass
(
pass_func
,
3
,
"Legalize"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
return
CreateFunctionPass
(
pass_func
,
0
,
"Legalize"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
}
TVM_REGISTER_API
(
"relay._transform.Legalize"
).
set_body_typed
(
Legalize
);
TVM_REGISTER_API
(
"relay._transform.Legalize"
).
set_body_typed
(
Legalize
);
...
...
src/relay/qnn/pass/legalize.cc
0 → 100644
View file @
36201fe9
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file relay/qnn/pass/legalize.cc
* \brief The Legalize wrapper for QNN.
*/
#include <tvm/relay/qnn/transform.h>
namespace
tvm
{
namespace
relay
{
namespace
qnn
{
namespace
transform
{
Pass
Legalize
()
{
Array
<
Pass
>
pass_seqs
;
pass_seqs
.
push_back
(
relay
::
transform
::
Legalize
(
"FTVMQnnLegalize"
));
pass_seqs
.
push_back
(
relay
::
transform
::
Legalize
(
"FTVMQnnCanonicalize"
));
relay
::
transform
::
Pass
seq
=
relay
::
transform
::
Sequential
(
pass_seqs
);
return
seq
;
}
TVM_REGISTER_API
(
"relay.qnn._transform.Legalize"
).
set_body_typed
(
Legalize
);
}
// namespace transform
}
// namespace qnn
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_qnn_conv2d.py
View file @
36201fe9
...
@@ -77,7 +77,6 @@ def get_qnn_func(data,
...
@@ -77,7 +77,6 @@ def get_qnn_func(data,
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
func
),
func
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
func
),
func
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
qnn
.
transform
.
CanonicalizeOps
()(
mod
)
return
mod
return
mod
def
get_funcs
(
data_shape
,
def
get_funcs
(
data_shape
,
...
...
tests/python/relay/test_op_qnn_dequantize.py
View file @
36201fe9
...
@@ -31,7 +31,6 @@ def test_dequantize_op():
...
@@ -31,7 +31,6 @@ def test_dequantize_op():
input_zero_point
=
input_zero_point
)
input_zero_point
=
input_zero_point
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
quantized_output
),
quantized_output
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
quantized_output
),
quantized_output
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
qnn
.
transform
.
CanonicalizeOps
()(
mod
)
with
relay
.
build_config
(
opt_level
=
3
):
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
mod
,
"llvm"
,
params
=
None
)
graph
,
lib
,
params
=
relay
.
build
(
mod
,
"llvm"
,
params
=
None
)
rt_mod
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
=
tvm
.
cpu
(
0
))
rt_mod
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
=
tvm
.
cpu
(
0
))
...
...
tests/python/relay/test_op_qnn_quantize.py
View file @
36201fe9
...
@@ -31,7 +31,6 @@ def test_quantize_op():
...
@@ -31,7 +31,6 @@ def test_quantize_op():
output_zero_point
=
output_zero_point
,
out_dtype
=
out_dtype
)
output_zero_point
=
output_zero_point
,
out_dtype
=
out_dtype
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
quantized_output
),
quantized_output
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
quantized_output
),
quantized_output
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
qnn
.
transform
.
CanonicalizeOps
()(
mod
)
with
relay
.
build_config
(
opt_level
=
3
):
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
mod
,
"llvm"
,
params
=
None
)
graph
,
lib
,
params
=
relay
.
build
(
mod
,
"llvm"
,
params
=
None
)
rt_mod
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
=
tvm
.
cpu
(
0
))
rt_mod
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
=
tvm
.
cpu
(
0
))
...
...
tests/python/relay/test_op_qnn_requantize.py
View file @
36201fe9
...
@@ -49,7 +49,6 @@ def test_requantize():
...
@@ -49,7 +49,6 @@ def test_requantize():
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
mod
),
mod
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
mod
),
mod
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
qnn
.
transform
.
CanonicalizeOps
()(
mod
)
return
mod
return
mod
def
same_scale_test
():
def
same_scale_test
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment