Unverified Commit 3f47b327 by mbaret Committed by GitHub

[RELAY] Move frontend utils (#5345)

* [RELAY] Move frontend utils

The util file currently under frontend is used from
outside of frontend (in qnn/op/legalizations). This suggests
that the file should be pushed up to a higher level.

The benefit from this change is that importing qnn no longer
also imports all the frontends.

* Inline get_scalar_from_constant

Change-Id: I1cc64e9ecb0eadb6ac0f7b62e6ea174644af4ad4

* Remove util.py from Relay

Change-Id: If9cd7cf3fc0bd1861a3a9b5604f338e084d8db96

* Shorten functions

Change-Id: Ieb537d82e6ee52421ff05a90cd00a03679ffebf2

* Line length

Change-Id: I1d216b7e73a060c4f118f5da50ce58b18eba907f
parent 952def53
...@@ -29,7 +29,6 @@ from .. import function as _function ...@@ -29,7 +29,6 @@ from .. import function as _function
from .. import op as _op from .. import op as _op
from .. import qnn as _qnn from .. import qnn as _qnn
from ... import nd as _nd from ... import nd as _nd
from .util import get_scalar_from_constant
from .common import ExprTable from .common import ExprTable
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
...@@ -2281,6 +2280,17 @@ class OperatorConverter(object): ...@@ -2281,6 +2280,17 @@ class OperatorConverter(object):
def has_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))
def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert isinstance(expr, _expr.Constant) and not expr.data.shape, \
"Expr is not a constant scalar."
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return np.asscalar(value)
def build_str_map(obj): def build_str_map(obj):
"""Build string map of TFLite enum int value """Build string map of TFLite enum int value
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
""" Utility functions that are used across many directories. """
from __future__ import absolute_import
import numpy as np
from .. import expr as _expr
def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert isinstance(expr, _expr.Constant) and not expr.data.shape, \
"Expr is not a constant scalar."
value = expr.data.asnumpy()
if value.dtype == np.dtype(np.int32):
return int(value)
if value.dtype == np.dtype(np.float32):
return float(value)
assert False, "Constant expr must be float32/int32"
return None # To suppress pylint
...@@ -20,8 +20,8 @@ from __future__ import absolute_import ...@@ -20,8 +20,8 @@ from __future__ import absolute_import
import tvm import tvm
from tvm import relay from tvm import relay
import numpy as np
from .. import op as reg from .. import op as reg
from ...frontend.util import get_scalar_from_constant
################################################# #################################################
# Register the functions for different operators. # Register the functions for different operators.
...@@ -54,6 +54,15 @@ def qnn_dense_legalize(attrs, inputs, types): ...@@ -54,6 +54,15 @@ def qnn_dense_legalize(attrs, inputs, types):
# Helper functions. # Helper functions.
################### ###################
def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert isinstance(expr, relay.Constant) and not expr.data.shape, \
"Expr is not a constant scalar."
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return np.asscalar(value)
# Helper function for lowering in the abscence of fast Int8 arithmetic units. # Helper function for lowering in the abscence of fast Int8 arithmetic units.
def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
""" Converts QNN operators into a sequence of Relay operators that are friendly to HW that do """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment