Skip to content

Commit

Permalink
Save result files from Phase 1
Browse files Browse the repository at this point in the history
  • Loading branch information
plkumjorn committed Mar 23, 2019
1 parent c087496 commit 2277916
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
12 changes: 3 additions & 9 deletions src_reject/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,11 @@
POS_OF_WORD_path = "../data/POS_OF_WORD.pickle"
WORD_TOPIC_TRANSLATION_path = "../data/WORD_TOPIC_TRANSLATION.pickle"

# TODO by Peter: how to get these rejector files
if dataset == "dbpedia" and unseen_rate == 0.25:
rejector_file = "./dbpedia_unseen0.25_augmented12000.pickle"
elif dataset == "dbpedia" and unseen_rate == 0.5:
rejector_file = "./dbpedia_unseen0.50_augmented8000.pickle"
elif dataset == "20news" and unseen_rate == 0.25:
rejector_file = "./20news_unseen0.25_augmented4000.pickle"
elif dataset == "20news" and unseen_rate == 0.5:
rejector_file = "./20news_unseen0.50_augmented3000.pickle"
if dataset in ["dbpedia", "20news"] and unseen_rate in [0.25, 0.5, 0.75]:
rejector_file = "../results/%s_unseen%.2f_augmented%d.pickle" % (dataset, unseen_rate, args.naug)
else:
rejector_file = None



##################################
Expand Down
4 changes: 4 additions & 0 deletions src_reject/train_reject.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def print_summary_of_all_iterations(iteration_statistics):
iteration_statistics = []
num_epoch = config.args.nepoch
unseen_percentage = config.unseen_rate
pass_to_phase2 = []

for i in range(config.args.rgidx-1, 10):
seen_classes, unseen_classes = random_set[i]['seen'], random_set[i]['unseen']
Expand Down Expand Up @@ -609,6 +610,7 @@ def print_summary_of_all_iterations(iteration_statistics):
else:
gt_accepted.append(0)
accepted_stats = utils.get_precision_recall_f1(np.array(accepted), np.array(gt_accepted), with_confusion_matrix = True)
pass_to_phase2.append(accepted)

avg_classifier_stat = dict()
for key in stat_list[0]:
Expand All @@ -633,6 +635,8 @@ def print_summary_of_all_iterations(iteration_statistics):
iteration_statistics=iteration_statistics
)

pickle.dump(pass_to_phase2, results_path + "%s_unseen%.2f_augmented0.pickle" % (dataset_name, unseen_percentage))

pass


Expand Down
5 changes: 3 additions & 2 deletions src_reject/train_reject_augmented.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def print_summary_of_all_iterations(iteration_statistics):
num_epoch = config.args.nepoch
unseen_percentage = config.unseen_rate
num_augmented = config.args.naug

pass_to_phase2 = []
for i in range(config.args.rgidx-1, 10):
seen_classes, unseen_classes = random_set[i]['seen'], random_set[i]['unseen']

Expand Down Expand Up @@ -779,6 +779,7 @@ def print_summary_of_all_iterations(iteration_statistics):
else:
gt_accepted.append(0)
accepted_stats = utils.get_precision_recall_f1(np.array(accepted), np.array(gt_accepted), with_confusion_matrix = True)
pass_to_phase2.append(accepted)

avg_classifier_stat = dict()
for key in stat_list[0]:
Expand Down Expand Up @@ -807,7 +808,7 @@ def print_summary_of_all_iterations(iteration_statistics):
iteration_statistics=iteration_statistics
)


pickle.dump(pass_to_phase2, results_path + "%s_unseen%.2f_augmented%d.pickle" % (dataset_name, unseen_percentage, num_augmented))

pass

Expand Down

0 comments on commit 2277916

Please sign in to comment.