Skip to content

Commit

Permalink
Add codes for topic translation
Browse files Browse the repository at this point in the history
  • Loading branch information
plkumjorn committed Mar 22, 2019
1 parent a40088d commit 2bcc64c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 44 deletions.
42 changes: 21 additions & 21 deletions data/20-newsgroups/clean/classLabels20news.csv
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
ClassCode,ClassLabel,ConceptNet,Count,ClassDescription,Hierarchy
1,alt.atheism,atheism,799,the belief or theory that God does not exist,alt
2,comp.graphics,graphics,973,pictures produced by computers,computer
3,comp.os.ms-windows.misc,operating system,985,the software that tells the parts of a computer how to work together and what to do,computer;os;ms;windows
4,comp.sys.ibm.pc.hardware,ibm,982,ibm personal computer equipments,computer;system;pc;hardware
5,comp.sys.mac.hardware,mac,961,mac computer equipment,computer;system;hardware
6,comp.windows.x,windows,980,windows x,computer;x;
7,misc.forsale,sale,972,the process of selling goods or services for money,
8,rec.autos,auto,990,relating to cars,recreation
9,rec.motorcycles,motorcycle,994,a road vehicle that has two wheels and an engine and looks like a large heavy bicycle,recreation
10,rec.sport.baseball,baseball,994,a game played by two teams of nine players who get points by hitting a ball with a bat and then running around four bases,recreation;sport
11,rec.sport.hockey,hockey,999,a game played on grass by two teams of 11 players who try to score goals by hitting a ball with a curved stick called a hockey stick,recreation;sport
12,sci.crypt,crypt,991,the use of codes to put information on a website into a form that can only be read by users with permission,science
13,sci.electronics,electronics,981,using electricity and extremely small electrical parts such as microchips and transistors,science
14,sci.med,medical,990,relating to medicine and the treatment of injuries and diseases,science
15,sci.space,space,987,the whole of the universe outside the Earth’s atmosphere,science
16,soc.religion.christian,christian,997,the religion based on the teachings of Jesus Christ. Its followers worship in a church.,social
17,talk.politics.guns,gun,910,"a weapon that shoots bullets, for example a pistol or a rifle. You load a gun with ammunition and pull the trigger to use it",talk;politics
18,talk.politics.mideast,mideast,940,"the region of the world that consists of the countries east of the Mediterranean Sea and west of India. It includes Egypt, Jordan, Israel, Lebanon, Syria, Turkey, Iran, and Iraq.",talk;politics
19,talk.politics.misc,politics,775,the ideas and activities involved in getting power in a country or over a particular area of the world,talk
20,talk.religion.misc,religion,628,the belief in the existence of a god or gods,talk
ClassCode,ClassLabel,ConceptNet,Count,ClassDescription,Hierarchy,ClassWord
1,alt.atheism,atheism,799,the belief or theory that God does not exist,alt,atheism
2,comp.graphics,graphics,973,pictures produced by computers,computer,graphics
3,comp.os.ms-windows.misc,operating system,985,the software that tells the parts of a computer how to work together and what to do,computer;os;ms;windows,os
4,comp.sys.ibm.pc.hardware,ibm,982,ibm personal computer equipments,computer;system;pc;hardware,ibm
5,comp.sys.mac.hardware,mac,961,mac computer equipment,computer;system;hardware,mac
6,comp.windows.x,windows,980,windows x,computer;x;,windows
7,misc.forsale,sale,972,the process of selling goods or services for money,,sale
8,rec.autos,auto,990,relating to cars,recreation,auto
9,rec.motorcycles,motorcycle,994,a road vehicle that has two wheels and an engine and looks like a large heavy bicycle,recreation,motorcycle
10,rec.sport.baseball,baseball,994,a game played by two teams of nine players who get points by hitting a ball with a bat and then running around four bases,recreation;sport,baseball
11,rec.sport.hockey,hockey,999,a game played on grass by two teams of 11 players who try to score goals by hitting a ball with a curved stick called a hockey stick,recreation;sport,hockey
12,sci.crypt,crypt,991,the use of codes to put information on a website into a form that can only be read by users with permission,science,cryptography
13,sci.electronics,electronics,981,using electricity and extremely small electrical parts such as microchips and transistors,science,electronics
14,sci.med,medical,990,relating to medicine and the treatment of injuries and diseases,science,medical
15,sci.space,space,987,the whole of the universe outside the Earth’s atmosphere,science,space
16,soc.religion.christian,christian,997,the religion based on the teachings of Jesus Christ. Its followers worship in a church.,social,christian
17,talk.politics.guns,gun,910,"a weapon that shoots bullets, for example a pistol or a rifle. You load a gun with ammunition and pull the trigger to use it",talk;politics,gun
18,talk.politics.mideast,mideast,940,"the region of the world that consists of the countries east of the Mediterranean Sea and west of India. It includes Egypt, Jordan, Israel, Lebanon, Syria, Turkey, Iran, and Iraq.",talk;politics,mideast
19,talk.politics.misc,politics,775,the ideas and activities involved in getting power in a country or over a particular area of the world,talk,politics
20,talk.religion.misc,religion,628,the belief in the existence of a god or gods,talk,religion
26 changes: 14 additions & 12 deletions src_reject/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
import argparse

parser = argparse.ArgumentParser(description='configurations')
parser.add_argument("--data", type=str, required=True, help="dataset: dbpedia or 20news")
parser.add_argument("--unseen", type=float, required=True, help="unseen rate: 0.25 0.5 0.75")
# parser.add_argument("--aug", type=int, required=True, help="augmentation: 0 4000 8000 12000 16000 20000")
parser.add_argument("--model", type=str, required=True, help="model: vwvcvkg vwvc vwvkg vcvkg kgonly cnnfc rnnfc")
parser.add_argument("--data", type=str, required=False, help="dataset: dbpedia or 20news")
parser.add_argument("--unseen", type=float, required=False, help="unseen rate: 0.25 0.5 0.75")
# parser.add_argument("--aug", type=int, required=False, help="augmentation: 0 4000 8000 12000 16000 20000")
parser.add_argument("--model", type=str, required=False, help="model: vwvcvkg vwvc vwvkg vcvkg kgonly cnnfc rnnfc")
parser.add_argument("--ns", type=int, default=2, required=False, help="negative samples: integer, the ratio of positive and negative samples, the higher the more negative samples")
parser.add_argument("--ni", type=int, default=2, required=False, help="negative increase: integer, the speed of increasing negative samples during training per epoch")
parser.add_argument("--sepoch", type=int, required=True, help="small epoch: integer, repeat training of each epoch for several times so that the ratio of posi/negative, learning rate both keep the same")
parser.add_argument("--sepoch", type=int, required=False, help="small epoch: integer, repeat training of each epoch for several times so that the ratio of posi/negative, learning rate both keep the same")
parser.add_argument("--rgidx", type=int, default=1, required=False, help="random group starting index: e.g. if 5, the training will start from the 5th random group, by default 1")
parser.add_argument("--train", type=int, required=True, help="train or not")
parser.add_argument("--train", type=int, required=False, help="train or not")
parser.add_argument("--gpu", type=float, default=1.0, required=False, help="gpu occupation percentage")
parser.add_argument("--baseepoch", type=int, required=False, help="base epoch for testing")
parser.add_argument("--fulltest", type=int, required=False, help="full test or not")
parser.add_argument("--threshold", type=float, required=False, help="threshold for seen")
parser.add_argument("--nott", type=int, required=False, help="no. of original texts to be translated")
args = parser.parse_args()
print(args)

Expand Down Expand Up @@ -173,9 +174,9 @@
zhang15_dbpedia_train_path = zhang15_dbpedia_dir + "train.csv"
zhang15_dbpedia_train_processed_path = zhang15_dbpedia_dir + "processed_train_text.pkl"

# TODO by Peter: how to get augmented data
zhang15_dbpedia_train_aug_path = zhang15_dbpedia_dir + "train_augmented_aggregated.csv"
zhang15_dbpedia_train_aug_processed_path = zhang15_dbpedia_dir + "processed_train_aug_text.pkl"
zhang15_dbpedia_train_augmented_path = zhang15_dbpedia_dir + "train_augmented.csv"
zhang15_dbpedia_train_augmented_aggregated_path = zhang15_dbpedia_dir + "train_augmented_aggregated.csv"
zhang15_dbpedia_train_augmented_processed_path = zhang15_dbpedia_dir + "processed_train_augmented_text.pkl"

zhang15_dbpedia_test_path = zhang15_dbpedia_dir + "test.csv"
zhang15_dbpedia_test_processed_path = zhang15_dbpedia_dir + "processed_test_text.pkl"
Expand Down Expand Up @@ -285,9 +286,10 @@
news20_test_path = news20_dir + "test.csv"
news20_test_processed_path = news20_dir + "processed_test_text.pkl"

# TODO by Peter, how to get augmented data
news20_train_aug_path = news20_dir + "train_augmented.csv"
news20_train_aug_processed_path = news20_dir + "processed_train_aug_text.pkl"
news20_train_augmented_path = news20_dir + "train_augmented.csv"
news20_train_augmented_aggregated_path = news20_dir + "train_augmented_aggregated.csv"
news20_train_augmented_processed_path = news20_dir + "processed_train_augmented_text.pkl"


news20_vocab_path = news20_dir + "vocab.txt"

Expand Down
37 changes: 26 additions & 11 deletions src_reject/topic_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,24 @@
from nltk.corpus import wordnet as wn
import language_check

maxInt = sys.maxsize
decrement = True

while decrement:
decrement = False
try:
csv.field_size_limit(maxInt)
except OverflowError:
maxInt = int(maxInt/10)
decrement = True

lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))


glove_file = datapath(config.word_embed_file_path)
tmp_file = get_tmpfile(config.word_embed_gensim_file_path)
_ = glove2word2vec(glove_file, tmp_file)
glove_model = KeyedVectors.load_word2vec_format(tmp_file)
if not os.path.isfile(config.word_embed_gensim_file_path):
_ = glove2word2vec(config.word_embed_file_path, config.word_embed_gensim_file_path)
glove_model = KeyedVectors.load_word2vec_format(config.word_embed_gensim_file_path)

tool = language_check.LanguageTool('en-US')

Expand Down Expand Up @@ -95,7 +104,7 @@ def topic_transfer(text, from_class, to_class):

return ans_sentence

def augment_train(class_label_path, train_augmented_path, train_path):
def augment_train(class_label_path, train_augmented_path, train_path, nott):
global POS_OF_WORD, WORD_TOPIC_TRANSLATION
if os.path.isfile(config.POS_OF_WORD_path):
POS_OF_WORD = pickle.load(open(config.POS_OF_WORD_path, "rb"))
Expand All @@ -106,7 +115,7 @@ def augment_train(class_label_path, train_augmented_path, train_path):
class_dict = dataloader.load_class_dict(
class_file=class_label_path,
class_code_column="ClassCode",
class_name_column="ConceptNet"
class_name_column="ClassWord"
)

fieldnames = ['No.','from_class', 'to_class', 'text']
Expand All @@ -119,18 +128,17 @@ def augment_train(class_label_path, train_augmented_path, train_path):
reader = csv.DictReader(csvfile)
rows = list(reader)
random.shuffle(rows)
if nott is not None: # no. of texts to be translated
rows = rows[:min(nott, len(rows))]
count = 0
with progressbar.ProgressBar(max_value=len(rows)) as bar:
for idx, row in enumerate(rows):
text = row['text']
class_id = int(row['class'])
class_name = class_dict[class_id]
# print('---Original----', class_name)
# print(text)

for cidx in class_dict:
if cidx != int(row['class']):
# print('---------------', class_dict[idx])
# print(topic_transfer(text, from_class = class_name, to_class = class_dict[idx]))
try:
writer.writerow({'No.':count, 'from_class': class_id, 'to_class': cidx, 'text':topic_transfer(text, from_class = class_name, to_class = class_dict[cidx])})
count += 1
Expand All @@ -143,4 +151,11 @@ def augment_train(class_label_path, train_augmented_path, train_path):
csvwritefile.close()

if __name__ == "__main__":
print(sys.argv)
print(config.dataset, config.args.nott)
if config.dataset == "dbpedia":
augment_train(config.zhang15_dbpedia_class_label_path, config.zhang15_dbpedia_train_augmented_aggregated_path, config.zhang15_dbpedia_train_path, config.args.nott)
elif config.dataset == "20news":
augment_train(config.news20_class_label_path, config.news20_train_augmented_aggregated_path, config.news20_train_path, config.args.nott)
else:
raise Exception("config.dataset %s not found" % config.dataset)
pass

0 comments on commit 2bcc64c

Please sign in to comment.