Created
June 17, 2020 12:30
-
-
Save numb3r3/f8b0f8a6ca5036e751bd03cb2c86dff9 to your computer and use it in GitHub Desktop.
textrank++ demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
""" | |
import os | |
import sys | |
import argparse | |
import json | |
import numpy as np | |
from numpy import dot | |
from numpy.linalg import norm | |
from tqdm import tqdm | |
from utils.data_helper import split_sentences, tokenize | |
from text_utils.segmenter import SentenceSegmenter | |
from text_utils.tokenizer import Tokenizer | |
from text_utils.stopwords import is_stopword | |
from utils.simhash import SimHash | |
from utils.metrics import evaluate, summary_metrics | |
from utils.client import Client | |
from graph import DocGraph | |
def main(args): | |
# print('load stopwords') | |
# load_stopwords(args.stopwords) | |
print('loading text graph ...') | |
segmenter = SentenceSegmenter(token_limits=args.max_sent_len) | |
tokenizer = Tokenizer() | |
select_client = None | |
if args.select_ip: | |
select_client = Client(args.select_ip, args.select_port) | |
encode_client = None | |
if args.encode_ip: | |
encode_client = Client(args.encode_ip, args.encode_port) | |
topn_results_1 = [] | |
topn_results_2 = [] | |
# sample_count = 0 | |
with open(args.input, 'r') as fin: | |
for line in tqdm(fin): | |
items = line.strip().split('\t') | |
doc_id = items[0] | |
title = items[1] | |
content = items[2] | |
# labels = items[3] | |
gt_rank = items[3] | |
gt_rank_list = gt_rank.split("\001") | |
important_term_set = set([item.split("_")[0] for item in gt_rank_list if item.split("_")[1]=="0"]) | |
middle_term_set = set([item.split("_")[0] for item in gt_rank_list if item.split("_")[1]=="1"]) | |
unimportant_term_set = set([item.split("_")[0] for item in gt_rank_list if item.split("_")[1]=="2"]) | |
title_tokenized_str = ' '.join(tokenizer.tokenize(title)) | |
sent_tokenlized_strs = [] | |
for i, sent in enumerate(segmenter.segment(content)): | |
tokens = tokenizer.tokenize(sent) | |
if len(tokens) >= 2: | |
sent_tokenlized_strs.append(' '.join(tokens)) | |
weight_preds = None | |
if select_client: | |
weight_preds = select_client.predict([title_tokenized_str] + sent_tokenlized_strs) | |
sent_encodes = None | |
if encode_client: | |
sent_encodes = encode_client.predict([title_tokenized_str] + sent_tokenlized_strs) | |
doc_graph = DocGraph(doc_id=doc_id, | |
title=title_tokenized_str, | |
sentences=sent_tokenlized_strs, | |
term_weights=weight_preds, | |
sent_encodes=sent_encodes, | |
filter_term_func=lambda t: not is_stopword(t), | |
sim_measure=args.sim_measure) | |
ranks = doc_graph.rank(iters=args.maxiters, normalize=True) | |
def _valid_token(token): | |
is_valid = doc_graph.get_vertex(token).meta['vertex_type'] == 'word' | |
if is_valid and args.only_title: | |
is_valid = (token in title) | |
return is_valid | |
sorted_pred_ranks = [key for key in sorted(ranks, key=ranks.get, reverse=True) if _valid_token(key)] | |
result1 = evaluate([3, 5, 10], sorted_pred_ranks, important_term_set) | |
result2 = evaluate([3, 5, 10], sorted_pred_ranks, important_term_set | middle_term_set) | |
topn_results_1.append(result1) | |
topn_results_2.append(result2) | |
if args.manual_check: | |
print('标题: %s' % title) | |
print('正文: %s' % content) | |
print('TOP-5 重要句子:') | |
sorted_sents = [k for k in sorted(ranks, key=ranks.get, reverse=True) if doc_graph.get_vertex(k).meta['vertex_type'] == 'sentence'] | |
summary_sents = {} | |
for t, k in enumerate(sorted_sents): | |
vertex = doc_graph.get_vertex(k) | |
idx = vertex.meta['vertex_index'] | |
x = vertex.meta['vertex_str'] | |
print('[%d] (%.2f) %s' % (idx, ranks[k], x)) | |
summary_sents[idx] = x | |
if t >= 4: | |
break | |
print('静态摘要:') | |
print(' '.join([summary_sents[k] for k in sorted(summary_sents)])) | |
print('TOP-10 重要词:') | |
key_words = sorted_pred_ranks[:10] | |
print(' '.join(['%s - %.2f' % (k, ranks[k]) for k in key_words])) | |
print('最重要词: %s' % (' '.join(important_term_set))) | |
print('次重要词: %s' % (' '.join(middle_term_set))) | |
for n, p, r, f in result1: | |
print('TOP-%d: P: %.2f, R: %.2f, F1: %.2f' % (n, p, r, f)) | |
print('#############################') | |
precision_1, recall_1, f1_score_1 = summary_metrics(topn_results_1, verbose=True) | |
precision_2, recall_2, f1_score_2 = summary_metrics(topn_results_2, verbose=True) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-input", type=str, help="the input file") | |
parser.add_argument("-algo", type=str, help="the rank algorithm name") | |
parser.add_argument("-sim_measure", type=str, help="the similarity measure", default="euclidean_sim") | |
parser.add_argument("-stopwords", type=str, help="the stopwords file") | |
parser.add_argument("-maxiters", type=int, help="the maximize iterations", default=150) | |
parser.add_argument("-only_title", type=bool, help="only consider the title words") | |
parser.add_argument("-select_port", type=int, help="the selecte word service port", default=30311) | |
parser.add_argument("-select_ip", type=str, help="the select word service ip") | |
parser.add_argument("-encode_port", type=int, help="the sentence encode service port", default=30310) | |
parser.add_argument("-encode_ip", type=str, help="the sentence encode service ip") | |
parser.add_argument("-manual_check", type=bool, help="whether to print detailed result", default=False) | |
parser.add_argument("-max_sent_len", type=int, help="the maximize sentence length", default=20) | |
return parser.parse_args() | |
def usage(): | |
print("Usage: %s" % (sys.argv[0]), file=sys.stderr) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment