# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Test legalize pass""" import numpy as np import tvm from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.qnn.op import register_qnn_legalize from tvm.relay import transform, analysis def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = relay.Module.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body def test_qnn_legalize(): """Test directly replacing an operator with a new one""" def before(): x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8') y = relay.qnn.op.requantize(x, input_scale=1, input_zero_point=0, output_scale=1, output_zero_point=0, out_dtype='int8') y = relay.Function([x], y) return y @register_qnn_legalize("qnn.requantize", level=100) def legalize_qnn_requantize(attrs, inputs, types): data = inputs[0] data = relay.add(relay.const(0, 'int8'), data) y = relay.qnn.op.requantize(data, input_scale=1, input_zero_point=0, output_scale=1, output_zero_point=0, out_dtype='int8') return y def expected(): x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8') y = relay.add(relay.const(0, 'int8'), x) z = relay.qnn.op.requantize(y, input_scale=1, input_zero_point=0, output_scale=1, output_zero_point=0, out_dtype='int8') z = relay.Function([x], z) return z a = before() # Check that Relay Legalize does not change the graph. a = run_opt_pass(a, relay.transform.Legalize()) b = run_opt_pass(before(), transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) # Check that QNN Legalize modifies the graph. a = run_opt_pass(a, relay.qnn.transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) def test_qnn_legalize_qnn_conv2d(): data_shape = (1, 64, 256, 256) kernel_shape = (128, 64, 3, 3) for dtype in ['uint8', 'int8']: data_dtype = kernel_dtype = dtype data = relay.var("data", shape=data_shape, dtype=data_dtype) kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype) func = relay.qnn.op.conv2d( data, kernel, input_zero_point=1, kernel_zero_point=1, kernel_size=(3, 3), strides=(1, 1), dilation=(1, 1), out_dtype='int32', data_layout='NCHW', kernel_layout='OIHW') mod = relay.Function(relay.analysis.free_vars(func), func) mod = relay.Module.from_expr(mod) with tvm.target.create('llvm -mcpu=skylake-avx512'): mod = relay.qnn.transform.Legalize()(mod) assert 'cast' in mod.astext() if __name__ == "__main__": test_qnn_legalize() test_qnn_legalize_qnn_conv2d()