Commit 4a625ed6 by Zhihong Ma

feat: LSTM-OCR with CRNN structure

parent 2d918317
GPUID: 0
WORKERS: 1
PRINT_FREQ: 10
SAVE_FREQ: 10
PIN_MEMORY: False
OUTPUT_DIR: 'output'
DATASET:
DATASET: 360CC
ROOT: "../CRNN/image"
CHAR_FILE: 'dataset/txt/char_std_5990.txt'
JSON_FILE: {'train': 'dataset/txt/train.txt', 'val': 'dataset/txt/test.txt'}
SCALE_FACTOR: 0.25
ROT_FACTOR: 30
STD: 0.193
MEAN: 0.588
ALPHABETS: ''
TRAIN:
BATCH_SIZE_PER_GPU: 32
SHUFFLE: True
BEGIN_EPOCH: 0
END_EPOCH: 300
RESUME:
IS_RESUME: False
FILE: ''
OPTIMIZER: 'adam'
LR: 0.0001
WD: 0.0
LR_STEP: [60, 80]
LR_FACTOR: 0.1
MOMENTUM: 0.0
NESTEROV: False
RMSPROP_ALPHA:
RMSPROP_CENTERED:
SAVE: true
TEST:
BATCH_SIZE_PER_GPU: 16
SHUFFLE: True # for random test rather than test on the whole validation set
NUM_TEST_BATCH: 3000
NUM_TEST_DISP: 10 # 每次显示多少个测试比对
MODEL:
NAME: 'lstm-ocr'
IMAGE_SIZE:
OW: 280 # origial width: 280
H: 32
W: 160 # resized width: 160
NUM_CLASSES: 0
NUM_HIDDEN: 256
GPUID: 0
WORKERS: 1
PRINT_FREQ: 10
SAVE_FREQ: 10
PIN_MEMORY: False
OUTPUT_DIR: 'output'
CUDNN:
BENCHMARK: True
DETERMINISTIC: False
ENABLED: True
DATASET:
DATASET: OWN
ROOT: "H:/DL-DATASET/360M/images"
JSON_FILE: {'train': 'lib/dataset/txt/train_own.txt', 'val': 'lib/dataset/txt/test_own.txt'}
SCALE_FACTOR: 0.25
ROT_FACTOR: 30
STD: 0.193
MEAN: 0.588
ALPHABETS: ''
TRAIN:
BATCH_SIZE_PER_GPU: 32
SHUFFLE: True
BEGIN_EPOCH: 0
END_EPOCH: 100
RESUME:
IS_RESUME: False
FILE: ''
OPTIMIZER: 'adam'
LR: 0.0001
WD: 0.0
LR_STEP: [60, 80]
LR_FACTOR: 0.1
MOMENTUM: 0.0
NESTEROV: False
RMSPROP_ALPHA:
RMSPROP_CENTERED:
FINETUNE:
IS_FINETUNE: true
FINETUNE_CHECKPOINIT: 'output/checkpoints/mixed_second_finetune_acc_97P7.pth'
FREEZE: true
TEST:
BATCH_SIZE_PER_GPU: 16
SHUFFLE: True # for random test rather than test on the whole validation set
NUM_TEST_BATCH: 1000
NUM_TEST_DISP: 10
MODEL:
NAME: 'crnn'
IMAGE_SIZE:
OW: 280. # origial width: 280
H: 32
W: 160 # resized width: 160
NUM_CLASSES: 0
NUM_HIDDEN: 256
from __future__ import print_function, absolute_import
import torch.utils.data as data
import os
import numpy as np
import cv2
class _360CC(data.Dataset):
def __init__(self, config, is_train=True):
self.root = config.DATASET.ROOT
self.is_train = is_train
self.inp_h = config.MODEL.IMAGE_SIZE.H
self.inp_w = config.MODEL.IMAGE_SIZE.W
self.dataset_name = config.DATASET.DATASET
self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
self.std = np.array(config.DATASET.STD, dtype=np.float32)
char_file = config.DATASET.CHAR_FILE # 字典?
with open(char_file, 'rb') as file:
# 读取char_file中的所有行的文字,并以字典形式存储键值(序号)和(去除了前后空格的)字符编码
char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
# convert name:indices to name:string
self.labels = []
with open(txt_file, 'r', encoding='utf-8') as file:
contents = file.readlines()
for c in contents:
imgname = c.split(' ')[0]
indices = c.split(' ')[1:]
# 在这里相当于对train.txt / test.txt中的 以index list形式为label进行了decode,得到了str
string = ''.join([char_dict[int(idx)] for idx in indices])
self.labels.append({imgname: string})
print("load {} images!".format(self.__len__()))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name = list(self.labels[idx].keys())[0] # 即 label
img = cv2.imread(os.path.join(self.root, img_name))
# 转化为灰度图像
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 获取图像原始高度
img_h, img_w = img.shape
# 整体缩放到inp_h,inp_w
img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
img = np.reshape(img, (self.inp_h, self.inp_w, 1))
img = img.astype(np.float32)
# yaml中配置
img = (img/255. - self.mean) / self.std
# chanel,h,w
img = img.transpose([2, 0, 1])
return img, idx
from ._360cc import _360CC
from ._own import _OWN
def get_dataset(config):
if config.DATASET.DATASET == "360CC":
return _360CC
elif config.DATASET.DATASET == "OWN":
return _OWN
else:
raise NotImplemented()
\ No newline at end of file
from __future__ import print_function, absolute_import
import torch.utils.data as data
import os
import numpy as np
import cv2
class _OWN(data.Dataset):
def __init__(self, config, is_train=True):
self.root = config.DATASET.ROOT
self.is_train = is_train
self.inp_h = config.MODEL.IMAGE_SIZE.H
self.inp_w = config.MODEL.IMAGE_SIZE.W
self.dataset_name = config.DATASET.DATASET
self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
self.std = np.array(config.DATASET.STD, dtype=np.float32)
txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
# convert name:indices to name:string
with open(txt_file, 'r', encoding='utf-8') as file:
self.labels = [{c.split(' ')[0]: c.split(' ')[-1][:-1]} for c in file.readlines()]
print("load {} images!".format(self.__len__()))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name = list(self.labels[idx].keys())[0]
img = cv2.imread(os.path.join(self.root, img_name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_h, img_w = img.shape
img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
img = np.reshape(img, (self.inp_h, self.inp_w, 1))
img = img.astype(np.float32)
img = (img/255. - self.mean) / self.std
img = img.transpose([2, 0, 1])
return img, idx
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -11,7 +11,7 @@ def extract_ratio(md='ResNet18'): ...@@ -11,7 +11,7 @@ def extract_ratio(md='ResNet18'):
flop_ratio = [] flop_ratio = []
for line in lines: for line in lines:
# if '(' in line and ')' in line: # if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line: if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line or 'LSTM' in line:
layer.append(line.split(':')[1].split('(')[0]) layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1] r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1) r1 = float(r1)
......
from model import * from model import *
from functools import partial
from lstm_utils import *
import argparse
from easydict import EasyDict as edict
import yaml
from config.alphabets import *
import sys
import torch import torch
from ptflops import get_model_complexity_info from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module): data_path = '../data/ptb'
# get children form model! embed_size = 700
# 为了后续也能够更新参数,需要用nn.ModuleList来承载 hidden_size = 700
eval_batch_size = 10
dropout = 0.65
tied = True
def parse_arg():
parser = argparse.ArgumentParser(description="train crnn")
# children = nn.ModuleList(model.children()) # parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
# print(children) parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='config/360CC_config.yaml')
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children()) args = parser.parse_args()
# flatt_children = nn.ModuleList()
flatt_children = [] with open(args.cfg, 'r') as f:
if len(children) == 0: # config = yaml.load(f, Loader=yaml.FullLoader)
# if model has no children; model is last child! :O config = yaml.load(f,Loader=yaml.FullLoader)
return model config = edict(config)
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children) config.DATASET.ALPHABETS = alphabet
return flatt_children config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
# 定义获取不包含wrapper的所有子模块的函数 return config
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape): def lstm_constructor(shape,hidden):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)): return {"x": torch.zeros(shape,dtype=torch.int64),
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module: "hidden": hidden}
return True
return False
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops') config = parse_arg()
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18') model = model = get_crnn(config)
args = parser.parse_args()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
full_file = 'ckpt/cifar10_' + args.model + '.pt' full_file = 'ckpt/360cc_lstm-ocr.pt'
model.load_state_dict(torch.load(full_file)) model.load_state_dict(torch.load(full_file))
# flat = get_children(model) flops, params = get_model_complexity_info(model, (1, 32, 160), as_strings=True,
# print(flat) print_per_layer_stat=True)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_optimizer(config, model):
optimizer = None
if config.TRAIN.OPTIMIZER == "sgd":
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
momentum=config.TRAIN.MOMENTUM,
weight_decay=config.TRAIN.WD,
nesterov=config.TRAIN.NESTEROV
)
elif config.TRAIN.OPTIMIZER == "adam":
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
)
elif config.TRAIN.OPTIMIZER == "rmsprop":
optimizer = optim.RMSprop(
filter(lambda p: p.requires_grad, model.parameters()),
lr=config.TRAIN.LR,
momentum=config.TRAIN.MOMENTUM,
weight_decay=config.TRAIN.WD,
# alpha=config.TRAIN.RMSPROP_ALPHA,
# centered=config.TRAIN.RMSPROP_CENTERED
)
return optimizer
def get_batch_label(d, i):
label = []
for idx in i:
label.append(list(d.labels[idx].values())[0])
return label
# 可能主要用于OWN dataset?
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
length = []
result = []
# 如果是字符形式的输入,则decode_flag = True
decode_flag = True if type(text[0])==bytes else False
# text是list of str
for item in text:
# 变为utf8编码
if decode_flag:
item = item.decode('utf-8','strict')
length.append(len(item))
# 对str中的每个字符
for char in item:
# 在字典中对应的编号
index = self.dict[char]
result.append(index)
text = result # 相当于给翻译成了alphabet中的index
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
# 判断length中是否只含一个元素
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1]) # alphabet是字符列表,但顺序被self.dict记录了
# 将一个字符串列表char_list中的所有元素按顺序连接起来形成一个新的字符串。
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()): # 遍历各个str,并调用上面的情况去处理decode
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
def get_char_dict(path):
with open(path, 'rb') as file:
char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
def model_info(model): # Plots a line-by-line description of a PyTorch model
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
print('\n%5s %50s %9s %12s %20s %12s %12s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
print('%5g %50s %9s %12g %20s %12.3g %12.3g' % (
i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
print('Model Summary: %g layers, %g parameters, %g gradients\n' % (i + 1, n_p, n_g))
\ No newline at end of file
CRNN(
11.77 M, 100.000% Params, 1.26 GMac, 100.000% MACs,
(layers): ModuleDict(
11.77 M, 100.000% Params, 1.26 GMac, 100.000% MACs,
(conv1): Conv2d(640, 0.005% Params, 3.28 MMac, 0.260% MACs, 1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU(0, 0.000% Params, 327.68 KMac, 0.026% MACs, )
(pool1): MaxPool2d(0, 0.000% Params, 327.68 KMac, 0.026% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(73.86 k, 0.628% Params, 94.54 MMac, 7.494% MACs, 64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2): ReLU(0, 0.000% Params, 163.84 KMac, 0.013% MACs, )
(pool2): MaxPool2d(0, 0.000% Params, 163.84 KMac, 0.013% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): Conv2d(295.17 k, 2.508% Params, 94.45 MMac, 7.487% MACs, 128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn3): BatchNorm2d(512, 0.004% Params, 163.84 KMac, 0.013% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu3): ReLU(0, 0.000% Params, 81.92 KMac, 0.006% MACs, )
(conv4): Conv2d(590.08 k, 5.014% Params, 188.83 MMac, 14.968% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4): ReLU(0, 0.000% Params, 81.92 KMac, 0.006% MACs, )
(pool4): MaxPool2d(0, 0.000% Params, 81.92 KMac, 0.006% MACs, kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
(conv5): Conv2d(1.18 M, 10.029% Params, 193.55 MMac, 15.342% MACs, 256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn5): BatchNorm2d(1.02 k, 0.009% Params, 167.94 KMac, 0.013% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu5): ReLU(0, 0.000% Params, 83.97 KMac, 0.007% MACs, )
(conv6): Conv2d(2.36 M, 20.053% Params, 387.01 MMac, 30.677% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu6): ReLU(0, 0.000% Params, 83.97 KMac, 0.007% MACs, )
(pool6): MaxPool2d(0, 0.000% Params, 83.97 KMac, 0.007% MACs, kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
(conv7): Conv2d(1.05 M, 8.915% Params, 43.01 MMac, 3.409% MACs, 512, 512, kernel_size=(2, 2), stride=(1, 1))
(bn7): BatchNorm2d(1.02 k, 0.009% Params, 41.98 KMac, 0.003% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu7): ReLU(0, 0.000% Params, 20.99 KMac, 0.002% MACs, )
(lstm1): LSTM(1.58 M, 13.401% Params, 64.87 MMac, 5.142% MACs, 512, 256, bidirectional=True)
(fc1): Linear(131.33 k, 1.116% Params, 5.37 MMac, 0.426% MACs, in_features=512, out_features=256, bias=True)
(lstm2): LSTM(1.05 M, 8.945% Params, 43.37 MMac, 3.438% MACs, 256, 256, bidirectional=True)
(fc2): Linear(3.46 M, 29.364% Params, 141.41 MMac, 11.209% MACs, in_features=512, out_features=6736, bias=True)
)
)
\ No newline at end of file
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
938.7733335552051 422.2961529040284 112.00572935619252 28.15632049948652 6.580025053557225 1.593147596698573 0.39127389747221836 0.09726359211545822 0.02424534011691382 0.006096811349829773 0.0015819717562818567 0.0004088973671629054 0.00010591612996765123 4.215430786749398e-05 3.358851529044255e-05 938.0795941095593 276.1298308688724 43.56510410351131 43.31743097440866 43.31651992748872 43.317170796987085 43.316909690885424 245.27340648023244 80.55429874312716 50.01809924502505 38.66623068838267 17.323188075553762 12.265772098808744 26.78403199897619 9.167642811282255 3.1779184189380314 12.23167946700013 22.667666751272108 6.595860316723723 0.8016993676302422 3.1649891963442665 12.231686340919987 21.030391880048075 5.630180110297391 0.20824827230948417 0.7939111200807045 3.164995719146672 12.23168958358928
js_param_list:
1083.2152054453256 407.68102883090796 142.77271018968833 40.694342955319094 9.528033474559528 2.3056188772063395 0.5663514708991146 0.1405112916168392 0.03504902772537398 0.008813186313906787 0.0022363630716486275 0.000566627677334268 0.00014559649495622196 4.1442517690921754e-05 3.062477331841085e-05 1083.0293894297583 286.18087410960214 48.04799916302135 47.56581240023642 47.5640587789577 47.565362072997914 47.56507051619791 306.72068771644035 110.02866499298202 52.861756413549614 51.56753626886545 17.228751722832826 13.602914776865827 35.21196010598172 8.706064889198597 3.4432562752359903 13.53603392058786 29.65906818846962 6.118510949236464 0.8693586052802382 3.417883320602458 13.536039876528541 27.482646086738573 5.183734622246258 0.22716556760816775 0.8541492170710019 3.417896091306703 13.536039813988335
ptq_acc_list:
0.0 0.0 0.0 0.172 0.627 0.772 0.7856666666666666 0.791 0.791 0.7936666666666666 0.793 0.7936666666666666 0.794 0.7936666666666666 0.794 0.0 0.0 0.087 0.09433333333333334 0.086 0.09533333333333334 0.09366666666666666 0.0 0.001 0.0033333333333333335 0.18733333333333332 0.343 0.591 0.4836666666666667 0.5843333333333334 0.7566666666666667 0.5906666666666667 0.557 0.657 0.7823333333333333 0.7503333333333333 0.5883333333333334 0.597 0.6843333333333333 0.792 0.785 0.7486666666666667 0.5933333333333334
acc_loss_list:
1.0 1.0 1.0 0.7834662190516157 0.21065883340327318 0.02811582039446074 0.010910616869492292 0.004196391103650818 0.004196391103650818 0.0008392782207302194 0.001678556441460299 0.0008392782207302194 0.00041963911036503983 0.0008392782207302194 0.00041963911036503983 1.0 1.0 0.8904741921947126 0.8812421317666806 0.8917331095258079 0.8799832144355854 0.8820814099874108 1.0 0.9987410826689047 0.9958036088963492 0.7641628199748216 0.5681913554343264 0.2559798573227025 0.39110365086026017 0.26437263953000417 0.04741921947125467 0.25639949643306753 0.2987830465799412 0.17289131347041542 0.01510700797314311 0.055392362568191404 0.25933697020562313 0.24842635333613097 0.13848090642047836 0.002937473772555558 0.011749895090222373 0.057490558120016744 0.2530423835501468
## update 2023.5.14
### 采用了CRNN的结构实现LSTM-OCR,相比于之前的版本有了更好、更通用的处理OCR问题的能力。我使用了360CC数据集(中文OCR训练数据集,共360w张图片,我只用了其中3w训练数据和3k张测试数据)来训练和测试模型。
1. config文件夹里是配置文件:
- alphabets.py是字典
- 360CC_config.yaml是使用360CC数据集时的配置文件
- OWN_config.yaml是使用其他数据集时的配置文件
2. dataset文件夹:
- txt文件夹中的char_std_5990.txt是360cc数据集的字典
- test.txt中是配置测试集数据
- train.txt中是配置训练集数据
- _360cc.py用于构建360cc dataset来训练/测试
- _own.py用于构建自己的dataset来训练/测试
3. lstm_utils.py中是LSTM-OCR中需要用到的一些特定的函数、类等
4. module.py中增加了对LSTM量化相关的组件,补充了对于BiLSTM的部分权值参数的伪量化,修改了QLinear的quantize_inference方法,使其仍然是真量化。
5. train.py中的train和validate函数考虑了OCR问题的特点,有了较大修改。同理ptq.py中的direct_inference, full_inference, quantize_inference也有一定修改。
6. 由于QLSTM采用伪量化,而其他层采用真量化,因此在model.py中与QLSTM直接连接的层需要通过dequantize_tensor(x)或quantize_tensor(x)来处理下数据。
7. ptq_result_lstm-ocr.txt和ptq_result_lstm-ocr.xlsx是实验结果,仍采用acc作为评价指标,当一个句子被完全识别正确时,算是识别正确一次。
8. 拟合图像如下 <br>
flops:
<img src = "fig/flops_lstmocr.png" class="h-90 auto">
params:
<img src = "fig/params_lstmocr.png" class="h-90 auto">
## update 2023.5.4 ## update 2023.5.4
......
import argparse
from easydict import EasyDict as edict
import yaml
import os
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from model import *
from utils import *
from lstm_utils import *
from dataset import get_dataset
from config.alphabets import *
import time
import sys
from torch.utils.tensorboard import SummaryWriter
def parse_arg():
parser = argparse.ArgumentParser(description="train crnn")
# parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='config/360CC_config.yaml')
args = parser.parse_args()
with open(args.cfg, 'r') as f:
# config = yaml.load(f, Loader=yaml.FullLoader)
config = yaml.load(f,Loader=yaml.FullLoader)
config = edict(config)
config.DATASET.ALPHABETS = alphabet
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
return config
def train(config, train_loader, dataset, converter, model, criterion, optimizer, device, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
model.train()
end = time.time()
for i, (inp, idx) in enumerate(train_loader):
# measure data time
data_time.update(time.time() - end)
labels = get_batch_label(dataset, idx)
inp = inp.to(device)
# inference
preds = model(inp).cpu()
# compute loss
batch_size = inp.size(0)
text, length = converter.encode(labels) # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
loss = criterion(preds, text, preds_size, length)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.item(), inp.size(0))
batch_time.update(time.time()-end)
if i % config.PRINT_FREQ == 0:
msg = 'Epoch: [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
'Speed {speed:.1f} samples/s\t' \
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
epoch, i, len(train_loader), batch_time=batch_time,
speed=inp.size(0)/batch_time.val,
data_time=data_time, loss=losses)
print(msg)
end = time.time()
def validate(config, val_loader, dataset, converter, model, criterion, device, epoch):
losses = AverageMeter()
model.eval()
n_correct = 0
with torch.no_grad():
for i, (inp, idx) in enumerate(val_loader):
# 一个batch的label的list
labels = get_batch_label(dataset, idx)
inp = inp.to(device)
# inference
preds = model(inp).cpu()
# compute loss
batch_size = inp.size(0)
text, length = converter.encode(labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
losses.update(loss.item(), inp.size(0))
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
# sim_preds是decode后的,得到了一个str或者一个str list (也是一个batch的)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
# target应该是一个list 这里相当于在判断字符串是否相等
for pred, target in zip(sim_preds, labels):
if pred == target:
n_correct += 1
if (i + 1) % config.PRINT_FREQ == 0:
print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))
# if i == config.TEST.NUM_TEST_BATCH: # 只检查这些数据的 (config.TEST.NUM_TEST_BATCH个batch)
# break
# 只打印展示前 config.TEST.NUM_TEST_DISP 个句子的对比
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
# num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
# if num_test_sample > len(dataset):
# num_test_sample = len(dataset)
num_test_sample = len(dataset)
print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
accuracy = n_correct / float(num_test_sample)
print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
return accuracy
def main():
# load config
config = parse_arg()
# construct face related neural networks
model = get_crnn(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# define loss function
criterion = torch.nn.CTCLoss()
# 上一轮训练的epoch (起始为0 )
last_epoch = config.TRAIN.BEGIN_EPOCH
train_dataset = get_dataset(config)(config, is_train=True)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
shuffle=config.TRAIN.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY,
)
val_dataset = get_dataset(config)(config, is_train=False)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=config.TEST.BATCH_SIZE_PER_GPU,
shuffle=config.TEST.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY,
)
optimizer = get_optimizer(config, model)
if isinstance(config.TRAIN.LR_STEP, list):
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, config.TRAIN.LR_STEP,
config.TRAIN.LR_FACTOR, last_epoch-1
)
else:
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, config.TRAIN.LR_STEP,
config.TRAIN.LR_FACTOR, last_epoch - 1
)
save_dir = 'ckpt'
if not os.path.isdir(save_dir):
os.makedirs(save_dir, mode=0o777)
os.chmod(save_dir, mode=0o777)
converter = strLabelConverter(config.DATASET.ALPHABETS)
# for resume, only
if config.TRAIN.RESUME.IS_RESUME:
model.load_state_dict(torch.load(save_dir + '/360cc_' + config.MODEL.NAME + '.pt'))
acc = validate(config, val_loader, val_dataset, converter, model, criterion, device, config.TRAIN.END_EPOCH)
print(f"test accuracy: {acc:.2f}%")
sys.exit()
best_acc = 0.0
for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
# train
train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch)
lr_scheduler.step()
# validate
acc = validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch)
# save best ckpt
if config.TRAIN.SAVE and acc>best_acc:
torch.save(model.state_dict(), save_dir + '/360cc_' + config.MODEL.NAME + '.pt')
best_acc = max(acc, best_acc)
print("best acc is:", best_acc)
if __name__ == '__main__':
main()
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim
def js_div(p_output, q_output, get_softmax=True): def js_div(p_output, q_output, get_softmax=True):
""" """
...@@ -166,4 +168,6 @@ def fold_bn(conv, bn): ...@@ -166,4 +168,6 @@ def fold_bn(conv, bn):
return conv return conv
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment