Skip to content

Commit

Permalink
Merge utils
Browse files Browse the repository at this point in the history
  • Loading branch information
plkumjorn committed Apr 2, 2019
1 parent 16244c5 commit 744f2ed
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions src_reject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,52 @@ def make_dirlist(dirlist):
if not os.path.exists(dir):
os.makedirs(dir)

def get_precision_recall_f1(prediction, ground_truth): # 1D data
def get_precision_recall_f1(prediction, ground_truth, with_confusion_matrix = False): # 1D data
# print(prediction.shape, ground_truth.shape, prediction.ndim)
assert prediction.shape == ground_truth.shape and prediction.ndim == 1
# print(prediction)
# print(ground_truth)
TP, FP, FN = 0, 0, 0
# TP, FP, FN = 0, 0, 0
# for i in range(len(prediction)):
# if prediction[i] == 1 and ground_truth[i] == 1:
# TP += 1
# elif prediction[i] == 1 and ground_truth[i] == 0:
# FP += 1
# elif prediction[i] == 0 and ground_truth[i] == 1:
# FN += 1
results = get_confusion_matrix(prediction, ground_truth)
TP, FP, FN = results['TP'], results['FP'], results['FN']
if (TP, FP, FN) == (0, 0, 0):
if with_confusion_matrix:
results['P'], results['R'], results['F1'] = None, None, None
return results
else:
return None, None, None
P = TP / (TP + FP) if TP + FP != 0 else 0
R = TP / (TP + FN) if TP + FN != 0 else 0
F1 = 2 * P * R / (P + R) if P + R != 0 else 0

if with_confusion_matrix:
results['P'], results['R'], results['F1'] = P, R, F1
return results
else:
return P, R, F1

def get_confusion_matrix(prediction, ground_truth):
TP, FP, FN, TN = 0, 0, 0, 0
for i in range(len(prediction)):
if prediction[i] == 1 and ground_truth[i] == 1:
TP += 1
elif prediction[i] == 1 and ground_truth[i] == 0:
FP += 1
elif prediction[i] == 0 and ground_truth[i] == 1:
FN += 1
if (TP, FP, FN) == (0, 0, 0):
return None, None, None
P = TP / (TP + FP) if TP + FP != 0 else 0
R = TP / (TP + FN) if TP + FN != 0 else 0
F1 = 2 * P * R / (P + R) if P + R != 0 else 0
return P, R, F1

elif prediction[i] == 0 and ground_truth[i] == 0:
TN += 1
else:
assert False
return {'TP':TP, 'FP':FP, 'FN':FN, 'TN':TN}

def get_statistics(prediction, ground_truth, single_label_pred=False):

num_data_of_class_gt = np.sum(ground_truth, axis=0)
Expand Down Expand Up @@ -101,7 +128,12 @@ def get_statistics(prediction, ground_truth, single_label_pred=False):

def dict_to_string_4_print(dict):
keys = sorted(dict.keys())
return ', '.join(['%s: %.3f' % (key, dict[key]) for key in keys])
if 'texts_accepted_from_class' in keys:
ans = 'Texts_accepted_from_class:' + str(dict['texts_accepted_from_class']) + '\n'
else:
ans = ''
return ans + ', '.join(['%s: %.3f' % (key, dict[key]) if dict[key] is not None else '%s: None' % (key) for key in keys if key != 'texts_accepted_from_class'])



if __name__ == "__main__":
Expand Down

0 comments on commit 744f2ed

Please sign in to comment.