Commit 6183244f by Zhihong Ma

fix: BN trial

parent 147e0fbb
......@@ -444,7 +444,7 @@ class QBN(QModule):
# 量化 weight 且weight实际上是可以直接用于相乘的 (已 -zeropoint)用于finetune后准备量化推理了
# self.bn_module.weight.data = self.qw.quantize_tensor(self.bn_module.weight.data, self.mode)
# self.bn_module.weight.data = self.bn_module.weight.data - self.qw.zero_point
self.bn_module.weight.data = FakeQuantize.apply(self.bn_module.wegiht, self.qw)
self.bn_module.weight.data = FakeQuantize.apply(self.bn_module.weight, self.qw)
# 量化 bias
# bias的num_bits是否也应该受设备量化位宽限制
# self.bn_module.bias.data = quantize_tensor(self.bn_module.bias.data,
......@@ -510,18 +510,18 @@ class QBN(QModule):
# qi 在init时就被定了mode
self.qi.update(x) # qi中包含了伪量化层的参数、方法
x = FakeQuantize.apply(x, self.qi) # forward: FP->INT->FP (qi: input的量化) 量化再恢复
# self.qw.update(self.bn_module.weight.data)
# self.qb.update(self.bn_module.bias.data)
self.qw.update(self.bn_module.weight.data)
self.qb.update(self.bn_module.bias.data)
# bn_q= torch.nn.BatchNorm2d(num_features=self.bn_module.num_features, affine=self.bn_module.affine, eps=self.bn_module.eps,momentum=self.bn_module.momentum, track_running_stats=self.bn_module.track_running_stats)
# bn_q.weight = FakeQuantize.apply(self.bn_module.weight, self.qw)
# bn_q.bias = FakeQuantize.apply(self.bn_module.bias, self.qb)
# bn_q.running_mean = self.bn_module.running_mean
# bn_q.running_var = self.bn_module.running_var
# x=bn_q(x)
x = self.bn_module(x)
bn_q= torch.nn.BatchNorm2d(num_features=self.bn_module.num_features, affine=self.bn_module.affine, eps=self.bn_module.eps,momentum=self.bn_module.momentum, track_running_stats=self.bn_module.track_running_stats)
bn_q.weight.data = FakeQuantize.apply(self.bn_module.weight, self.qw)
bn_q.bias.data = FakeQuantize.apply(self.bn_module.bias, self.qb)
bn_q.running_mean.data = self.bn_module.running_mean
bn_q.running_var.data = self.bn_module.running_var
x=bn_q(x)
# x = self.bn_module(x)
if hasattr(self, 'qo'):
self.qo.update(x)
......
# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
# from model import *
import argparse
import torch
......
......@@ -57,9 +57,11 @@ class LeNet(nn.Module):
# qi=true: 前一层输出的结果是没有量化过的,需要量化。 maxpool和relu都不会影响INT和minmax,所以在这俩之后的层的pi是false
#若前一层是conv,数据minmax被改变,则需要qi=true来量化
'qconv1': QConv2d(self.conv_layers['conv1'], qi=True, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
'qbn1':QBN(self.conv_layers['bn1'],qi=False,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),
'qreluc1': QReLU(n_exp=self.n_exp, mode=self.mode),
'qpool1': QMaxPooling2d(kernel_size=2,stride=2,padding=0, n_exp=self.n_exp, mode=self.mode),
'qconv2': QConv2d(self.conv_layers['conv2'], qi=False, qo=True, num_bits=num_bits, n_exp=self.n_exp, mode=self.mode),
'qbn2':QBN(self.conv_layers['bn2'],qi=False,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode),
'qreluc2': QReLU(n_exp=self.n_exp, mode=self.mode),
'qpool2': QMaxPooling2d(kernel_size=2, stride=2, padding=0, n_exp=self.n_exp, mode=self.mode)
})
......@@ -74,7 +76,9 @@ class LeNet(nn.Module):
def quantize_forward(self, x):
for _, layer in self.quantize_conv_layers.items():
for s, layer in self.quantize_conv_layers.items():
# print(s)
# print(layer)
x = layer(x)
output = x.view(-1,16*5*5)
......@@ -88,14 +92,26 @@ class LeNet(nn.Module):
def freeze(self):
self.quantize_conv_layers['qconv1'].freeze()
self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qbn1'].freeze(qi=self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qconv1'].qo)
# self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qreluc1'].freeze(self.quantize_conv_layers['qbn1'].qo)
self.quantize_conv_layers['qpool1'].freeze(self.quantize_conv_layers['qbn1'].qo)
self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qbn1'].qo)
self.quantize_conv_layers['qbn2'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
# self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo)
# self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)
# self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qconv2'].freeze(self.quantize_conv_layers['qconv1'].qo)
self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qconv2'].qo)
self.quantize_conv_layers['qreluc2'].freeze(self.quantize_conv_layers['qbn2'].qo)
self.quantize_conv_layers['qpool2'].freeze(self.quantize_conv_layers['qbn2'].qo)
self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qbn2'].qo)
self.quantize_fc_layers['qfc1'].freeze(qi=self.quantize_conv_layers['qconv2'].qo)
self.quantize_fc_layers['qreluf1'].freeze(self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qfc2'].freeze(qi=self.quantize_fc_layers['qfc1'].qo)
self.quantize_fc_layers['qreluf2'].freeze(self.quantize_fc_layers['qfc2'].qo)
......@@ -120,6 +136,7 @@ class LeNet(nn.Module):
x = self.quantize_conv_layers['qconv1'].qi.quantize_tensor(x, self.mode)
for s, layer in self.quantize_conv_layers.items():
print(s)
x = layer.quantize_inference(x)
output = x.view( -1,16*5*5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment