366 lines
14 KiB
Python
366 lines
14 KiB
Python
import torch
|
|
import pickle
|
|
import numpy as np
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
class AttrDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
self.__dict__ = self
|
|
|
|
##### https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py
|
|
class BeamEntry:
|
|
"information about one single beam at specific time-step"
|
|
def __init__(self):
|
|
self.prTotal = 0 # blank and non-blank
|
|
self.prNonBlank = 0 # non-blank
|
|
self.prBlank = 0 # blank
|
|
self.prText = 1 # LM score
|
|
self.lmApplied = False # flag if LM was already applied to this beam
|
|
self.labeling = () # beam-labeling
|
|
|
|
class BeamState:
|
|
"information about the beams at specific time-step"
|
|
def __init__(self):
|
|
self.entries = {}
|
|
|
|
def norm(self):
|
|
"length-normalise LM score"
|
|
for (k, _) in self.entries.items():
|
|
labelingLen = len(self.entries[k].labeling)
|
|
self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))
|
|
|
|
def sort(self):
|
|
"return beam-labelings, sorted by probability"
|
|
beams = [v for (_, v) in self.entries.items()]
|
|
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
|
|
return [x.labeling for x in sortedBeams]
|
|
|
|
def wordsearch(self, classes, ignore_idx, beamWidth, dict_list):
|
|
beams = [v for (_, v) in self.entries.items()]
|
|
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)[:beamWidth]
|
|
|
|
for j, candidate in enumerate(sortedBeams):
|
|
idx_list = candidate.labeling
|
|
text = ''
|
|
for i,l in enumerate(idx_list):
|
|
if l not in ignore_idx and (not (i > 0 and idx_list[i - 1] == idx_list[i])): # removing repeated characters and blank.
|
|
text += classes[l]
|
|
|
|
if j == 0: best_text = text
|
|
if text in dict_list:
|
|
print('found text: ', text)
|
|
best_text = text
|
|
break
|
|
else:
|
|
print('not in dict: ', text)
|
|
return best_text
|
|
|
|
def applyLM(parentBeam, childBeam, classes, lm):
|
|
"calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
|
|
if lm and not childBeam.lmApplied:
|
|
c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char
|
|
c2 = classes[childBeam.labeling[-1]] # second char
|
|
lmFactor = 0.01 # influence of language model
|
|
bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other
|
|
childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence
|
|
childBeam.lmApplied = True # only apply LM once per beam entry
|
|
|
|
def addBeam(beamState, labeling):
|
|
"add beam if it does not yet exist"
|
|
if labeling not in beamState.entries:
|
|
beamState.entries[labeling] = BeamEntry()
|
|
|
|
def ctcBeamSearch(mat, classes, ignore_idx, lm, beamWidth=25, dict_list = []):
|
|
"beam search as described by the paper of Hwang et al. and the paper of Graves et al."
|
|
|
|
#blankIdx = len(classes)
|
|
blankIdx = 0
|
|
maxT, maxC = mat.shape
|
|
|
|
# initialise beam state
|
|
last = BeamState()
|
|
labeling = ()
|
|
last.entries[labeling] = BeamEntry()
|
|
last.entries[labeling].prBlank = 1
|
|
last.entries[labeling].prTotal = 1
|
|
|
|
# go over all time-steps
|
|
for t in range(maxT):
|
|
curr = BeamState()
|
|
|
|
# get beam-labelings of best beams
|
|
bestLabelings = last.sort()[0:beamWidth]
|
|
|
|
# go over best beams
|
|
for labeling in bestLabelings:
|
|
|
|
# probability of paths ending with a non-blank
|
|
prNonBlank = 0
|
|
# in case of non-empty beam
|
|
if labeling:
|
|
# probability of paths with repeated last char at the end
|
|
prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
|
|
|
|
# probability of paths ending with a blank
|
|
prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]
|
|
|
|
# add beam at current time-step if needed
|
|
addBeam(curr, labeling)
|
|
|
|
# fill in data
|
|
curr.entries[labeling].labeling = labeling
|
|
curr.entries[labeling].prNonBlank += prNonBlank
|
|
curr.entries[labeling].prBlank += prBlank
|
|
curr.entries[labeling].prTotal += prBlank + prNonBlank
|
|
curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from
|
|
curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling
|
|
|
|
# extend current beam-labeling
|
|
for c in range(maxC - 1):
|
|
# add new char to current beam-labeling
|
|
newLabeling = labeling + (c,)
|
|
|
|
# if new labeling contains duplicate char at the end, only consider paths ending with a blank
|
|
if labeling and labeling[-1] == c:
|
|
prNonBlank = mat[t, c] * last.entries[labeling].prBlank
|
|
else:
|
|
prNonBlank = mat[t, c] * last.entries[labeling].prTotal
|
|
|
|
# add beam at current time-step if needed
|
|
addBeam(curr, newLabeling)
|
|
|
|
# fill in data
|
|
curr.entries[newLabeling].labeling = newLabeling
|
|
curr.entries[newLabeling].prNonBlank += prNonBlank
|
|
curr.entries[newLabeling].prTotal += prNonBlank
|
|
|
|
# apply LM
|
|
#applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)
|
|
|
|
# set new beam state
|
|
last = curr
|
|
|
|
# normalise LM scores according to beam-labeling-length
|
|
last.norm()
|
|
|
|
# sort by probability
|
|
#bestLabeling = last.sort()[0] # get most probable labeling
|
|
|
|
# map labels to chars
|
|
#res = ''
|
|
#for idx,l in enumerate(bestLabeling):
|
|
# if l not in ignore_idx and (not (idx > 0 and bestLabeling[idx - 1] == bestLabeling[idx])): # removing repeated characters and blank.
|
|
# res += classes[l]
|
|
|
|
if dict_list == []:
|
|
bestLabeling = last.sort()[0] # get most probable labeling
|
|
res = ''
|
|
for i,l in enumerate(bestLabeling):
|
|
if l not in ignore_idx and (not (i > 0 and bestLabeling[i - 1] == bestLabeling[i])): # removing repeated characters and blank.
|
|
res += classes[l]
|
|
else:
|
|
res = last.wordsearch(classes, ignore_idx, beamWidth, dict_list)
|
|
|
|
return res
|
|
#####
|
|
|
|
def consecutive(data, mode ='first', stepsize=1):
|
|
group = np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
|
|
group = [item for item in group if len(item)>0]
|
|
|
|
if mode == 'first': result = [l[0] for l in group]
|
|
elif mode == 'last': result = [l[-1] for l in group]
|
|
return result
|
|
|
|
def word_segmentation(mat, separator_idx = {'th': [1,2],'en': [3,4]}, separator_idx_list = [1,2,3,4]):
|
|
result = []
|
|
sep_list = []
|
|
start_idx = 0
|
|
for sep_idx in separator_idx_list:
|
|
if sep_idx % 2 == 0: mode ='first'
|
|
else: mode ='last'
|
|
a = consecutive( np.argwhere(mat == sep_idx).flatten(), mode)
|
|
new_sep = [ [item, sep_idx] for item in a]
|
|
sep_list += new_sep
|
|
sep_list = sorted(sep_list, key=lambda x: x[0])
|
|
|
|
for sep in sep_list:
|
|
for lang in separator_idx.keys():
|
|
if sep[1] == separator_idx[lang][0]: # start lang
|
|
sep_lang = lang
|
|
sep_start_idx = sep[0]
|
|
elif sep[1] == separator_idx[lang][1]: # end lang
|
|
if sep_lang == lang: # check if last entry if the same start lang
|
|
new_sep_pair = [lang, [sep_start_idx+1, sep[0]-1]]
|
|
if sep_start_idx > start_idx:
|
|
result.append( ['', [start_idx, sep_start_idx-1] ] )
|
|
start_idx = sep[0]+1
|
|
result.append(new_sep_pair)
|
|
else: # reset
|
|
sep_lang = ''
|
|
|
|
if start_idx <= len(mat)-1:
|
|
result.append( ['', [start_idx, len(mat)-1] ] )
|
|
return result
|
|
|
|
class CTCLabelConverter(object):
|
|
""" Convert between text-label and text-index """
|
|
|
|
#def __init__(self, character, separator = []):
|
|
def __init__(self, character, separator_list = {}, dict_pathlist = {}):
|
|
# character (str): set of the possible characters.
|
|
dict_character = list(character)
|
|
|
|
#special_character = ['\xa2', '\xa3', '\xa4','\xa5']
|
|
#self.separator_char = special_character[:len(separator)]
|
|
|
|
self.dict = {}
|
|
#for i, char in enumerate(self.separator_char + dict_character):
|
|
for i, char in enumerate(dict_character):
|
|
# NOTE: 0 is reserved for 'blank' token required by CTCLoss
|
|
self.dict[char] = i + 1
|
|
|
|
self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
|
|
#self.character = ['[blank]']+ self.separator_char + dict_character # dummy '[blank]' token for CTCLoss (index 0)
|
|
self.separator_list = separator_list
|
|
|
|
separator_char = []
|
|
for lang, sep in separator_list.items():
|
|
separator_char += sep
|
|
|
|
self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]
|
|
|
|
dict_list = {}
|
|
for lang, dict_path in dict_pathlist.items():
|
|
with open(dict_path, "rb") as input_file:
|
|
word_count = pickle.load(input_file)
|
|
dict_list[lang] = word_count
|
|
self.dict_list = dict_list
|
|
|
|
def encode(self, text, batch_max_length=25):
|
|
"""convert text-label into text-index.
|
|
input:
|
|
text: text labels of each image. [batch_size]
|
|
|
|
output:
|
|
text: concatenated text index for CTCLoss.
|
|
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
|
length: length of each text. [batch_size]
|
|
"""
|
|
length = [len(s) for s in text]
|
|
text = ''.join(text)
|
|
text = [self.dict[char] for char in text]
|
|
|
|
return (torch.IntTensor(text), torch.IntTensor(length))
|
|
|
|
def decode_greedy(self, text_index, length):
|
|
""" convert text-index into text-label. """
|
|
texts = []
|
|
index = 0
|
|
for l in length:
|
|
t = text_index[index:index + l]
|
|
|
|
char_list = []
|
|
for i in range(l):
|
|
if t[i] not in self.ignore_idx and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator).
|
|
#if (t[i] != 0) and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank (and separator).
|
|
char_list.append(self.character[t[i]])
|
|
text = ''.join(char_list)
|
|
|
|
texts.append(text)
|
|
index += l
|
|
return texts
|
|
|
|
def decode_beamsearch(self, mat, beamWidth=5):
|
|
texts = []
|
|
|
|
for i in range(mat.shape[0]):
|
|
t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth)
|
|
texts.append(t)
|
|
return texts
|
|
|
|
def decode_wordbeamsearch(self, mat, beamWidth=5):
|
|
texts = []
|
|
argmax = np.argmax(mat, axis = 2)
|
|
for i in range(mat.shape[0]):
|
|
words = word_segmentation(argmax[i])
|
|
string = ''
|
|
for word in words:
|
|
matrix = mat[i, word[1][0]:word[1][1]+1,:]
|
|
if word[0] == '': dict_list = []
|
|
else: dict_list = self.dict_list[word[0]]
|
|
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list)
|
|
string += t
|
|
texts.append(string)
|
|
return texts
|
|
|
|
class AttnLabelConverter(object):
|
|
""" Convert between text-label and text-index """
|
|
|
|
def __init__(self, character):
|
|
# character (str): set of the possible characters.
|
|
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
|
|
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
|
|
list_character = list(character)
|
|
self.character = list_token + list_character
|
|
|
|
self.dict = {}
|
|
for i, char in enumerate(self.character):
|
|
# print(i, char)
|
|
self.dict[char] = i
|
|
|
|
def encode(self, text, batch_max_length=25):
|
|
""" convert text-label into text-index.
|
|
input:
|
|
text: text labels of each image. [batch_size]
|
|
batch_max_length: max length of text label in the batch. 25 by default
|
|
|
|
output:
|
|
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
|
|
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
|
|
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
|
|
"""
|
|
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
|
|
# batch_max_length = max(length) # this is not allowed for multi-gpu setting
|
|
batch_max_length += 1
|
|
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
|
|
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
|
|
for i, t in enumerate(text):
|
|
text = list(t)
|
|
text.append('[s]')
|
|
text = [self.dict[char] for char in text]
|
|
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
|
|
return (batch_text.to(device), torch.IntTensor(length).to(device))
|
|
|
|
def decode(self, text_index, length):
|
|
""" convert text-index into text-label. """
|
|
texts = []
|
|
for index, l in enumerate(length):
|
|
text = ''.join([self.character[i] for i in text_index[index, :]])
|
|
texts.append(text)
|
|
return texts
|
|
|
|
|
|
class Averager(object):
|
|
"""Compute average for torch.Tensor, used for loss average."""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def add(self, v):
|
|
count = v.data.numel()
|
|
v = v.data.sum()
|
|
self.n_count += count
|
|
self.sum += v
|
|
|
|
def reset(self):
|
|
self.n_count = 0
|
|
self.sum = 0
|
|
|
|
def val(self):
|
|
res = 0
|
|
if self.n_count != 0:
|
|
res = self.sum / float(self.n_count)
|
|
return res
|