Created
January 19, 2017 18:36
-
-
Save bxshi/eecc1fe0aafac85e61be9c38190d5fd1 to your computer and use it in GitHub Desktop.
Calculate MeanRank and Hits@K using TensorFlow. From github.com/nddsg/ProjC (private repo right now)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_eval_ops(model_input, pred_y, all_triples, eval_triples, n_entity, | |
top_k, idx_1=0, idx_2=1, idx_3=2): | |
""" Evaluation operations for any model. | |
For given <h,r> predict t, idx_1 = 0, idx_2 = 1, idx_3 = 2 | |
For given <t,r> predict h, idx_1 = 2, idx_2 = 1, idx_3 = 0 | |
:param model_input: N by 3 matrix, each row is a h,r,t pair | |
:param pred_y: N by ENTITY_VOCAB matrix | |
:param all_triples: M by 3 matrix, contains all triples in the KG | |
:param eval_triples: M_{eval} by 3 matrix, contains all triples that will be | |
evaluated, this is a subset of all_triples. model_input | |
is a subset of eval_triples where the joint index | |
model_input[idx_1] and model_input[idx_2] is unique in | |
model_input | |
:param n_entity: Number of unique entities in the KG | |
:param top_k: Parameter of Hits@top_k | |
:param idx_1: First index of the <?,r> pair | |
:param idx_2: Second index of the <?,r> pair | |
:param idx_3: Target index in the h,r,t triple | |
:return: | |
""" | |
def get_id_mask(hrt, triples): | |
return tf.logical_and(tf.equal(hrt[idx_1], triples[:, idx_1]), | |
tf.equal(hrt[idx_2], triples[:, idx_2])) | |
def calculate_metrics(tensors): | |
# eval_hrt, a 3 element h,r,t triple | |
eval_hrt = tensors | |
# find the entity_vocab vector row id of the given h,r pair | |
pred_y_mask = get_id_mask(eval_hrt, model_input) | |
pred_score = tf.reshape(tf.boolean_mask(pred_y, pred_y_mask), [-1]) | |
# score of current tail | |
target_score = pred_score[eval_hrt[idx_3]] | |
triple_mask = get_id_mask(eval_hrt, all_triples) | |
# disabling validate_indices will disable duplication check | |
entity_mask = tf.sparse_to_dense(tf.boolean_mask(all_triples[:, idx_3], triple_mask), | |
output_shape=[n_entity], | |
sparse_values=True, | |
default_value=False, | |
validate_indices=False) | |
# After masking, [i,j] will equals to min_score - 1e-5 if it is a positive instance | |
masked_pred_score = pred_score * tf.cast(tf.logical_not(entity_mask), tf.float32) - \ | |
tf.cast(entity_mask, tf.float32) * 1e30 | |
# Count how many entities has a score larger than target | |
def get_rank(score, entity_scores): | |
return tf.reduce_sum(tf.cast(tf.greater(score, entity_scores), tf.int32)) + 1 | |
unfiltered_rank = get_rank(pred_score, target_score) | |
filtered_rank = get_rank(masked_pred_score, target_score) | |
unfiltered_hit = tf.where(unfiltered_rank <= top_k, 1, 0) | |
filtered_hit = tf.where(filtered_rank <= top_k, 1, 0) | |
return tf.stack( | |
[tf.cast(x, tf.float32) for x in [unfiltered_rank, filtered_rank, unfiltered_hit, filtered_hit]]) | |
metrics = tf.reduce_mean( | |
tf.map_fn(calculate_metrics, eval_triples, | |
dtype=tf.float32, parallel_iterations=20, | |
back_prop=False, swap_memory=True), | |
axis=0) | |
return metrics |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment