Created
February 20, 2023 07:30
-
-
Save Tobias-Fischer/4cde91f12531e9cbb862f0ee88793cb0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
def createPR(S_in, GThard, GTsoft=None): | |
GT = GThard.astype('bool') # ensure logical-datatype | |
S = S_in.copy() | |
if GTsoft is not None: | |
S[GTsoft & ~GT] = S.min() | |
# init precision and recall vectors | |
R = [0.0, ] | |
P = [1.0, ] | |
R_Stefan = [0.0, ] | |
P_Stefan = [1.0, ] | |
# select start and end treshold | |
startV = S.max() # start-value for treshold | |
endV = S.min() # end-value for treshold | |
# iterate over different thresholds | |
for i in np.linspace(startV, endV, 3): | |
B = S >= i # apply threshold | |
print(B) | |
TP, FP, TN, FN = 0, 0, 0, 0 | |
for q in range(S.shape[1]): | |
if B[:, q].sum() == 0: # did not exceed threshold, i.e. not returning a match | |
if GT[:, q].sum() == 0: | |
TN += 1 | |
else: | |
FN += 1 | |
else: # exceeded threshold, i.e. we found a match | |
if np.all(GT[:, q] == B[:, q]): # probably needs a nicer way but works to illustrate the issue | |
TP += 1 | |
else: | |
FP += 1 | |
TP_Stefan = np.count_nonzero(GT & B) # true positives | |
FN_Stefan = np.count_nonzero(GT & (~B)) # false negatives | |
FP_Stefan = np.count_nonzero((~GT) & B) # false positives | |
if TP + FP == 0: | |
continue | |
if TP + FN == 0: | |
continue | |
P.append(TP / (TP + FP)) # precision | |
R.append(TP / (TP + FN)) # recall | |
P_Stefan.append(TP_Stefan / (TP_Stefan + FP_Stefan)) # precision | |
R_Stefan.append(TP_Stefan / (TP_Stefan + FN_Stefan)) # recall | |
print('R', R) | |
print('P', P) | |
print('R_Stefan', R_Stefan) | |
print('P_Stefan', P_Stefan) | |
plt.plot(R, P, 'r') | |
plt.plot(R_Stefan, P_Stefan, 'b') | |
plt.show() | |
num_DB = 5 | |
num_Q = 5 | |
GT = np.array([ | |
[1, 0, 0, 0, 0], | |
[0, 1, 0, 0, 0], | |
[0, 0, 1, 0, 0], | |
[0, 0, 0, 1, 0], | |
[0, 0, 0, 0, 1], | |
]).astype('bool') | |
S = np.array([ | |
[1, 0, 0, 0, 0], | |
[0, 0.5, 0, 0, 0], | |
[0, 0, 0, 0, 0], | |
[0, 0, 1, 0, 0.5], | |
[0, 0, 0, 0.5, 0], | |
]) | |
createPR(S, GT) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment