Skip to content

Commit

Permalink
BUG
Browse files Browse the repository at this point in the history
  • Loading branch information
JingqingZ committed Mar 15, 2019
1 parent 552b8c5 commit b9bf256
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 5 deletions.
221 changes: 218 additions & 3 deletions src_reject/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ def error_overall_with_rejector():
unseen_filename = "../results/unseen_full_zhang15_dbpedia_kg3_cluster_3group_%s_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext_aug%d/logs/test_full_%d.npz" \
% (config.model, i + 1, "-".join(str(_) for _ in rgroup[1]), 80, config.negative_sample, config.negative_increase, config.augmentation, config.global_test_base_epoch)
seen_filename = "../results/seen_full_zhang15_dbpedia_vwonly_random%d_unseen%s_max%d_cnn/logs/test_full_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 50, 2)
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 50, 1)
elif config.dataset == "dbpedia" and config.unseen_rate == 0.5:
unseen_filename = "../results/unseen_full_zhang15_dbpedia_kg3_cluster_3group_%s_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext_aug%d/logs/test_full_%d.npz" \
% (config.model, i + 1, "-".join(str(_) for _ in rgroup[1]), 80, config.negative_sample, config.negative_increase, config.augmentation, config.global_test_base_epoch)
Expand Down Expand Up @@ -1194,13 +1194,17 @@ def error_overall_with_rejector():
c = np.argmax(pred_seen[seen_counter - 1])
pa += 1
elif reject_list[data_idx] == 1 and class_id in rgroup[1]:
c = rgroup[1][1] - 1
# FIXME:
# ORIGINAL: c = rgroup[1][1] - 1
c = rgroup[0][1] - 1
na += 1
elif reject_list[data_idx] == 0 and class_id in rgroup[1]:
c = np.argmax(pred_unseen[unseen_counter - 1])
nr += 1
elif reject_list[data_idx] == 0 and class_id in rgroup[0]:
c = rgroup[0][1] - 1
# FIXME:
# ORIGINAL: c = rgroup[0][1] - 1
c = rgroup[1][1] - 1
pr += 1
else:
raise Exception("invalid rejection")
Expand Down Expand Up @@ -1281,9 +1285,220 @@ def error_overall_with_rejector():
overall_stats["macro-F1"])
print(print_string)

def error_phase1_with_rejector():
import pickle

if config.dataset == "dbpedia":
random_group = dataloader.get_random_group(config.zhang15_dbpedia_class_random_group_path)
elif config.dataset == "20news":
random_group = dataloader.get_random_group(config.news20_class_random_group_path)
else:
raise Exception("invalid dataset")

with open(config.rejector_file, 'rb') as f:
full_reject_list = pickle.load(f)

seen_stats = dict()
unseen_stats = dict()
overall_stats = dict()
print_string = ""

seen_acc_list = list()
unseen_acc_list = list()
overall_acc_list = list()
for i, rgroup in enumerate(random_group):
reject_list = full_reject_list[i]

print(rgroup)
print(len(reject_list))

'''
if config.dataset == "dbpedia" and config.unseen_rate == 0.25:
# unseen_filename = "../results/unseen_full_zhang15_dbpedia_kg3_cluster_3group_random%d_unseen%s_max%d_cnn_negative%dincrease3_randomtext/logs/test_full_%d.npz" \
# % (i + 1, "-".join(str(_) for _ in rgroup[1]), 80, 5, 9)
unseen_filename = "../results/unseen_full_zhang15_dbpedia_kg3_cluster_3group_%s_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext_aug%d/logs/test_full_%d.npz" \
% (config.model, i + 1, "-".join(str(_) for _ in rgroup[1]), 80, config.negative_sample, config.negative_increase, config.augmentation, config.global_test_base_epoch)
seen_filename = "../results/seen_full_zhang15_dbpedia_vwonly_random%d_unseen%s_max%d_cnn/logs/test_full_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 50, 2)
elif config.dataset == "dbpedia" and config.unseen_rate == 0.5:
unseen_filename = "../results/unseen_full_zhang15_dbpedia_kg3_cluster_3group_%s_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext_aug%d/logs/test_full_%d.npz" \
% (config.model, i + 1, "-".join(str(_) for _ in rgroup[1]), 80, config.negative_sample, config.negative_increase, config.augmentation, config.global_test_base_epoch)
seen_filename = "../results/seen_full_zhang15_dbpedia_vwonly_random%d_unseen%s_max%d_cnn/logs/test_full_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 50, 2)
elif config.dataset == "20news" and config.unseen_rate == 0.25:
unseen_filename = "../results/unseen_selected_tfidf_news20_kg3_cluster_3group_only_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext/logs/test_full_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 50, 1, 1, 1)
seen_filename = "../results/seen_selected_tfidf_news20_vwonly_random%d_unseen%s_max%d_cnn/logs/test_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 200, 30)
elif config.dataset == "20news" and config.unseen_rate == 0.5:
unseen_filename = "../results/unseen_selected_tfidf_news20_kg3_cluster_3group_%s_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext_aug%d/logs/test_full_%d.npz" \
% (config.model, i + 1, "-".join(str(_) for _ in rgroup[1]), 50, config.negative_sample, config.negative_increase, config.augmentation, config.global_test_base_epoch)
seen_filename = "../results/seen_selected_tfidf_news20_vwonly_random%d_unseen%s_max%d_cnn/logs/test_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 200, 10)
else:
raise Exception("exception")
unseen_data = np.load(unseen_filename)
seen_data = np.load(seen_filename)
pred_unseen = unseen_data["pred_unseen"]
pred_seen = seen_data["pred_seen"]
print(pred_unseen.shape)
print(pred_seen.shape)
assert pred_unseen.shape[1] == pred_seen.shape[1]
assert pred_unseen.shape[1] == len(rgroup[0]) + len(rgroup[1])
'''

test_class_list = dataloader.load_data_class(
filename=config.zhang15_dbpedia_test_path if config.dataset == "dbpedia" else config.news20_test_path,
column="class",
)

print(len(test_class_list))
assert len(test_class_list) == len(reject_list)

seen_counter = 0
unseen_counter = 0

pred_overall_seen = list()
gt_overall_seen = list()
pred_overall_unseen = list()
gt_overall_unseen = list()

pa = pr = na = nr = 0
with progressbar.ProgressBar(max_value=len(test_class_list)) as bar:
for data_idx, class_id in enumerate(test_class_list):

'''
if class_id in rgroup[0]:
seen_counter += 1
if seen_counter > pred_seen.shape[0]:
continue
elif class_id in rgroup[1]:
unseen_counter += 1
if unseen_counter > pred_unseen.shape[0]:
continue
else:
raise Exception("invalid class id")
g = np.zeros(len(rgroup[0]) + len(rgroup[1]))
g[class_id - 1] = 1
if class_id in rgroup[0]:
gt_overall_seen.append(g)
else:
gt_overall_unseen.append(g)
'''


# r = np.zeros(len(rgroup[0]) + len(rgroup[1]))
if reject_list[data_idx] == 1 and class_id in rgroup[0]:
# c = np.argmax(pred_seen[seen_counter - 1])
pa += 1
elif reject_list[data_idx] == 1 and class_id in rgroup[1]:
# c = rgroup[1][1] - 1
na += 1
elif reject_list[data_idx] == 0 and class_id in rgroup[1]:
# c = np.argmax(pred_unseen[unseen_counter - 1])
nr += 1
elif reject_list[data_idx] == 0 and class_id in rgroup[0]:
# c = rgroup[0][1] - 1
pr += 1
else:
raise Exception("invalid rejection")

# r[c] = 1
# if class_id in rgroup[0]:
# pred_overall_seen.append(r)
# else:
# pred_overall_unseen.append(r)

bar.update(data_idx)

'''
pred_overall_seen = np.array(pred_overall_seen)
gt_overall_seen = np.array(gt_overall_seen)
pred_overall_unseen = np.array(pred_overall_unseen)
gt_overall_unseen = np.array(gt_overall_unseen)
pred_overall = np.concatenate([pred_overall_seen, pred_overall_unseen], axis=0)
gt_overall = np.concatenate([gt_overall_seen, gt_overall_unseen], axis=0)
classify_stats_seen = utils.get_statistics(
pred_overall_seen,
gt_overall_seen,
single_label_pred=True
)
classify_stats_unseen = utils.get_statistics(
pred_overall_unseen,
gt_overall_unseen,
single_label_pred=True
)
classify_stats_overall = utils.get_statistics(
pred_overall,
gt_overall,
single_label_pred=True
)
for k in classify_stats_seen:
v = classify_stats_seen[k]
if k not in seen_stats:
seen_stats[k] = list()
seen_stats[k].append(v)
for k in classify_stats_unseen:
v = classify_stats_unseen[k]
if k not in unseen_stats:
unseen_stats[k] = list()
unseen_stats[k].append(v)
for k in classify_stats_overall:
v = classify_stats_overall[k]
if k not in overall_stats:
overall_stats[k] = list()
overall_stats[k].append(v)
print("seen ", utils.dict_to_string_4_print(classify_stats_seen))
print("unseen ", utils.dict_to_string_4_print(classify_stats_unseen))
print("overal ", utils.dict_to_string_4_print(classify_stats_overall))
'''
print(pa, pr, na, nr)
seen_acc_list.append(pa / (pa + pr))
unseen_acc_list.append(nr / (na + nr))
overall_acc_list.append((pa + nr) / (pa + pr + na + nr))
print("----------")

print(np.mean(seen_acc_list))
print(np.mean(unseen_acc_list))
print(np.mean(overall_acc_list))

'''
for k in seen_stats:
seen_stats[k] = np.mean(seen_stats[k])
for k in unseen_stats:
unseen_stats[k] = np.mean(unseen_stats[k])
for k in overall_stats:
overall_stats[k] = np.mean(overall_stats[k])
print("=======")
print("seen: %s" % (utils.dict_to_string_4_print(seen_stats)))
print("unseen: %s" % (utils.dict_to_string_4_print(unseen_stats)))
print("overall: %s" % (utils.dict_to_string_4_print(overall_stats)))
print_string += "%.3f/%.3f/%.3f" \
% (1 - seen_stats["single-label-error"],
seen_stats["micro-F1"],
seen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
% (1 - unseen_stats["single-label-error"],
unseen_stats["micro-F1"],
unseen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
% (1 - overall_stats["single-label-error"],
overall_stats["micro-F1"],
overall_stats["macro-F1"])
print(print_string)
'''


if __name__ == "__main__":
error_overall_with_rejector()
# error_phase1_with_rejector()
# error_seen()
exit()
if config.model == "cnnfc":
Expand Down
7 changes: 5 additions & 2 deletions src_reject/train_seen.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def run_dbpedia():

for i, rgroup in enumerate(random_group):

if i + 1 < config.random_group_start_idx:
continue

# unseen_percentage = 0.0
max_length = 50

Expand Down Expand Up @@ -483,9 +486,9 @@ def run_dbpedia():
base_epoch=-1,
gpu_config=gpu_config
)
ctl.controller(train_text_seqs, train_class_list, test_text_seqs, test_class_list, train_epoch=2)
ctl.controller(train_text_seqs, train_class_list, test_text_seqs, test_class_list, train_epoch=1)
# ctl.controller4test(test_text_seqs, test_class_list, unseen_class_list, base_epoch=5)
ctl.controller4test(test_text_seqs, test_class_list, unseen_class_list=ctl.unseen_class, base_epoch=2)
ctl.controller4test(test_text_seqs, test_class_list, unseen_class_list=ctl.unseen_class, base_epoch=1)

ctl.sess.close()

Expand Down

0 comments on commit b9bf256

Please sign in to comment.