# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=eval-used,invalid-name,too-many-arguments """Utility functions""" from tvm import relay from tvm.relay import transform def has_multiple_inputs(node_list, node_idx, input_names): """Check whether a node has multiple input nodes except variable nodes. Parameters ---------- node_list : list of dict of str to object List of all nodes in a graph. node_idx : int Node index to be checked. input_names : list of str List of input names of graph. Returns ------- out : bool Whether the specified node has multiple input nodes """ num_inputs = 0 node = node_list[node_idx] for in_idx in node["inputs"]: in_idx = in_idx[0] in_node = node_list[in_idx] # Exclude parameter nodes if in_node["op"] != "null" or is_input_node(in_node, input_names): num_inputs += 1 return num_inputs > 1 def is_input_node(node_entry, input_names): """Whether a node is an input node. Parameters ---------- node_entry : dict Node entry. input_names : list of str List of input names of graph. Returns ------- out : bool whether node is a input node. """ return "name" in node_entry and node_entry["name"] in input_names def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): """Bind input variables of a relay function expression to new shapes and/or dtypes. Parameters ---------- expr : tvm.relay.Expr.Function Input relay function expression. input_shapes : dict of str to tuple of int, optional Input shapes. input_dtypes : str or dict of str to str, optional Input dtypes. Returns ------- out : tvm.relay.Expr.Function Bind relay function expression. """ if input_shapes is None: return expr if isinstance(input_dtypes, str): input_dtypes = {key : input_dtypes for key in input_shapes.keys()} updated_input_dict = {} for input_name in input_shapes.keys(): updated_input = relay.var(input_name, shape=input_shapes[input_name], dtype=input_dtypes[input_name]) updated_input_dict[input_name] = updated_input rebind_dict = {} for var in expr.params: if var.name_hint in updated_input_dict: rebind_dict[var] = updated_input_dict[var.name_hint] updated_expr = relay.expr.bind(expr, rebind_dict) mod = relay.Module.from_expr(updated_expr) mod = transform.InferType()(mod) entry = mod["main"] return entry if isinstance(updated_expr, relay.Function) else entry.body