Commit d465b618 by Klin
parents 55cf8622 85f1d4cd
GPUID: 0
WORKERS: 1
PRINT_FREQ: 10
SAVE_FREQ: 10
PIN_MEMORY: False
OUTPUT_DIR: 'output'
DATASET:
DATASET: 360CC
ROOT: "../CRNN/image"
CHAR_FILE: 'dataset/txt/char_std_5990.txt'
JSON_FILE: {'train': 'dataset/txt/train.txt', 'val': 'dataset/txt/test.txt'}
SCALE_FACTOR: 0.25
ROT_FACTOR: 30
STD: 0.193
MEAN: 0.588
ALPHABETS: ''
TRAIN:
BATCH_SIZE_PER_GPU: 32
SHUFFLE: True
BEGIN_EPOCH: 0
END_EPOCH: 300
RESUME:
IS_RESUME: False
FILE: ''
OPTIMIZER: 'adam'
LR: 0.0001
WD: 0.0
LR_STEP: [60, 80]
LR_FACTOR: 0.1
MOMENTUM: 0.0
NESTEROV: False
RMSPROP_ALPHA:
RMSPROP_CENTERED:
SAVE: true
TEST:
BATCH_SIZE_PER_GPU: 16
SHUFFLE: True # for random test rather than test on the whole validation set
NUM_TEST_BATCH: 3000
NUM_TEST_DISP: 10 # 每次显示多少个测试比对
MODEL:
NAME: 'lstm-ocr'
IMAGE_SIZE:
OW: 280 # origial width: 280
H: 32
W: 160 # resized width: 160
NUM_CLASSES: 0
NUM_HIDDEN: 256
GPUID: 0
WORKERS: 1
PRINT_FREQ: 10
SAVE_FREQ: 10
PIN_MEMORY: False
OUTPUT_DIR: 'output'
CUDNN:
BENCHMARK: True
DETERMINISTIC: False
ENABLED: True
DATASET:
DATASET: OWN
ROOT: "H:/DL-DATASET/360M/images"
JSON_FILE: {'train': 'lib/dataset/txt/train_own.txt', 'val': 'lib/dataset/txt/test_own.txt'}
SCALE_FACTOR: 0.25
ROT_FACTOR: 30
STD: 0.193
MEAN: 0.588
ALPHABETS: ''
TRAIN:
BATCH_SIZE_PER_GPU: 32
SHUFFLE: True
BEGIN_EPOCH: 0
END_EPOCH: 100
RESUME:
IS_RESUME: False
FILE: ''
OPTIMIZER: 'adam'
LR: 0.0001
WD: 0.0
LR_STEP: [60, 80]
LR_FACTOR: 0.1
MOMENTUM: 0.0
NESTEROV: False
RMSPROP_ALPHA:
RMSPROP_CENTERED:
FINETUNE:
IS_FINETUNE: true
FINETUNE_CHECKPOINIT: 'output/checkpoints/mixed_second_finetune_acc_97P7.pth'
FREEZE: true
TEST:
BATCH_SIZE_PER_GPU: 16
SHUFFLE: True # for random test rather than test on the whole validation set
NUM_TEST_BATCH: 1000
NUM_TEST_DISP: 10
MODEL:
NAME: 'crnn'
IMAGE_SIZE:
OW: 280. # origial width: 280
H: 32
W: 160 # resized width: 160
NUM_CLASSES: 0
NUM_HIDDEN: 256
from __future__ import print_function, absolute_import
import torch.utils.data as data
import os
import numpy as np
import cv2
class _360CC(data.Dataset):
def __init__(self, config, is_train=True):
self.root = config.DATASET.ROOT
self.is_train = is_train
self.inp_h = config.MODEL.IMAGE_SIZE.H
self.inp_w = config.MODEL.IMAGE_SIZE.W
self.dataset_name = config.DATASET.DATASET
self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
self.std = np.array(config.DATASET.STD, dtype=np.float32)
char_file = config.DATASET.CHAR_FILE # 字典?
with open(char_file, 'rb') as file:
# 读取char_file中的所有行的文字,并以字典形式存储键值(序号)和(去除了前后空格的)字符编码
char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
# convert name:indices to name:string
self.labels = []
with open(txt_file, 'r', encoding='utf-8') as file:
contents = file.readlines()
for c in contents:
imgname = c.split(' ')[0]
indices = c.split(' ')[1:]
# 在这里相当于对train.txt / test.txt中的 以index list形式为label进行了decode,得到了str
string = ''.join([char_dict[int(idx)] for idx in indices])
self.labels.append({imgname: string})
print("load {} images!".format(self.__len__()))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name = list(self.labels[idx].keys())[0] # 即 label
img = cv2.imread(os.path.join(self.root, img_name))
# 转化为灰度图像
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 获取图像原始高度
img_h, img_w = img.shape
# 整体缩放到inp_h,inp_w
img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
img = np.reshape(img, (self.inp_h, self.inp_w, 1))
img = img.astype(np.float32)
# yaml中配置
img = (img/255. - self.mean) / self.std
# chanel,h,w
img = img.transpose([2, 0, 1])
return img, idx
from ._360cc import _360CC
from ._own import _OWN
def get_dataset(config):
if config.DATASET.DATASET == "360CC":
return _360CC
elif config.DATASET.DATASET == "OWN":
return _OWN
else:
raise NotImplemented()
\ No newline at end of file
from __future__ import print_function, absolute_import
import torch.utils.data as data
import os
import numpy as np
import cv2
class _OWN(data.Dataset):
def __init__(self, config, is_train=True):
self.root = config.DATASET.ROOT
self.is_train = is_train
self.inp_h = config.MODEL.IMAGE_SIZE.H
self.inp_w = config.MODEL.IMAGE_SIZE.W
self.dataset_name = config.DATASET.DATASET
self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
self.std = np.array(config.DATASET.STD, dtype=np.float32)
txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
# convert name:indices to name:string
with open(txt_file, 'r', encoding='utf-8') as file:
self.labels = [{c.split(' ')[0]: c.split(' ')[-1][:-1]} for c in file.readlines()]
print("load {} images!".format(self.__len__()))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name = list(self.labels[idx].keys())[0]
img = cv2.imread(os.path.join(self.root, img_name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_h, img_w = img.shape
img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
img = np.reshape(img, (self.inp_h, self.inp_w, 1))
img = img.astype(np.float32)
img = (img/255. - self.mean) / self.std
img = img.transpose([2, 0, 1])
return img, idx
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -11,7 +11,7 @@ def extract_ratio(md='ResNet18'):
flop_ratio = []
for line in lines:
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line or 'LSTM' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
......
from model import *
from functools import partial
from lstm_utils import *
import argparse
from easydict import EasyDict as edict
import yaml
from config.alphabets import *
import sys
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
data_path = '../data/ptb'
embed_size = 700
hidden_size = 700
eval_batch_size = 10
dropout = 0.65
tied = True
def parse_arg():
parser = argparse.ArgumentParser(description="train crnn")
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
# parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='config/360CC_config.yaml')
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
args = parser.parse_args()
with open(args.cfg, 'r') as f:
# config = yaml.load(f, Loader=yaml.FullLoader)
config = yaml.load(f,Loader=yaml.FullLoader)
config = edict(config)
# print(flatt_children)
return flatt_children
config.DATASET.ALPHABETS = alphabet
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
return config
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
def lstm_constructor(shape,hidden):
return {"x": torch.zeros(shape,dtype=torch.int64),
"hidden": hidden}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
config = parse_arg()
model = model = get_crnn(config)
full_file = 'ckpt/cifar10_' + args.model + '.pt'
full_file = 'ckpt/360cc_lstm-ocr.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
flops, params = get_model_complexity_info(model, (1, 32, 160), as_strings=True,
print_per_layer_stat=True)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_optimizer(config, model):
optimizer = None
if config.TRAIN.OPTIMIZER == "sgd":
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
momentum=config.TRAIN.MOMENTUM,
weight_decay=config.TRAIN.WD,
nesterov=config.TRAIN.NESTEROV
)
elif config.TRAIN.OPTIMIZER == "adam":
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
)
elif config.TRAIN.OPTIMIZER == "rmsprop":
optimizer = optim.RMSprop(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
momentum=config.TRAIN.MOMENTUM,
weight_decay=config.TRAIN.WD,
# alpha=config.TRAIN.RMSPROP_ALPHA,
# centered=config.TRAIN.RMSPROP_CENTERED
)
return optimizer
def get_batch_label(d, i):
label = []
for idx in i:
label.append(list(d.labels[idx].values())[0])
return label
# 可能主要用于OWN dataset?
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
length = []
result = []
# 如果是字符形式的输入,则decode_flag = True
decode_flag = True if type(text[0])==bytes else False
# text是list of str
for item in text:
# 变为utf8编码
if decode_flag:
item = item.decode('utf-8','strict')
length.append(len(item))
# 对str中的每个字符
for char in item:
# 在字典中对应的编号
index = self.dict[char]
result.append(index)
text = result # 相当于给翻译成了alphabet中的index
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
# 判断length中是否只含一个元素
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1]) # alphabet是字符列表,但顺序被self.dict记录了
# 将一个字符串列表char_list中的所有元素按顺序连接起来形成一个新的字符串。
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()): # 遍历各个str,并调用上面的情况去处理decode
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
def get_char_dict(path):
with open(path, 'rb') as file:
char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
def model_info(model): # Plots a line-by-line description of a PyTorch model
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
print('\n%5s %50s %9s %12s %20s %12s %12s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
print('%5g %50s %9s %12g %20s %12.3g %12.3g' % (
i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
print('Model Summary: %g layers, %g parameters, %g gradients\n' % (i + 1, n_p, n_g))
\ No newline at end of file
CRNN(
11.77 M, 100.000% Params, 1.26 GMac, 100.000% MACs,
(layers): ModuleDict(
11.77 M, 100.000% Params, 1.26 GMac, 100.000% MACs,
(conv1): Conv2d(640, 0.005% Params, 3.28 MMac, 0.260% MACs, 1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU(0, 0.000% Params, 327.68 KMac, 0.026% MACs, )
(pool1): MaxPool2d(0, 0.000% Params, 327.68 KMac, 0.026% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(73.86 k, 0.628% Params, 94.54 MMac, 7.494% MACs, 64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2): ReLU(0, 0.000% Params, 163.84 KMac, 0.013% MACs, )
(pool2): MaxPool2d(0, 0.000% Params, 163.84 KMac, 0.013% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): Conv2d(295.17 k, 2.508% Params, 94.45 MMac, 7.487% MACs, 128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn3): BatchNorm2d(512, 0.004% Params, 163.84 KMac, 0.013% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu3): ReLU(0, 0.000% Params, 81.92 KMac, 0.006% MACs, )
(conv4): Conv2d(590.08 k, 5.014% Params, 188.83 MMac, 14.968% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4): ReLU(0, 0.000% Params, 81.92 KMac, 0.006% MACs, )
(pool4): MaxPool2d(0, 0.000% Params, 81.92 KMac, 0.006% MACs, kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
(conv5): Conv2d(1.18 M, 10.029% Params, 193.55 MMac, 15.342% MACs, 256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn5): BatchNorm2d(1.02 k, 0.009% Params, 167.94 KMac, 0.013% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu5): ReLU(0, 0.000% Params, 83.97 KMac, 0.007% MACs, )
(conv6): Conv2d(2.36 M, 20.053% Params, 387.01 MMac, 30.677% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu6): ReLU(0, 0.000% Params, 83.97 KMac, 0.007% MACs, )
(pool6): MaxPool2d(0, 0.000% Params, 83.97 KMac, 0.007% MACs, kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
(conv7): Conv2d(1.05 M, 8.915% Params, 43.01 MMac, 3.409% MACs, 512, 512, kernel_size=(2, 2), stride=(1, 1))
(bn7): BatchNorm2d(1.02 k, 0.009% Params, 41.98 KMac, 0.003% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu7): ReLU(0, 0.000% Params, 20.99 KMac, 0.002% MACs, )
(lstm1): LSTM(1.58 M, 13.401% Params, 64.87 MMac, 5.142% MACs, 512, 256, bidirectional=True)
(fc1): Linear(131.33 k, 1.116% Params, 5.37 MMac, 0.426% MACs, in_features=512, out_features=256, bias=True)
(lstm2): LSTM(1.05 M, 8.945% Params, 43.37 MMac, 3.438% MACs, 256, 256, bidirectional=True)
(fc2): Linear(3.46 M, 29.364% Params, 141.41 MMac, 11.209% MACs, in_features=512, out_features=6736, bias=True)
)
)
\ No newline at end of file
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
938.7733335552051 422.2961529040284 112.00572935619252 28.15632049948652 6.580025053557225 1.593147596698573 0.39127389747221836 0.09726359211545822 0.02424534011691382 0.006096811349829773 0.0015819717562818567 0.0004088973671629054 0.00010591612996765123 4.215430786749398e-05 3.358851529044255e-05 938.0795941095593 276.1298308688724 43.56510410351131 43.31743097440866 43.31651992748872 43.317170796987085 43.316909690885424 245.27340648023244 80.55429874312716 50.01809924502505 38.66623068838267 17.323188075553762 12.265772098808744 26.78403199897619 9.167642811282255 3.1779184189380314 12.23167946700013 22.667666751272108 6.595860316723723 0.8016993676302422 3.1649891963442665 12.231686340919987 21.030391880048075 5.630180110297391 0.20824827230948417 0.7939111200807045 3.164995719146672 12.23168958358928
js_param_list:
1083.2152054453256 407.68102883090796 142.77271018968833 40.694342955319094 9.528033474559528 2.3056188772063395 0.5663514708991146 0.1405112916168392 0.03504902772537398 0.008813186313906787 0.0022363630716486275 0.000566627677334268 0.00014559649495622196 4.1442517690921754e-05 3.062477331841085e-05 1083.0293894297583 286.18087410960214 48.04799916302135 47.56581240023642 47.5640587789577 47.565362072997914 47.56507051619791 306.72068771644035 110.02866499298202 52.861756413549614 51.56753626886545 17.228751722832826 13.602914776865827 35.21196010598172 8.706064889198597 3.4432562752359903 13.53603392058786 29.65906818846962 6.118510949236464 0.8693586052802382 3.417883320602458 13.536039876528541 27.482646086738573 5.183734622246258 0.22716556760816775 0.8541492170710019 3.417896091306703 13.536039813988335
ptq_acc_list:
0.0 0.0 0.0 0.172 0.627 0.772 0.7856666666666666 0.791 0.791 0.7936666666666666 0.793 0.7936666666666666 0.794 0.7936666666666666 0.794 0.0 0.0 0.087 0.09433333333333334 0.086 0.09533333333333334 0.09366666666666666 0.0 0.001 0.0033333333333333335 0.18733333333333332 0.343 0.591 0.4836666666666667 0.5843333333333334 0.7566666666666667 0.5906666666666667 0.557 0.657 0.7823333333333333 0.7503333333333333 0.5883333333333334 0.597 0.6843333333333333 0.792 0.785 0.7486666666666667 0.5933333333333334
acc_loss_list:
1.0 1.0 1.0 0.7834662190516157 0.21065883340327318 0.02811582039446074 0.010910616869492292 0.004196391103650818 0.004196391103650818 0.0008392782207302194 0.001678556441460299 0.0008392782207302194 0.00041963911036503983 0.0008392782207302194 0.00041963911036503983 1.0 1.0 0.8904741921947126 0.8812421317666806 0.8917331095258079 0.8799832144355854 0.8820814099874108 1.0 0.9987410826689047 0.9958036088963492 0.7641628199748216 0.5681913554343264 0.2559798573227025 0.39110365086026017 0.26437263953000417 0.04741921947125467 0.25639949643306753 0.2987830465799412 0.17289131347041542 0.01510700797314311 0.055392362568191404 0.25933697020562313 0.24842635333613097 0.13848090642047836 0.002937473772555558 0.011749895090222373 0.057490558120016744 0.2530423835501468
## update 2023.5.14
### 采用了CRNN的结构实现LSTM-OCR,相比于之前的版本有了更好、更通用的处理OCR问题的能力。我使用了360CC数据集(中文OCR训练数据集,共360w张图片,我只用了其中3w训练数据和3k张测试数据)来训练和测试模型。
1. config文件夹里是配置文件:
- alphabets.py是字典
- 360CC_config.yaml是使用360CC数据集时的配置文件
- OWN_config.yaml是使用其他数据集时的配置文件
2. dataset文件夹:
- txt文件夹中的char_std_5990.txt是360cc数据集的字典
- test.txt中是配置测试集数据
- train.txt中是配置训练集数据
- _360cc.py用于构建360cc dataset来训练/测试
- _own.py用于构建自己的dataset来训练/测试
3. lstm_utils.py中是LSTM-OCR中需要用到的一些特定的函数、类等
4. module.py中增加了对LSTM量化相关的组件,补充了对于BiLSTM的部分权值参数的伪量化,修改了QLinear的quantize_inference方法,使其仍然是真量化。
5. train.py中的train和validate函数考虑了OCR问题的特点,有了较大修改。同理ptq.py中的direct_inference, full_inference, quantize_inference也有一定修改。
6. 由于QLSTM采用伪量化,而其他层采用真量化,因此在model.py中与QLSTM直接连接的层需要通过dequantize_tensor(x)或quantize_tensor(x)来处理下数据。
7. ptq_result_lstm-ocr.txt和ptq_result_lstm-ocr.xlsx是实验结果,仍采用acc作为评价指标,当一个句子被完全识别正确时,算是识别正确一次。
8. 拟合图像如下 <br>
flops:
<img src = "fig/flops_lstmocr.png" class="h-90 auto">
params:
<img src = "fig/params_lstmocr.png" class="h-90 auto">
## update 2023.5.4
......
import argparse
from easydict import EasyDict as edict
import yaml
import os
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from model import *
from utils import *
from lstm_utils import *
from dataset import get_dataset
from config.alphabets import *
import time
import sys
from torch.utils.tensorboard import SummaryWriter
def parse_arg():
parser = argparse.ArgumentParser(description="train crnn")
# parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='config/360CC_config.yaml')
args = parser.parse_args()
with open(args.cfg, 'r') as f:
# config = yaml.load(f, Loader=yaml.FullLoader)
config = yaml.load(f,Loader=yaml.FullLoader)
config = edict(config)
config.DATASET.ALPHABETS = alphabet
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
return config
def train(config, train_loader, dataset, converter, model, criterion, optimizer, device, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
model.train()
end = time.time()
for i, (inp, idx) in enumerate(train_loader):
# measure data time
data_time.update(time.time() - end)
labels = get_batch_label(dataset, idx)
inp = inp.to(device)
# inference
preds = model(inp).cpu()
# compute loss
batch_size = inp.size(0)
text, length = converter.encode(labels) # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
loss = criterion(preds, text, preds_size, length)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.item(), inp.size(0))
batch_time.update(time.time()-end)
if i % config.PRINT_FREQ == 0:
msg = 'Epoch: [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
'Speed {speed:.1f} samples/s\t' \
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
epoch, i, len(train_loader), batch_time=batch_time,
speed=inp.size(0)/batch_time.val,
data_time=data_time, loss=losses)
print(msg)
end = time.time()
def validate(config, val_loader, dataset, converter, model, criterion, device, epoch):
losses = AverageMeter()
model.eval()
n_correct = 0
with torch.no_grad():
for i, (inp, idx) in enumerate(val_loader):
# 一个batch的label的list
labels = get_batch_label(dataset, idx)
inp = inp.to(device)
# inference
preds = model(inp).cpu()
# compute loss
batch_size = inp.size(0)
text, length = converter.encode(labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
losses.update(loss.item(), inp.size(0))
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
# sim_preds是decode后的,得到了一个str或者一个str list (也是一个batch的)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
# target应该是一个list 这里相当于在判断字符串是否相等
for pred, target in zip(sim_preds, labels):
if pred == target:
n_correct += 1
if (i + 1) % config.PRINT_FREQ == 0:
print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))
# if i == config.TEST.NUM_TEST_BATCH: # 只检查这些数据的 (config.TEST.NUM_TEST_BATCH个batch)
# break
# 只打印展示前 config.TEST.NUM_TEST_DISP 个句子的对比
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
# num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
# if num_test_sample > len(dataset):
# num_test_sample = len(dataset)
num_test_sample = len(dataset)
print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
accuracy = n_correct / float(num_test_sample)
print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
return accuracy
def main():
# load config
config = parse_arg()
# construct face related neural networks
model = get_crnn(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# define loss function
criterion = torch.nn.CTCLoss()
# 上一轮训练的epoch (起始为0 )
last_epoch = config.TRAIN.BEGIN_EPOCH
train_dataset = get_dataset(config)(config, is_train=True)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
shuffle=config.TRAIN.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY,
)
val_dataset = get_dataset(config)(config, is_train=False)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=config.TEST.BATCH_SIZE_PER_GPU,
shuffle=config.TEST.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY,
)
optimizer = get_optimizer(config, model)
if isinstance(config.TRAIN.LR_STEP, list):
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, config.TRAIN.LR_STEP,
config.TRAIN.LR_FACTOR, last_epoch-1
)
else:
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, config.TRAIN.LR_STEP,
config.TRAIN.LR_FACTOR, last_epoch - 1
)
save_dir = 'ckpt'
if not os.path.isdir(save_dir):
os.makedirs(save_dir, mode=0o777)
os.chmod(save_dir, mode=0o777)
converter = strLabelConverter(config.DATASET.ALPHABETS)
# for resume, only
if config.TRAIN.RESUME.IS_RESUME:
model.load_state_dict(torch.load(save_dir + '/360cc_' + config.MODEL.NAME + '.pt'))
acc = validate(config, val_loader, val_dataset, converter, model, criterion, device, config.TRAIN.END_EPOCH)
print(f"test accuracy: {acc:.2f}%")
sys.exit()
best_acc = 0.0
for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
# train
train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch)
lr_scheduler.step()
# validate
acc = validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch)
# save best ckpt
if config.TRAIN.SAVE and acc>best_acc:
torch.save(model.state_dict(), save_dir + '/360cc_' + config.MODEL.NAME + '.pt')
best_acc = max(acc, best_acc)
print("best acc is:", best_acc)
if __name__ == '__main__':
main()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
def js_div(p_output, q_output, get_softmax=True):
"""
......@@ -167,3 +169,5 @@ def fold_bn(conv, bn):
return conv
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class GlobalVariables:
SELF_INPLANES = 0
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
## Naive MIA
#### update 2023.5.24
1\. 思路
(1)一共需要训练3个模型,分别是Target Model, Shadow Model,Attack Model. 其中Target Model是被攻击的模型。
(2)假设攻击者已知Target Model的结构并了解其训练所用的数据集的特征,于是Shadow Model采用与Target Model相同的结构。将数据集切分,分别用于Target Model和Shadow Model的训练和测试。
(3)在训练和测试Target Model和Shadow Model时,我们可以分别构建Attack Model的测试集和训练集。Attack Model的输入是Target Model或Shadow Model对一个图片输入的输出向量,输出是该图片是否属于Target Model或Shadow Model的训练集。我们构造Shadow Model就是为了构建Attack Model的训练集(因为我们作为攻击者,只知道Shadow Model的训练集,测试集是什么,而不知道Target Model的),Target Model的输出向量和训练集、测试集可以作为Attack Model的测试集,检验攻击成果。
(4)Attack Model是一个二分类网络,由若干FC,ReLU,BN层组成
(5)最终输出攻击的成功率 (即Attack Model的acc)和一些统计信息
2\. 代码文件说明
attack.py:
核心文件,``get_cmd_arguments()``负责了读取各种参数配置
``split_dataset()``将数据集切分
``get_data_loader()``构造data loader
`` attack_inference``负责Attack Model的推理
``create_attack()``负责的工作较多,包括训练Target Model,训练Shadow Model,通过``train_model()``在训练中构造Attack Model的数据集(也可以load训练好的模型权值后通过``prepare_attack_data`构造). 训练Attack Model,并测试。
model.py :
支持的各种模型结构,目前用到的包括ResNet18,50,152,MobileNetV2
train.py:
``prepare_attack_data()``构建Attack Model训练/测试所用的数据集,attack_X是Target Model或Shadow Model的输出,attack_Y是Target Model或Shadow Model是否为其训练集数据
```train_per_epoch()```负责Target Model或Shadow Model每个epoch的训练
``val_per_epoch()``负责Target Model或Shadow Model每个epoch的测试/验证
``train_attack_model()``负责训练Attack Model的训练
``train_model()``负责组织训练模型,调用上述函数
其余程序文件的含义与之前用到的基本相同
3\. 结果
- 先后分别在CIFAR10和CIFAR100上尝试了对ResNet18,50,152,MobileNetV2的攻击,在CIFAR100上一般能比CIFAR10上取得更好的结果,原因可能是CIFAR100数据集每类数据图片更少,更难训练,更容易过拟合,因此对MIA更脆弱。
只有对使用CIFAR10训练的MobileNetV2的攻击取得了相对显著的结果(Attack Model acc=61.57%),其余情况下,Attack Model的acc均在52%~56%,考虑到二分类问题随机情况下也有50%的acc,攻击效果不是很显著。
- 具体数据
Target Model和Shadow Model的optimizer是Adam,Attack Model的optimizer是SGD,lr_scheduler都是CosineAnnealingLR
* ResNet18 + CIFAR10:
```
Validation Accuracy for the Best Attack Model is: 54.40 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 53.26%
---Detailed Results----
precision recall f1-score support
Non-Member 0.55 0.38 0.45 15000
Member 0.52 0.69 0.59 15000
accuracy 0.53 30000
macro avg 0.54 0.53 0.52 30000
weighted avg 0.54 0.53 0.52 30000
```
* ResNet50 + CIFAR10:
```
Validation Accuracy for the Best Attack Model is: 55.70 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 52.20%
---Detailed Results----
precision recall f1-score support
Non-Member 0.52 0.47 0.49 15000
Member 0.52 0.58 0.55 15000
accuracy 0.52 30000
macro avg 0.52 0.52 0.52 30000
weighted avg 0.52 0.52 0.52 30000
```
* ResNet152 + CIFAR10
```
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 52.50%
---Detailed Results----
precision recall f1-score support
Non-Member 0.54 0.37 0.44 15000
Member 0.52 0.68 0.59 15000
accuracy 0.53 30000
macro avg 0.53 0.53 0.51 30000
weighted avg 0.53 0.53 0.51 30000
```
* MobileNetV2 + CIFAR10
```
Validation Accuracy for the Best Attack Model is: 65.20 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 61.57%
---Detailed Results----
precision recall f1-score support
Non-Member 0.63 0.57 0.60 15000
Member 0.61 0.66 0.63 15000
accuracy 0.62 30000
macro avg 0.62 0.62 0.61 30000
weighted avg 0.62 0.62 0.61 30000
```
* ResNet50 + CIFAR100:
```
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 54.30%
---Detailed Results----
precision recall f1-score support
Non-Member 0.54 0.61 0.57 15000
Member 0.55 0.47 0.51 15000
accuracy 0.54 30000
macro avg 0.54 0.54 0.54 30000
weighted avg 0.54 0.54 0.54 30000
```
* ResNet152 + CIFAR100:
```
Attack Test Accuracy is : 56.11%
---Detailed Results----
precision recall f1-score support
Non-Member 0.56 0.60 0.58 15000
Member 0.57 0.52 0.54 15000
accuracy 0.56 30000
macro avg 0.56 0.56 0.56 30000
weighted avg 0.56 0.56 0.56 30000
```
* MobileNetV2 + CIFAR100:
```
Validation Accuracy for the Best Attack Model is: 55.50 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 51.33%
---Detailed Results----
precision recall f1-score support
Non-Member 0.51 0.52 0.52 15000
Member 0.51 0.51 0.51 15000
accuracy 0.51 30000
macro avg 0.51 0.51 0.51 30000
weighted avg 0.51 0.51 0.51 30000
```
4\. 问题及改进方向
Q1:大部分情况下攻击效果较差
A1:尝试换一种MIA攻击方式,尝试Loss Trajectory MIA。之后再仔细检查下有无代码问题。
Q2:对于CIFAR100,Target Model和Shadow Model的训练效果差
A2:因为CIFAR100每类的训练数据本来就比较少,还分别拆分给了Target Model和Shadow Model的train set和test set,导致训练数据过少,因此模型训练效果较差。可以考虑调整训练策略,或是**改用CINIC-10数据集**,CINIC-10 是 CIFAR-10 通过添加下采样的 ImageNet 图像扩展得到,有 270,000 张图像,是 CIFAR 的 4.5 倍,图像大小与 CIFAR 中的一样,不需要对代码进行大量修改,不过有一点问题是图像来源于 CIFAR 和 ImageNet,这些图像的分布不一定相同,可能不利于训练。
Q3:如何把量化也考虑进来
A3:我认为主要取决于我们假设的攻击情景是怎样的。
如果认为攻击者不知道要攻击的模型是全精度的还是量化后的,那么就会采用全精度的Shadow Model,进而根据Shadow Model训练好Attack Model。我们只需要把Target Model进行相应的量化,而后再根据其输出构造Attack Model的test set,就可以将量化考虑进来了。
如果认为攻击者知道要攻击的模型是量化的,同时还知道是用的什么量化数据表示方式、位宽等,那么会比较麻烦。简单的方法是还是训练一个全精度的Shadow Model然后量化,之后再构造Attack Model的训练集。其余步骤于上一个情况相同。可能会取得更好的攻击效果(因为是根据量化后的Shadow Model构建的Attack Model的训练集,Attack Model更能适应量化模型的output). 复杂的方法是引入QAT,但QAT目前在ResNet系列和MobileNetV2上还比较难训练起来。
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment