Created
July 9, 2018 14:35
-
-
Save aneesh-joshi/76fac8fb6586bd45d3f6ee73fae79d23 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
sys.path.append(os.path.join('..')) | |
import csv | |
import re | |
import gensim.downloader as api | |
from gensim.utils import simple_preprocess | |
import numpy as np | |
class MyWikiIterable: | |
def __init__(self, fpath): | |
# self.type_translator = {'query': 0, 'doc': 1, 'label': 2} | |
# self.iter_type = iter_type | |
with open(fpath, encoding='utf8') as tsv_file: | |
tsv_reader = csv.reader(tsv_file, delimiter='\t', quotechar='"', quoting=csv.QUOTE_NONE) | |
self.data_rows = [] | |
for row in tsv_reader: | |
self.data_rows.append(row) | |
self.to_print = "" | |
def preprocess_sent(self, sent): | |
"""Utility function to lower, strip and tokenize each sentence | |
Replace this function if you want to handle preprocessing differently""" | |
return re.sub("[^a-zA-Z0-9]", " ", sent.strip().lower()).split() | |
def get_stuff(self): | |
# Defining some consants for .tsv reading | |
QUESTION_ID_INDEX = 0 | |
QUESTION_INDEX = 1 | |
ANSWER_INDEX = 5 | |
ANSWER_ID_INDEX = 4 | |
LABEL_INDEX = 6 | |
document_group = [] | |
label_group = [] | |
n_relevant_docs = 0 | |
n_filtered_docs = 0 | |
query_ids = [] | |
query_id_group = [] | |
doc_ids = [] | |
doc_id_group = [] | |
queries = [] | |
docs = [] | |
labels = [] | |
for i, line in enumerate(self.data_rows[1:], start=1): | |
if i < len(self.data_rows) - 1: # check if out of bounds might occur | |
if self.data_rows[i][QUESTION_ID_INDEX] == self.data_rows[i + 1][QUESTION_ID_INDEX]: | |
document_group.append(self.preprocess_sent(self.data_rows[i][ANSWER_INDEX])) | |
doc_ids.append(self.data_rows[i][ANSWER_ID_INDEX]) | |
label_group.append(int(self.data_rows[i][LABEL_INDEX])) | |
n_relevant_docs += int(self.data_rows[i][LABEL_INDEX]) | |
else: | |
document_group.append(self.preprocess_sent(self.data_rows[i][ANSWER_INDEX])) | |
doc_ids.append(self.data_rows[i][ANSWER_ID_INDEX]) | |
label_group.append(int(self.data_rows[i][LABEL_INDEX])) | |
n_relevant_docs += int(self.data_rows[i][LABEL_INDEX]) | |
if n_relevant_docs > 0: | |
docs.append(document_group) | |
labels.append(label_group) | |
queries.append(self.preprocess_sent(self.data_rows[i][QUESTION_INDEX])) | |
query_ids.append(self.data_rows[i][QUESTION_ID_INDEX]) | |
doc_id_group.append(doc_ids) | |
# yield [queries[-1], document_group, label_group, query_ids, doc_ids] | |
else: | |
n_filtered_docs += 1 | |
n_relevant_docs = 0 | |
document_group = [] | |
label_group = [] | |
doc_ids = [] | |
else: | |
# If we are on the last line | |
document_group.append(self.preprocess_sent(self.data_rows[i][ANSWER_INDEX])) | |
label_group.append(int(self.data_rows[i][LABEL_INDEX])) | |
doc_ids.append(self.data_rows[i][ANSWER_ID_INDEX]) | |
doc_id_group.append(doc_ids) | |
query_ids.append(self.data_rows[i][QUESTION_ID_INDEX]) | |
n_relevant_docs += int(self.data_rows[i][LABEL_INDEX]) | |
if n_relevant_docs > 0: | |
docs.append(document_group) | |
labels.append(label_group) | |
queries.append(self.preprocess_sent(self.data_rows[i][QUESTION_INDEX])) | |
# yield [queries[-1], document_group, label_group, query_ids, doc_ids] | |
else: | |
n_filtered_docs += 1 | |
n_relevant_docs = 0 | |
return queries, docs, labels, query_ids, doc_id_group | |
queries, doc_group, label_group, query_ids, doc_id_group = MyWikiIterable(os.path.join('..', 'experimental_data', 'WikiQACorpus', 'WikiQA-train.tsv')).get_stuff() | |
print(len(queries)) | |
print(len(doc_group)) | |
print(len(label_group)) | |
print(len(query_ids)) | |
print(len(doc_id_group)) | |
# print(query_ids) | |
# print(queries) | |
# print(queries[0], '\n', docs[0],'\n', labels[0],'\n', query_ids[0],'\n', doc_ids[0]) | |
# exit() | |
# for q, doc, labels, q_id, d_ids in data: | |
# for d, l, d_id in zip(doc, labels, d_ids): | |
def print_qrels(fname): | |
with open(fname, 'w') as f: | |
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group): | |
for d, l, d_id in zip(doc, labels, d_ids): | |
# print(q_id + '\t' + '0' + '\t' + str(d_id) + '\t' + str(l) + '\n') | |
f.write(q_id + '\t' + '0' + '\t' + str(d_id) + '\t' + str(l) + '\n') | |
print("QRELS done") | |
def print_my_pred(fname, similarity_fn): | |
del_kv_model = api.load('glove-wiki-gigaword-300') | |
dim_size = del_kv_model.vector_size | |
kv_model = del_kv_model.wv | |
del del_kv_model | |
def sent2vec(sent): | |
if len(sent)==0: | |
print('length is 0, Returning random') | |
return np.random.random((dim_size,)) | |
vec = [] | |
for word in sent: | |
if word in kv_model: | |
vec.append(kv_model[word]) | |
if len(vec) == 0: | |
print('No words in vocab, Returning random') | |
return np.random.random((kv_model.vector_size,)) | |
vec = np.array(vec) | |
return np.mean(vec, axis=0) | |
def cosine_similarity(vec1, vec2): | |
return np.dot(vec1, vec2)/(np.linalg.norm(vec1)* np.linalg.norm(vec2)) | |
i=0 | |
with open(fname, 'w') as f: | |
for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group): | |
for d, l, d_id in zip(doc, labels, d_ids): | |
my_score = str(cosine_similarity(sent2vec(q),sent2vec(d))) | |
f.write(q_id + '\t' + 'Q0' + '\t' + str(d_id) + '\t' + '99' + '\t' + my_score + '\t' + 'STANDARD' + '\n') | |
print(i, "done") | |
i += 1 | |
def w2v_similarity_fn(q, d): | |
return cosine_similarity(sent2vec(q),sent2vec(d)) | |
print_qrels('my_test_qrels') | |
# import cProfile | |
print_my_pred('my_test_pred_w2v_300:', w2v_similarity_fn) | |
# print(sent2vec(simple_preprocess("asdasdasdasd helloasdasd"), kv_model)) | |
# def print_my_pred(fname, similarity_fn): | |
# # del_kv_model = api.load('glove-wiki-gigaword-300') | |
# from drmm_tks import DRMM_TKS | |
# dtks_model = DRMM_TKS.load('new_hope_dtks_unk_zero_no_normalize_50topk_12ep') | |
# i=0 | |
# with open(fname, 'w') as f: | |
# for q, doc, labels, q_id, d_ids in zip(queries, doc_group, label_group, query_ids, doc_id_group): | |
# for d, l, d_id in zip(doc, labels, d_ids): | |
# my_score = dtks_model.predict([q], [[d]]) | |
# f.write(q_id + '\t' + 'Q0' + '\t' + str(d_id) + '\t' + '99' + '\t' + str(my_score[0][0]) + '\t' + 'STANDARD' + '\n') | |
# print(i, "done") | |
# i += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment