Created
May 27, 2019 19:43
-
-
Save jorgeramirez/052285d70fe131c3601b94c91788000d to your computer and use it in GitHub Desktop.
get_mmr_regression_gain_multi.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ast import literal_eval as make_tuple | |
import random | |
import math | |
import sys | |
import itertools | |
import gc | |
import math | |
import datetime | |
import os | |
import threading | |
import multiprocessing | |
import concurrent.futures | |
import time | |
from code.cnndm_acl18.PyRouge.Rouge.Rouge import Rouge | |
from code.cnndm_acl18.Document import Document | |
rouge = Rouge(use_ngram_buf=True) | |
def load_data(src_file, tgt_file): | |
docs = [] | |
with open(src_file, 'r', encoding='utf-8') as src_reader, \ | |
open(tgt_file, 'r', encoding='utf-8') as tgt_reader: | |
for src_line, tgt_line in zip(src_reader, tgt_reader): | |
src_line = src_line.strip() | |
tgt_line = tgt_line.strip() | |
if src_line == "" or tgt_line == "": | |
docs.append(None) | |
continue | |
src_sents = src_line.split('##SENT##') | |
tgt_sents = tgt_line.strip().split('##SENT##') | |
docs.append(Document(src_sents, tgt_sents)) | |
return docs | |
def load_upperbound(filepath): | |
res = [] | |
with open(filepath, 'r', encoding='utf-8') as reader: | |
for line in reader: | |
line = line.strip() | |
sp = line.split('\t') | |
if 'None' in sp[0]: | |
comb = None | |
else: | |
comb = make_tuple(sp[0]) | |
score = float(sp[1]) | |
res.append((comb, score)) | |
return res | |
def get_mmr_order(oracle, doc): | |
scores = [(rouge.compute_rouge([doc.summary_sents], [[doc.doc_sents[idx]]])[ | |
'rouge-2']['f'][0]) for idx in oracle[0]] | |
comb = zip(oracle[0], scores) | |
comb = sorted(comb, key=lambda x: -x[1]) | |
selected = [] | |
left = [x[0] for x in comb[1:]] | |
selected.append(comb[0][0]) | |
while len(left) > 0: | |
candidates = [(selected + [x]) for x in left] | |
scores = [(rouge.compute_rouge([doc.summary_sents], [[doc.doc_sents[idx] for idx in can]])['rouge-2']['f'][0]) | |
for can in | |
candidates] | |
tmp = zip(list(range(len(candidates))), scores) | |
sorted_tmp = sorted(tmp, key=lambda x: -x[1]) | |
best_sent = left[sorted_tmp[0][0]] | |
best_score = sorted_tmp[0][1] | |
selected.append(best_sent) | |
del left[sorted_tmp[0][0]] | |
mmr_comb = tuple(selected) | |
return mmr_comb | |
def get_mmr_regression(oracle, doc): | |
selected = [] | |
selected_id = [] | |
prev_rouge = 0 | |
res_buf = [] | |
for sent_id in oracle: | |
candidates = [(selected + [x]) for x in doc.doc_sents] | |
cur_rouge = [(rouge.compute_rouge([doc.summary_sents], [can])[ | |
'rouge-2']['f'][0]) for can in candidates] | |
selected.append(doc.doc_sents[sent_id]) | |
selected_id.append(sent_id) | |
out_rouge = [(x - prev_rouge) for x in cur_rouge] | |
out_string = ' '.join([str(x) for x in out_rouge]) | |
res_buf.append(out_string) | |
prev_rouge = max(cur_rouge) | |
return tuple(selected_id), '\t'.join(res_buf) | |
def main(src_file, tgt_file, oracle_file, output_file): | |
docs = load_data(src_file, tgt_file) | |
oracles = load_upperbound(oracle_file) | |
acc = 0 | |
count = 0 | |
for item in oracles: | |
if item[0] is not None: | |
acc += item[1] | |
count += 1 | |
print('upper bound: {0}'.format(acc / count)) | |
count = 0 | |
with open(output_file, 'w', encoding='utf-8') as writer: | |
for doc, oracle in zip(docs, oracles): | |
count += 1 | |
if count % 100 == 0: | |
print(count) | |
rouge.ngram_buf = {} | |
if oracle[0] is None: | |
writer.write('None\t0' + '\n') | |
continue | |
oracle_with_order = get_mmr_order(oracle, doc) | |
oracle_with_order, rouge_scores = get_mmr_regression( | |
oracle_with_order, doc) | |
writer.write('{0}\t{1}'.format( | |
oracle_with_order, rouge_scores) + '\n') | |
def main_chunks(in_file, out_file, init, limit): | |
for i in range(init, limit): | |
src_file = "%s_%03d.src.txt" % (in_file, i) | |
tgt_file = "%s_%03d.tgt.txt" % (in_file, i) | |
oracle_file = "%s_%03d.oracle.txt" % (in_file, i) | |
fout = "%s_%03d.regain.txt" % (out_file, i) | |
print("src_file/tgt_file/oracle %s fout %s" % (src_file, fout)) | |
if not os.path.isfile(src_file): | |
print("input %s not found" % src_file) | |
break | |
main(src_file, tgt_file, oracle_file, fout) | |
def _main_multi(in_file, out_file, total_chunks): | |
if total_chunks == 0: | |
# we do not multiprocess the file | |
src_file = "%s.src.txt" % in_file | |
tgt_file = "%s.tgt.txt" % in_file | |
oracle_file = "%s.oracle.txt" % in_file | |
fout = "%s.regain.txt" % out_file | |
main(src_file, tgt_file, oracle_file, fout) | |
starttime = time.time() | |
n_cpu = multiprocessing.cpu_count() | |
steps = 1 | |
N = n_cpu | |
if n_cpu < total_chunks: | |
steps = (total_chunks // n_cpu) + 1 | |
processes = [] | |
for i in range(N + 1): | |
print(i) | |
p = multiprocessing.Process(target=main_chunks, | |
args=(in_file, out_file, i * steps, (i + 1) * steps)) | |
processes.append(p) | |
p.start() | |
for process in processes: | |
process.join() | |
print("%s done!" % in_file) | |
if __name__ == "__main__": | |
# src_file = r"sample_data/train.txt.src.100" | |
# tgt_file = r"sample_data/train.txt.tgt.100" | |
# oracle_file = r"sample_data/train.rouge_bigram_F1.oracle.100" | |
# output_file = r"sample_data/train.rouge_bigram_F1.oracle.100.regGain" | |
# main(src_file, tgt_file, oracle_file, output_file) | |
_main_multi(sys.argv[1], sys.argv[2], int(sys.argv[3])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment