Created
June 20, 2018 17:48
-
-
Save cryzed/f25823ea594a2cdd8a41eb81e370e662 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import collections | |
import itertools | |
import os | |
import sys | |
import time | |
try: | |
import cPickle as pickle | |
except ImportError: | |
import pickle | |
import PIL.Image | |
import numpy as np | |
import vlfeat | |
import scipy.cluster.vq | |
import scipy.spatial.distance | |
import matplotlib.pyplot as plt | |
PAGES_PATH = os.path.join('data', 'pages') | |
GT_PATH = os.path.join('data', 'GT') | |
IFS_MATCH_IMAGES_PATH = os.path.join('ifs_match_images') | |
MATCH_IMAGES_PATH = os.path.join('match_images') | |
CODE_BOOK_PATH = os.path.join('data', 'codebook.bin') | |
SPATIAL_PYRAMID_TYPES = ['L', 'R', 'G', 'GL', 'GR', 'LR', 'GLR'] | |
CELL_MARGIN_TYPES = ['none', 'horizontal', 'vertical', 'both'] | |
argument_parser = argparse.ArgumentParser() | |
argument_parser.add_argument('--step-size', '-s', type=int, default=15) | |
argument_parser.add_argument('--cell-size', '-c', type=int, default=3) | |
argument_parser.add_argument('--centroids', '-C', type=int, default=40) | |
argument_parser.add_argument('--k-means-iterations', '-k', type=int, default=20) | |
argument_parser.add_argument('--distance-metric', choices=['cityblock', 'cosine', 'euclidean'], default='cosine') | |
argument_parser.add_argument('--spatial-pyramid-type', '-S', choices=SPATIAL_PYRAMID_TYPES, default='LR') | |
argument_parser.add_argument('--pages', '-p', type=int, default=1) | |
argument_parser.add_argument('--accumulator-percentile', '-a', type=float, default=95.0) | |
argument_parser.add_argument('--use-ifs', '-I', action='store_true') | |
argument_parser.add_argument('--use-accumulator', '-A', action='store_true') | |
argument_parser.add_argument('--save-images', '-sa', action='store_true') | |
argument_parser.add_argument('--verbose', action='store_true') | |
argument_parser.add_argument('--cell-margin', choices=CELL_MARGIN_TYPES, default='horizontal') | |
SpatialPyramid = collections.namedtuple('SpatialPyramid', ['global_', 'left', 'right']) | |
def makedirs(name, mode=0777, exist_ok=False): | |
if not exist_ok: | |
return os.makedirs(name, mode) | |
# Taken from Python 3 | |
try: | |
os.makedirs(name, mode) | |
except OSError: | |
# Cannot rely on checking for EEXIST, since the operating system | |
# could give priority to other errors like EACCES or EROFS | |
if not exist_ok or not os.path.isdir(name): | |
raise | |
def load_gtp_file(path): | |
entries = collections.defaultdict(list) | |
with open(path) as file: | |
for line in (line for line in (line.strip() for line in file) if line): | |
x1, y1, x2, y2, word = line.split() | |
entries[word].append((int(x1), int(y1), int(x2), int(y2))) | |
return entries | |
def load_codebook(path): | |
input_file = open(path, 'r') | |
code_book = np.fromfile(input_file, dtype='float32') | |
code_book = np.reshape(code_book, (4096, 128)) | |
return code_book | |
def make_spatial_pyramid(data, length, type_='GLR'): | |
count = len(data) | |
left_index = int(np.floor(count / 2)) | |
right_index = int(np.ceil(count / 2)) | |
if type_ == 'L': | |
data = [], data[:left_index], [] | |
elif type_ == 'R': | |
data = [], [], data[right_index:] | |
elif type_ == 'G': | |
data = data, [], [] | |
elif type_ == 'GL': | |
data = data, data[:left_index], [] | |
elif type_ == 'GR': | |
data = data, [], data[right_index:] | |
elif type_ == 'LR': | |
data = [], data[:left_index], data[right_index:] | |
elif type_ == 'GLR': | |
data = data, data[:left_index], data[right_index:] | |
else: | |
raise ValueError('unknown spatial pyramid type: %r' % type_) | |
spatial_pyramid = SpatialPyramid(*(np.bincount(datum, minlength=length) for datum in data)) | |
return np.concatenate(spatial_pyramid) | |
def load_corpus(page_names): | |
defaultdict_factory = lambda: collections.defaultdict(defaultdict_factory) | |
corpus = collections.defaultdict(defaultdict_factory) | |
offset = 0 | |
images = [] | |
corpus_gtp = collections.defaultdict(list) | |
for page_name in page_names: | |
corpus['pages'][page_name]['offset'] = offset | |
# Load page image | |
image_path = os.path.join(PAGES_PATH, '%s.png' % page_name) | |
corpus['pages'][page_name]['image_path'] = image_path | |
image = PIL.Image.open(image_path) | |
corpus['pages'][page_name]['image'] = image | |
images.append(image) | |
# Load page GTP | |
gtp_path = os.path.join(GT_PATH, '%s.gtp' % page_name) | |
corpus['pages'][page_name]['gtp_path'] = gtp_path | |
gtp = load_gtp_file(gtp_path) | |
corpus['pages'][page_name]['gtp'] = gtp | |
# Update global corpus GTP with current offset | |
for word, coordinates in gtp.items(): | |
for x1, y1, x2, y2 in coordinates: | |
corpus_gtp[word].append((x1 + offset, y1, x2 + offset, y2)) | |
offset += image.width | |
# Create Corpus image by concatenating page images horizontally | |
width = sum(image.width for image in images) | |
max_height = max(image.height for image in images) | |
corpus_image = PIL.Image.new(images[0].mode, (width, max_height)) | |
x_offset = 0 | |
for image in images: | |
corpus_image.paste(image, (x_offset, 0)) | |
x_offset += image.width | |
corpus['gtp'] = corpus_gtp | |
corpus['image'] = corpus_image | |
corpus['data'] = np.array(corpus_image, dtype='float32') | |
return corpus | |
def pre_main(arguments): | |
load_codebook(os.path.join('data', 'codebook.bin')) | |
page_names = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(PAGES_PATH))[:arguments.pages]] | |
corpus = load_corpus(page_names) | |
results1 = collections.OrderedDict() | |
results2 = collections.OrderedDict() | |
for accumulator_percentile in range(0, 105, 5): | |
print accumulator_percentile | |
arguments.use_ifs = True | |
arguments.use_accumulator = True | |
arguments.accumulator_percentile = accumulator_percentile | |
start = time.time() | |
mean_average_precision = main(arguments, corpus) | |
duration = int(time.time() - start) | |
results1[accumulator_percentile] = mean_average_precision | |
results2[accumulator_percentile] = duration | |
plt.plot(range(len(results1)), results1.values(), 'o') | |
plt.xlabel('Accumulator Percentile') | |
plt.ylabel('Mean Average Precision') | |
plt.xticks(range(len(results1)), results1.keys()) | |
plt.grid(True) | |
plt.ylim(0, 1) | |
plt.tight_layout() | |
plt.show() | |
plt.plot(range(len(results2)), results2.values(), 'o') | |
plt.xlabel('Accumulator Percentile') | |
plt.ylabel('Runtime') | |
plt.xticks(range(len(results2)), results2.keys()) | |
plt.grid(True) | |
plt.tight_layout() | |
plt.show() | |
def main(arguments, corpus): | |
# Calculate SIFT data for corpus | |
frames, descriptors = vlfeat.vl_dsift( | |
corpus['image'] / corpus['data'].max(), step=arguments.step_size, size=arguments.cell_size, | |
fast=True, float_descriptors=True) | |
# Find all frames and descriptors contained inside word boundaries (minus a cell margin of cell_size * 2) | |
cell_margin = 2 * arguments.cell_size | |
words_frames = [] | |
words_descriptors = [] | |
previous_frame_index = 0 | |
word_data_indices = collections.OrderedDict() | |
word_coordinates = collections.OrderedDict() | |
for word, coordinates in corpus['gtp'].items(): | |
# Filter word frames within word bounding box | |
for variation, (x1, y1, x2, y2) in enumerate(coordinates): | |
if arguments.cell_margin == 'none': | |
mask = ( | |
(frames[:, 0] >= x1) & (frames[:, 1] >= y1) & | |
(frames[:, 0] <= x2) & (frames[:, 1] <= y2)) | |
elif arguments.cell_margin == 'horizontal': | |
mask = ( | |
(frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1) & | |
(frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2)) | |
elif arguments.cell_margin == 'vertical': | |
mask = ( | |
(frames[:, 0] >= x1) & (frames[:, 1] >= y1 + cell_margin) & | |
(frames[:, 0] <= x2) & (frames[:, 1] <= y2 - cell_margin)) | |
elif arguments.cell_margin == 'both': | |
mask = ( | |
(frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1 + cell_margin) & | |
(frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2 - cell_margin)) | |
else: | |
raise RuntimeError('dude what the fuck are you doing') | |
# Get matching frames/desc for the word | |
word_frames = frames[mask] | |
words_frames.append(word_frames) | |
words_descriptors.append(descriptors[mask]) | |
# Count how many frames are contained inside the bounding box | |
frame_count = word_frames.shape[0] | |
# Note at which index and how many (following) frames/descs are part of a word | |
key = word, variation | |
word_data_indices[key] = previous_frame_index, frame_count | |
word_coordinates[key] = x1, y1, x2, y2 | |
previous_frame_index += frame_count | |
words_frames = np.concatenate(words_frames) | |
words_descriptors = np.concatenate(words_descriptors) | |
if arguments.centroids == 4096: | |
code_book = load_codebook(CODE_BOOK_PATH) | |
labels, _ = scipy.cluster.vq.vq(words_descriptors, code_book) | |
else: | |
# Calculate labels | |
_, labels = scipy.cluster.vq.kmeans2( | |
words_descriptors, arguments.centroids, iter=arguments.k_means_iterations, minit='points') | |
# Word -> labels mapping | |
# noinspection PyArgumentList | |
word_labels = collections.OrderedDict( | |
(key, labels[start:start + length]) for key, (start, length) in word_data_indices.items()) | |
# Create (word, variation) -> spatial pyramid mapping | |
# noinspection PyArgumentList | |
spatial_pyramids = collections.OrderedDict( | |
(key, make_spatial_pyramid(labels, arguments.centroids, arguments.spatial_pyramid_type)) | |
for key, labels in word_labels.items()) | |
# Create IFS database | |
ifs_height = len(spatial_pyramids.values()[0]) | |
ifs = [set() for count in range(ifs_height)] | |
for word_index, spatial_pyramid in enumerate(spatial_pyramids.values()): | |
for index, count in enumerate(spatial_pyramid): | |
if count: | |
ifs[index].add(word_index) | |
# Create word index -> variation set mapping | |
word_variation_indices = collections.defaultdict(set) | |
for word_index, (word, variation) in enumerate(spatial_pyramids.keys()): | |
word_variation_indices[word].add(word_index) | |
# Find query in IFS | |
spatial_pyramids_values = spatial_pyramids.values() | |
word_coordinates_values = word_coordinates.values() | |
average_precisions = [] | |
average_recalls = [] | |
for word_index, ((word, variation), query) in enumerate(spatial_pyramids.items()): | |
# Skip words with no findable duplicates in the IFS database | |
appearances = len(word_variation_indices[word]) - 1 | |
if not appearances: | |
if arguments.verbose: | |
print >> sys.stderr, 'No duplicate appearances for (%s, %d)!' % (word, variation) | |
continue | |
if arguments.use_ifs: | |
ifs_candidate_indices = list(itertools.chain(*(ifs[index] for index, count in enumerate(query) if count))) | |
candidate_indices = set(ifs_candidate_indices) | |
if not candidate_indices: | |
if arguments.verbose: | |
print >> sys.stderr, 'No candidates for (%s, %d) after IFS!' % (word, variation) | |
average_precisions.append(0) | |
average_recalls.append(0) | |
continue | |
if arguments.use_accumulator: | |
# noinspection PyArgumentList | |
accumulator = collections.Counter(ifs_candidate_indices) | |
# No candidates left after having applied the IFS | |
if not accumulator: | |
if arguments.verbose: | |
print >> sys.stderr, 'No candidates for (%s, %d) after IFS + Accumulator!' % (word, variation) | |
average_precisions.append(0) | |
average_recalls.append(0) | |
continue | |
most_common = accumulator.most_common() | |
rankings = sorted(set(accumulator.values())) | |
percentile_ranking = rankings[max(0, int(len(rankings) * arguments.accumulator_percentile / 100.0) - 1)] | |
candidate_indices = set( | |
index for index, count in | |
list(itertools.takewhile(lambda item: item[1] >= percentile_ranking, most_common))) | |
else: | |
candidate_indices = set(range(len(spatial_pyramids))) | |
candidate_indices -= {word_index} | |
if not candidate_indices: | |
if arguments.verbose: | |
print >> sys.stderr, 'No candidates for (%s, %d)' % (word, variation) | |
average_precisions.append(0) | |
average_recalls.append(0) | |
continue | |
candidate_pyramids = np.array([spatial_pyramids_values[index] for index in candidate_indices]) | |
query = query.reshape((1, query.shape[0])) | |
distances = scipy.spatial.distance.cdist(query, candidate_pyramids, metric=arguments.distance_metric)[0] | |
# Translate index in distance array to index of candidate | |
distances_indices = range(distances.shape[0]) | |
distance_index_to_candidate_index = { | |
distance_index: candidate_index for distance_index, candidate_index in | |
zip(distances_indices, candidate_indices)} | |
distances_sorted_indices = np.argsort(distances) | |
sorted_candidate_indices = [ | |
distance_index_to_candidate_index[distance_index] for distance_index in distances_sorted_indices] | |
hits = [1 if index in word_variation_indices[word] else 0 for index in sorted_candidate_indices] | |
true_positives = sum(hits) | |
# Calculate accumulated hits at index | |
hits_at_k = [] | |
current_hits = 0 | |
for hit in hits: | |
if hit: | |
current_hits += 1 | |
hits_at_k.append(current_hits) | |
average_precision = sum( | |
(current_hits / float(index)) * hit for index, (hit, current_hits) in | |
enumerate(zip(hits, hits_at_k), start=1)) / float(appearances) | |
average_precisions.append(average_precision) | |
average_recalls.append(true_positives / float(appearances)) | |
if arguments.save_images: | |
match_images_path = os.path.join(MATCH_IMAGES_PATH, '%s_%d' % (word, variation)) | |
makedirs(match_images_path, exist_ok=True) | |
coordinates = word_coordinates_values[word_index] | |
corpus['image'].crop(coordinates).save(os.path.join(match_images_path, '0_original.png')) | |
for rank, candidate_word_index in enumerate(sorted_candidate_indices, start=1): | |
coordinates = word_coordinates_values[candidate_word_index] | |
path = os.path.join(match_images_path, 'candidate_%d.png' % rank) | |
corpus['image'].crop(coordinates).save(path) | |
# print 'Word %s (Variation: %d): %.2f%%' % (word, variation, average_precision * 100) | |
print 'Mean Recall: %f' % (np.mean(average_recalls) * 100) | |
mean_average_precision = np.mean(average_precisions) | |
print 'Mean Average Precision: %f' % (mean_average_precision * 100) | |
return mean_average_precision | |
if __name__ == '__main__': | |
arguments = argument_parser.parse_args() | |
pre_main(arguments) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment