Commit fed79b3a by Animesh Jain Committed by Zhi

[QNN] Quantize - Fixing the sequence of lowering. (#4316)

parent dc5f70ad
...@@ -48,8 +48,8 @@ bool QuantizeRel(const Array<Type>& types, ...@@ -48,8 +48,8 @@ bool QuantizeRel(const Array<Type>& types,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>(); const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
const Array<tvm::Expr> oshape = data->shape; const Array<tvm::Expr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype; const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8)) CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
<< "Output type should be one of [int8, unit8 ] but was " << out_dtype; << "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
// assign output type // assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
return true; return true;
...@@ -72,12 +72,12 @@ Expr MakeQuantize(Expr data, ...@@ -72,12 +72,12 @@ Expr MakeQuantize(Expr data,
Expr QuantizeLower(const Expr& input_tensor, Expr QuantizeLower(const Expr& input_tensor,
const QuantizeAttrs* attrs) { const QuantizeAttrs* attrs) {
const auto out_dtype = attrs->out_dtype; const auto out_dtype = attrs->out_dtype;
const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point); const auto output_zero_point = MakeConstantScalar(Float(32), attrs->output_zero_point);
const auto scale = MakeConstantScalar(Float(32), attrs->output_scale); const auto scale = MakeConstantScalar(Float(32), attrs->output_scale);
const int32_t min_val = GetQmin(out_dtype); const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype); const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32)); auto scale_data = Divide(input_tensor, scale);
auto add_zero_point = Add(scale_data, output_zero_point); auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), Int(32));
auto clamped_output = Clip(add_zero_point, min_val, max_val); auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype); auto clamp_out_dtype = Cast(clamped_output, out_dtype);
return clamp_out_dtype; return clamp_out_dtype;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment