/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file legalize.cc * \brief Converts an expr to another expr. This pass can be used to transform an op based on its * shape, dtype or layout to another op or a sequence of ops. */ #include <tvm/te/operation.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/transform.h> namespace tvm { namespace relay { namespace legalize { // Call registered FTVMLegalize of an op // Returns the legalized expression class Legalizer : public ExprMutator { public: explicit Legalizer(const std::string& legalize_map_attr_name) : legalize_map_attr_name_{legalize_map_attr_name} {} Expr VisitExpr_(const CallNode* call_node) { // Get the new_call node without any changes to current call node. Expr new_e = ExprMutator::VisitExpr_(call_node); Call new_call = Downcast<Call>(new_e); // Check if the string is registered in the OpRegistry. if (!Op::HasAttr(legalize_map_attr_name_)) { return new_e; } // Collect the registered legalize function. auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_); auto call_op = call_node->op; if (call_op.as<OpNode>()) { Op op = Downcast<Op>(call_node->op); if (fop_legalize.count(op)) { // Collect the new_args. tvm::Array<Expr> call_args = new_call->args; // Collect input and output dtypes to pass on to Legalize API. tvm::Array<tvm::relay::Type> types; for (auto arg : call_node->args) { types.push_back(arg->checked_type()); } types.push_back(call_node->checked_type()); // Transform the op by calling the registered legalize function. Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types); // Reassign new_e if the transformation succeeded. if (legalized_value.defined()) { // Check that the returned Expr from legalize is CallNode. const CallNode* legalized_call_node = legalized_value.as<CallNode>(); CHECK(legalized_call_node) << "Can only replace the original operator with another call node"; new_e = legalized_value; } } } return new_e; } private: std::string legalize_map_attr_name_; }; Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { return Legalizer(legalize_map_attr_name).Mutate(expr); } } // namespace legalize namespace transform { Pass Legalize(const std::string& legalize_map_attr_name) { runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name)); }; return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize); } // namespace transform } // namespace relay } // namespace tvm