Last active
January 18, 2023 09:26
-
-
Save emadeldeen24/1da8b1ce98622d5990f35db6a0d3018d to your computer and use it in GitHub Desktop.
calculate the accuracy, precision, recall and specificity from multi-class confusion matrix
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#import torch | |
import os | |
import numpy as np | |
n_classes = 5 | |
cm = np.array | |
([[10606.4167, 1464.2083, 61.9167, 11.75 , 301.4167], | |
[ 610.5417, 2315.7917, 686.4167, 59.4583, 768.9167], | |
[ 152.7083, 2469.5833, 8752.2083, 1394.5 , 1248.5833], | |
[ 3.3333, 29.4167, 385.0833, 1920.375 , 24.5 ], | |
[ 211.9167, 1287.5833, 614.7083, 71.875 , 2945.3333]]) | |
# np.set_printoptions(suppress=True, precision=4) | |
for c in range(n_classes): | |
tp = cm[c,c] | |
fp = sum(cm[:,c]) - cm[c,c] | |
fn = sum(cm[c,:]) - cm[c,c] | |
tn = sum(np.delete(sum(cm)-cm[c,:],c)) | |
recall = tp/(tp+fn) | |
precision = tp/(tp+fp) | |
specificity = tn/(tn+fp) | |
f1_score = 2*((precision*recall)/(precision+recall)) | |
#print(f"for class {c}: acc {accuracy}, recall {recall},\ | |
# precision {precision}, f1 {f1_score}") | |
print("for class {}: recall {}, specificity {}\ | |
precision {}, f1 {}".format(c,round(recall,4), round(specificity,4), round(precision,4),round(f1_score,4))) | |
## print("tp: ", tp) | |
## print("fp: ", fp) | |
## print("fn: ", fn) | |
## print("tn: ", tn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment