Created
May 27, 2019 19:44
-
-
Save jorgeramirez/9ef253e65505cc04e4b561417f03feff to your computer and use it in GitHub Desktop.
find_oracle_multi.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import itertools | |
import gc | |
import math | |
import datetime | |
import os | |
import threading | |
import multiprocessing | |
import concurrent.futures | |
import time | |
from code.cnndm_acl18.PyRouge.Rouge.Rouge import Rouge | |
from code.cnndm_acl18.Document import Document | |
rouge = Rouge(use_ngram_buf=True) | |
MAX_COMB_L = 5 | |
MAX_COMB_NUM = 100000 | |
def c_n_x(n, x): | |
if x > (n >> 2): | |
x = n - x | |
res = 1 | |
for i in range(n, n - x, -1): | |
res *= i | |
for i in range(x, 0, -1): | |
res = res // i | |
return res | |
def solve_one(document): | |
if document.doc_len == 0 or document.summary_len == 0: | |
return None, 0 | |
sentence_bigram_recall = [0] * document.doc_len | |
for idx, sent in enumerate(document.doc_sents): | |
scores = rouge.compute_rouge([document.summary_sents], [sent]) | |
recall = scores['rouge-2']['r'][0] | |
sentence_bigram_recall[idx] = recall | |
candidates = [] | |
for idx, recall in enumerate(sentence_bigram_recall): | |
if recall > 0: | |
candidates.append(idx) | |
all_best_l = 1 | |
all_best_score = 0 | |
all_best_comb = None | |
for l in range(1, len(candidates)): | |
if l > MAX_COMB_L: | |
print('Exceed MAX_COMB_L') | |
break | |
comb_num = c_n_x(len(candidates), l) | |
if math.isnan(comb_num) or math.isinf(comb_num) or comb_num > MAX_COMB_NUM: | |
print('Exceed MAX_COMB_NUM') | |
break | |
combs = itertools.combinations(candidates, l) | |
l_best_score = 0 | |
l_best_choice = None | |
for comb in combs: | |
c_string = [document.doc_sents[idx] for idx in comb] | |
rouge_scores = rouge.compute_rouge( | |
[document.summary_sents], [c_string]) | |
rouge_bigram_f1 = rouge_scores['rouge-2']['f'][0] | |
if rouge_bigram_f1 > l_best_score: | |
l_best_score = rouge_bigram_f1 | |
l_best_choice = comb | |
if l_best_score > all_best_score: | |
all_best_l = l | |
all_best_score = l_best_score | |
all_best_comb = l_best_choice | |
else: | |
if l > all_best_l: | |
break | |
return all_best_comb, all_best_score | |
def solve(documents, output_file): | |
writer = open(output_file, 'w', encoding='utf-8', buffering=1) | |
for idx, doc in enumerate(documents): | |
if idx % 50 == 0: | |
print(datetime.datetime.now()) | |
rouge.ngram_buf = {} | |
gc.collect() | |
comb = solve_one(doc) | |
writer.write('{0}\t {1}'.format(comb[0], comb[1]) + '\n') | |
writer.close() | |
def load_data(src_file, tgt_file): | |
docs = [] | |
with open(src_file, 'r', encoding='utf-8') as src_reader, \ | |
open(tgt_file, 'r', encoding='utf-8') as tgt_reader: | |
for src_line, tgt_line in zip(src_reader, tgt_reader): | |
src_line = src_line.strip() | |
tgt_line = tgt_line.strip() | |
if src_line == "" or tgt_line == "": | |
docs.append(None) | |
continue | |
src_sents = src_line.split('##SENT##') | |
tgt_sents = tgt_line.strip().split('##SENT##') | |
docs.append(Document(src_sents, tgt_sents)) | |
return docs | |
def main(src_file, tgt_file, outfile_name): | |
docs = load_data(src_file, tgt_file) | |
solve(docs, outfile_name) | |
def main_chunks(in_file, out_file, init, limit): | |
for i in range(init, limit): | |
src_file = "%s_%03d.src.txt" % (in_file, i) | |
tgt_file = "%s_%03d.tgt.txt" % (in_file, i) | |
fout = "%s_%03d.oracle.txt" % (out_file, i) | |
print("src_file/tgt_file %s fout %s" % (src_file, fout)) | |
if not os.path.isfile(src_file): | |
print("input %s not found" % src_file) | |
break | |
main(src_file, tgt_file, fout) | |
def _main_multi(in_file, out_file, total_chunks): | |
starttime = time.time() | |
n_cpu = multiprocessing.cpu_count() | |
steps = 1 | |
N = n_cpu | |
if n_cpu < total_chunks: | |
steps = (total_chunks // n_cpu) + 1 | |
processes = [] | |
for i in range(N + 1): | |
print(i) | |
p = multiprocessing.Process(target=main_chunks, | |
args=(in_file, out_file, i * steps, (i + 1) * steps)) | |
processes.append(p) | |
p.start() | |
for process in processes: | |
process.join() | |
# with concurrent.futures.ThreadPoolExecutor(max_workers=N) as executor: | |
# for i in range(N + 1): | |
# print(i) | |
# executor.submit(main_chunks, in_file, out_file, | |
# i * steps, (i + 1) * steps) | |
print("%s done!" % in_file) | |
if __name__ == "__main__": | |
_main_multi(sys.argv[1], sys.argv[2], int(sys.argv[3])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment