# -*- coding: utf-8 -*-
from torch.autograd import Function


class FakeQuantize(Function):

    @staticmethod
    def forward(ctx, x, qparam):  # 有qparam i.e. self 中记录的mode、scale、zeropoint、n_exp等信息，其实不用再额外传参
        x = qparam.quantize_tensor(x, qparam.mode)  # INT
        x = qparam.dequantize_tensor(x, qparam.mode)  # FP(int)
        return x

    @staticmethod
    def backward(ctx, grad_output):  # 用线性粗略近似 STE
        return grad_output, None