Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
JingqingZ committed Mar 29, 2019
1 parent c4065cb commit fe8e753
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 28 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Integrating Semantic Knowledge to Tackle Zero-shot Text Classification. NAACL-HLT 2019. (Accepted)
# Integrating Semantic Knowledge to Tackle Zero-shot Text Classification. NAACL-HLT 2019. Oral (Accepted)
### Jingqing Zhang, Piyawat Lertvittayakumjorn, Yike Guo

##### Jingqing and Piyawat contributed equally to this project.
Expand Down Expand Up @@ -195,11 +195,10 @@ The arguments of the command represent
We would like to thank Douglas McIlwraith, Nontawat Charoenphakdee,
and three anonymous reviewers for helpful suggestions. Jingqing and
Piyawat would also like to thank the support from
[LexisNexis HPCC Systems Academic Program] and Anandamahidol
LexisNexis® Risk Solutions HPCC Systems® academic program
and Anandamahidol
Foundation, respectively.

[LexisNexis HPCC Systems Academic Program]: https://hpccsystems.com/community/academics

<h2 id="Citation">Citation</h2>
Pending

Expand Down
2 changes: 1 addition & 1 deletion data/20-newsgroups/clean/classLabels20news.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ClassCode,ClassLabel,ConceptNet,Count,ClassDescription,Hierarchy,ClassWord
1,alt.atheism,atheism,799,the belief or theory that God does not exist,alt,atheism
2,comp.graphics,graphics,973,pictures produced by computers,computer,graphics
3,comp.os.ms-windows.misc,operating system,985,the software that tells the parts of a computer how to work together and what to do,computer;os;ms;windows,os
3,comp.os.ms-windows.misc,os,985,the software that tells the parts of a computer how to work together and what to do,computer;os;ms;windows,os
4,comp.sys.ibm.pc.hardware,ibm,982,ibm personal computer equipments,computer;system;pc;hardware,ibm
5,comp.sys.mac.hardware,mac,961,mac computer equipment,computer;system;hardware,mac
6,comp.windows.x,windows,980,windows x,computer;x;,windows
Expand Down
18 changes: 16 additions & 2 deletions src_reject/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,25 @@
POS_OF_WORD_path = "../data/POS_OF_WORD.pickle"
WORD_TOPIC_TRANSLATION_path = "../data/WORD_TOPIC_TRANSLATION.pickle"

'''
if dataset in ["dbpedia", "20news"] and unseen_rate in [0.25, 0.5, 0.75]:
rejector_file = "../results/%s_unseen%.2f_augmented%d.pickle" % (dataset, unseen_rate, args.naug)
rejector_file = "../results/%s_unseen%.2f_augmented%d.pickle" % (dataset, unseen_rate, args.naug)
else:
rejector_file = None

'''

if dataset == "dbpedia" and unseen_rate == 0.25:
rejector_file = "./dbpedia_unseen0.25_augmented12000.pickle"
elif dataset == "dbpedia" and unseen_rate == 0.5:
rejector_file = "./dbpedia_unseen0.50_augmented8000.pickle"
elif dataset == "20news" and unseen_rate == 0.25:
rejector_file = "./20news_unseen0.25_augmented4000.pickle"
elif dataset == "20news" and unseen_rate == 0.5:
rejector_file = "./20news_unseen0.50_augmented3000.pickle"
else:
rejector_file = None




##################################
Expand Down
36 changes: 18 additions & 18 deletions src_reject/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def error_seen():
# "" if not config.global_full_test else "_full", epoch)
elif config.dataset == "20news":
filename = "../results/seen_selected_tfidf_news20_vwonly_random%d_unseen%s_max%d_cnn/logs/test_full_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 200, 5)
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 200, 20)
# filename = "../results/unseen_selected_tfidf_news20_kg3_cluster_3group_%s_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext_aug%d/logs/test%s_%d.npz" \
# % (config.model, i + 1, "-".join(str(_) for _ in rgroup[1]), 50, config.negative_sample, config.negative_increase, config.augmentation,
# "" if not config.global_full_test else "_full", epoch)
Expand All @@ -760,7 +760,7 @@ def error_seen():
if k not in overall_stats:
overall_stats[k] = list()
overall_stats[k].append(v)
print_string += "%.3f/%.3f/%.3f," \
print_string += "%.4f/%.4f/%.4f," \
% (1 - classify_stats["single-label-error"],
classify_stats["micro-F1"],
classify_stats["macro-F1"])
Expand All @@ -771,7 +771,7 @@ def error_seen():
print("=======")
print("overall: %s" % (utils.dict_to_string_4_print(overall_stats)))
print("for Google Sheets, split by comma")
print_string += "%.3f/%.3f/%.3f" \
print_string += "%.4f/%.4f/%.4f" \
% (1 - overall_stats["single-label-error"],
overall_stats["micro-F1"],
overall_stats["macro-F1"])
Expand Down Expand Up @@ -863,7 +863,7 @@ def error_unseen():
if k not in overall_stats:
overall_stats[k] = list()
overall_stats[k].append(v)
print_string += "%.3f/%.3f/%.3f," \
print_string += "%.4f/%.4f/%.4f," \
% (1 - classify_stats["single-label-error"],
classify_stats["micro-F1"],
classify_stats["macro-F1"])
Expand All @@ -876,7 +876,7 @@ def error_unseen():
if len(overall_stats) == 0:
continue
print("for Google Sheets, split by comma")
print_string += "%.3f/%.3f/%.3f" \
print_string += "%.4f/%.4f/%.4f" \
% (1 - overall_stats["single-label-error"],
overall_stats["micro-F1"],
overall_stats["macro-F1"])
Expand Down Expand Up @@ -910,7 +910,7 @@ def error_unseen_best():
if k not in overall_stats:
overall_stats[k] = list()
overall_stats[k].append(v)
print_string += "%.3f/%.3f/%.3f," \
print_string += "%.4f/%.4f/%.4f," \
% (1 - best_stats["single-label-error"],
best_stats["micro-F1"],
best_stats["macro-F1"])
Expand All @@ -921,7 +921,7 @@ def error_unseen_best():
print("=======")
print("overall: %s" % (utils.dict_to_string_4_print(overall_stats)))
# print("for Google Sheets, split by comma")
print_string += "%.3f/%.3f/%.3f" \
print_string += "%.4f/%.4f/%.4f" \
% (1 - overall_stats["single-label-error"],
overall_stats["micro-F1"],
overall_stats["macro-F1"])
Expand Down Expand Up @@ -1074,15 +1074,15 @@ def error_overall():
print(epoch, "seen: %s" % (utils.dict_to_string_4_print(seen_stats)))
print(epoch, "unseen: %s" % (utils.dict_to_string_4_print(unseen_stats)))
print(epoch, "overall: %s" % (utils.dict_to_string_4_print(overall_stats)))
print_string += "%.3f/%.3f/%.3f" \
print_string += "%.4f/%.4f/%.4f" \
% (1 - seen_stats["single-label-error"],
seen_stats["micro-F1"],
seen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
print_string += ",%.4f/%.4f/%.4f" \
% (1 - unseen_stats["single-label-error"],
unseen_stats["micro-F1"],
unseen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
print_string += ",%.4f/%.4f/%.4f" \
% (1 - overall_stats["single-label-error"],
overall_stats["micro-F1"],
overall_stats["macro-F1"])
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def error_overall_with_rejector():
unseen_filename = "../results/unseen_selected_tfidf_news20_kg3_cluster_3group_only_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext/logs/test_full_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 50, 1, 1, 1)
seen_filename = "../results/seen_selected_tfidf_news20_vwonly_random%d_unseen%s_max%d_cnn/logs/test_full_%d.npz" \
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 200, 5)
% (i + 1, "-".join(str(_) for _ in rgroup[1]), 200, 20)
elif config.dataset == "20news" and config.unseen_rate == 0.5:
unseen_filename = "../results/unseen_selected_tfidf_news20_kg3_cluster_3group_%s_random%d_unseen%s_max%d_cnn_negative%dincrease%d_randomtext_aug%d/logs/test_full_%d.npz" \
% (config.model, i + 1, "-".join(str(_) for _ in rgroup[1]), 50, config.negative_sample, config.negative_increase, config.augmentation, config.global_test_base_epoch)
Expand Down Expand Up @@ -1271,15 +1271,15 @@ def error_overall_with_rejector():
print("seen: %s" % (utils.dict_to_string_4_print(seen_stats)))
print("unseen: %s" % (utils.dict_to_string_4_print(unseen_stats)))
print("overall: %s" % (utils.dict_to_string_4_print(overall_stats)))
print_string += "%.3f/%.3f/%.3f" \
print_string += "%.4f/%.4f/%.4f" \
% (1 - seen_stats["single-label-error"],
seen_stats["micro-F1"],
seen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
print_string += ",%.4f/%.4f/%.4f" \
% (1 - unseen_stats["single-label-error"],
unseen_stats["micro-F1"],
unseen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
print_string += ",%.4f/%.4f/%.4f" \
% (1 - overall_stats["single-label-error"],
overall_stats["micro-F1"],
overall_stats["macro-F1"])
Expand Down Expand Up @@ -1480,15 +1480,15 @@ def error_phase1_with_rejector():
print("seen: %s" % (utils.dict_to_string_4_print(seen_stats)))
print("unseen: %s" % (utils.dict_to_string_4_print(unseen_stats)))
print("overall: %s" % (utils.dict_to_string_4_print(overall_stats)))
print_string += "%.3f/%.3f/%.3f" \
print_string += "%.4f/%.4f/%.4f" \
% (1 - seen_stats["single-label-error"],
seen_stats["micro-F1"],
seen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
print_string += ",%.4f/%.4f/%.4f" \
% (1 - unseen_stats["single-label-error"],
unseen_stats["micro-F1"],
unseen_stats["macro-F1"])
print_string += ",%.3f/%.3f/%.3f" \
print_string += ",%.4f/%.4f/%.4f" \
% (1 - overall_stats["single-label-error"],
overall_stats["micro-F1"],
overall_stats["macro-F1"])
Expand All @@ -1499,7 +1499,7 @@ def error_phase1_with_rejector():
if __name__ == "__main__":
error_overall_with_rejector()
# error_phase1_with_rejector()
# error_seen()
error_seen()
exit()
if config.model == "cnnfc":
error_unseen()
Expand Down
10 changes: 7 additions & 3 deletions src_reject/train_seen.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,11 @@ def run_20news():

for i, rgroup in enumerate(random_group):

# if i + 1 < config.random_group_start_idx:
# continue
if i + 1 != 6 and i + 1 != 7:
continue

max_length = 200

with tf.Graph().as_default() as graph:
Expand Down Expand Up @@ -598,11 +603,10 @@ def run_20news():
base_epoch=-1,
gpu_config=gpu_config
)
ctl.controller(train_text_seqs, train_class_list, test_text_seqs, test_class_list, train_epoch=5)
ctl.controller4test(test_text_seqs, test_class_list, unseen_class_list=ctl.unseen_class, base_epoch=5)
ctl.controller(train_text_seqs, train_class_list, test_text_seqs, test_class_list, train_epoch=20)
ctl.controller4test(test_text_seqs, test_class_list, unseen_class_list=ctl.unseen_class, base_epoch=20)

ctl.sess.close()
time.sleep(20)

def run_amazon():

Expand Down

0 comments on commit fe8e753

Please sign in to comment.