# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. #pylint: disable=unused-argument,inconsistent-return-statements """Internal module for registering attribute for annotation.""" from __future__ import absolute_import from ... import target as _target from .. import expr as _expr from .. import analysis as _analysis from ..base import register_relay_node from ..op import op as _reg from . import _quantize from .quantize import _forward_op def register_partition_function(op_name, frewrite=None, level=10): def _register(func): return _reg._Register(op_name, "FQPartitionRewrite", func, level) return _register(frewrite) if frewrite is not None else _register @register_relay_node class QPartitionExpr(_expr.TempExpr): def __init__(self, expr): self.__init_handle_by_constructor__( _quantize.make_partition_expr, expr) def partition_expr_check(expr): if isinstance(expr, QPartitionExpr): return True, expr.expr return False, expr @register_partition_function("nn.conv2d") def conv2d_partition_function(ref_call, new_args, ctx): """Rewrite function for conv2d for partition""" data_cond, data = partition_expr_check(new_args[0]) kernel_cond, kernel = partition_expr_check(new_args[1]) assert not kernel_cond if data_cond: data = new_args[0].realize() ret = _forward_op(ref_call, [data, kernel]) return QPartitionExpr(ret) def identity_partition_function(ref_call, new_args, ctx): cond, expr = partition_expr_check(new_args[0]) if cond: return QPartitionExpr(_forward_op(ref_call, [expr])) return None register_partition_function("clip", identity_partition_function) register_partition_function("nn.relu", identity_partition_function) register_partition_function("nn.max_pool2d", identity_partition_function) def add_partition_generic(ref_call, new_args, ctx): """Rewrite function for ewise add for partition for generic devices""" lhs_cond, lhs = partition_expr_check(new_args[0]) rhs_cond, rhs = partition_expr_check(new_args[1]) if lhs_cond and rhs_cond: # - introduced by ResNet, when for the first residual connection # ... # %0 = nn.conv2d(%data, %meta[relay.Constant]) # %1 = add(%0, %meta[relay.Constant]) # %2 = nn.relu(%1) # %3 = nn.max_pool2d(%2) # ... # %9 = nn.conv2d(%8, %meta[relay.Constant]) # %10 = add(%9, %meta[relay.Constant]) # %11 = add(%3, %10) <- need to insert annotations for %3, %10 # ... lhs = new_args[0].realize() rhs = new_args[1].realize() return _forward_op(ref_call, [lhs, rhs]) elif not lhs_cond and rhs_cond: # - introduced by residual connection in ResNet # ... # %13 = nn.conv2d(%12, %meta[relay.Constant]) # %14 = add(%13, %meta[relay.Constant]) # %15 = annotation.cast_hint(%15, 'int8') # %16 = annotation.stop_fusion(%16) # %17 = add(%5, %16) # %18 = nn.relu(%17) # ... # %24 = nn.conv2d(%23, %meta[relay.Constant]) # %25 = add(%24, %meta[relay.Constant]) # %26 = add(%18, %25) <- need to insert annotations for %25 # ... rhs = new_args[1].realize() return _forward_op(ref_call, [lhs, rhs]) elif lhs_cond and not rhs_cond: if _analysis.check_constant(rhs): # - introduced by batch_norm: add(out, bias) return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) # - introduced by residual connection in MobileNetV2 # ... # %81 = add(%80, meta[relay.Constant]) # %82 = annotation.cast_hint(%81, 'int8') # %83 = annotation.stop_fusion(%82) # %84 = add(%79, %83) # ... # %96 = nn.conv2d(%94, %meta[relay.Constant]) # %96 = add(%95, %meta[relay.Constant]) # %97 = add(%96, %84) <- need to insert annotations for %96 # ... lhs = new_args[0].realize() return _forward_op(ref_call, [lhs, rhs]) elif not lhs_cond and not rhs_cond: # trivial case return None else: raise ValueError # TODO(ziheng) enhance `register_partition_function` to dispatch # for target automatically @register_partition_function("add") def add_partition_function(ref_call, new_args, ctx): """Rewrite function for ewise add for partition""" target = _target.current_target() if target and 'cuda' in target.keys: #TODO(wuwei/ziheng) cuda specific rules return add_partition_generic(ref_call, new_args, ctx) return add_partition_generic(ref_call, new_args, ctx) @register_partition_function("multiply") def multiply_partition_function(ref_call, new_args, ctx): """Rewrite function for ewise add for partition""" lhs_cond, lhs = partition_expr_check(new_args[0]) rhs_cond, rhs = partition_expr_check(new_args[1]) if lhs_cond: # introduced by bn: multiply(out, scale) return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) assert (not lhs_cond) and (not rhs_cond) return None