Commit 36201fe9 by Animesh Jain Committed by Zhi

[QNN][Relay] Calling Dialect passes from inside Relay Build API. (#3971)

parent a7873b0a
......@@ -154,6 +154,12 @@ class Op : public relay::Expr {
template <typename ValueType>
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
/*!
* \brief Checks if an attr is present in the registry.
* \param attr_name The name of the attribute.
* \return bool True if the attr is present.
*/
inline static bool HasAttr(const std::string& attr_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* \param op_name Name of the operator.
......@@ -171,6 +177,12 @@ class Op : public relay::Expr {
* \return reference to GenericOpMap
*/
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
/*!
* \brief Checks if the key is present in the registry
* \param key The attribute key
* \return bool True if the key is present
*/
TVM_DLL static const bool HasGenericAttr(const std::string& key);
};
/*! \brief Helper structure to register operators */
......@@ -393,6 +405,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
return OpMap<ValueType>(Op::GetGenericAttr(key));
}
inline bool Op::HasAttr(const std::string& key) {
return Op::HasGenericAttr(key);
}
inline OpNode* OpRegistry::get() {
return const_cast<OpNode*>(op_.operator->());
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/qnn/transform.h
*
* This file implements a pass manager for QNN ops using Relay Pass manager.
*/
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
#define TVM_RELAY_QNN_TRANSFORM_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
using relay::transform::Pass;
namespace qnn {
namespace transform {
/*!
* \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
* converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
* Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
* One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
* attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
* only QNN ops. One can register a transformation/legalization function for an op by using the
* FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
* gives us separation of concerns, leading to a better software practice. The legalization can be
* configured to happen per target.
*
* \return The pass.
*/
TVM_DLL Pass Legalize();
} // namespace transform
} // namespace qnn
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_QNN_TRANSFORM_H_
......@@ -27,6 +27,7 @@
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/qnn/transform.h>
#include <memory>
#include "utils.h"
......@@ -286,6 +287,15 @@ class RelayBuildModule : public runtime::ModuleNode {
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
Array<Pass> pass_seqs;
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}
pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
......@@ -309,11 +319,6 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}
// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
......
......@@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
return *it->second.get();
}
// Check if a key is present in the registry.
const bool Op::HasGenericAttr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
if (it == mgr->attr.end()) {
return false;
}
return true;
}
void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
......
......@@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
Expr new_e = ExprMutator::VisitExpr_(call_node);
Call new_call = Downcast<Call>(new_e);
// Check if the string is registered in the OpRegistry.
if (!Op::HasAttr(legalize_map_attr_name_)) {
return new_e;
}
// Collect the registered legalize function.
auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
Op op = Downcast<Op>(call_node->op);
if (fop_legalize.count(op)) {
// Collect the new_args.
tvm::Array<Expr> call_args = new_call->args;
// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto arg : call_node->args) {
types.push_back(arg->checked_type());
}
types.push_back(call_node->checked_type());
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
// Reassign new_e if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";
new_e = legalized_value;
auto call_op = call_node->op;
if (call_op.as<OpNode>()) {
Op op = Downcast<Op>(call_node->op);
if (fop_legalize.count(op)) {
// Collect the new_args.
tvm::Array<Expr> call_args = new_call->args;
// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto arg : call_node->args) {
types.push_back(arg->checked_type());
}
types.push_back(call_node->checked_type());
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
// Reassign new_e if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";
new_e = legalized_value;
}
}
}
......@@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")});
}
TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file relay/qnn/pass/legalize.cc
* \brief The Legalize wrapper for QNN.
*/
#include <tvm/relay/qnn/transform.h>
namespace tvm {
namespace relay {
namespace qnn {
namespace transform {
Pass Legalize() {
Array<Pass> pass_seqs;
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
return seq;
}
TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize);
} // namespace transform
} // namespace qnn
} // namespace relay
} // namespace tvm
......@@ -77,7 +77,6 @@ def get_qnn_func(data,
mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod
def get_funcs(data_shape,
......
......@@ -31,7 +31,6 @@ def test_dequantize_op():
input_zero_point=input_zero_point)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
......
......@@ -31,7 +31,6 @@ def test_quantize_op():
output_zero_point=output_zero_point,out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
......
......@@ -49,7 +49,6 @@ def test_requantize():
mod = relay.Function(relay.analysis.free_vars(mod), mod)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod
def same_scale_test():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment