Skip to content

Commit

Permalink
Add codes for Phase 1 with data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
plkumjorn committed Mar 22, 2019
1 parent 4be2a49 commit c087496
Show file tree
Hide file tree
Showing 4 changed files with 842 additions and 3 deletions.
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,28 @@ python3 train_reject.py \
--train 1
```

- With data augmentation: an example
```bash
python3 train_reject_augmented.py \
--data dbpedia \
--unseen 0.5 \
--model vw \
--nepoch 3 \
--rgidx 1 \
--naug 100 \
--train 1
```

The arguments of the command represent
* `data`: Dataset, either `dbpedia` or `20news`.
* `unseen`: Rate of unseen classes, either `0.25` or `0.5`.
* `model`: The model to be trained. This argument can only be
* `vw`: the inputs are embedding of words (from text)
* `nepoch`: The number of epochs for training
* `train`: In Phase 1, this argument does not affect the program. The program will run training and testing together.
* `rgidx`: Optional, Random group starting index: e.g. if 5, the training will start from the 5th random group, by default `1`. This argument is used when the program is accidentally interrupted.
* `naug`: The number of augmented data per unseen class


### How to train / test the traditional classifier in Phase 2

Expand All @@ -134,7 +156,7 @@ The arguments of the command represent
* `model`: The model to be trained. This argument can only be
* `vw`: the inputs are embedding of words (from text)
* `sepoch`: Repeat training of each epoch for several times. The ratio of positive/negative samples and learning rate will keep consistent in one epoch no matter how many times the epoch is repeated.
* `train`: In Phase 1, this argument does not affect the program. The program will run training and testing together.
* `train`: For the traditional classifier, this argument does not affect the program. The program will run training and testing together.
* `rgidx`: Optional, Random group starting index: e.g. if 5, the training will start from the 5th random group, by default `1`. This argument is used when the program is accidentally interrupted.
* `gpu`: Optional, GPU occupation percentage, by default `1.0`, which means full occupation of available GPUs.
* `baseepoch`: Optional, you may want to specify which epoch to test.
Expand Down
2 changes: 2 additions & 0 deletions src_reject/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
parser.add_argument("--fulltest", type=int, required=False, help="full test or not")
parser.add_argument("--threshold", type=float, required=False, help="threshold for seen")
parser.add_argument("--nott", type=int, required=False, help="no. of original texts to be translated")
parser.add_argument("--naug", type=int, default = 100, required=False, help="no. of augmented data per unseen class")
args = parser.parse_args()
print(args)

Expand Down Expand Up @@ -171,6 +172,7 @@
zhang15_dbpedia_dir = zhang15_dir + "dbpedia_csv/"

zhang15_dbpedia_full_data_path = zhang15_dbpedia_dir + "full.csv"
zhang15_dbpedia_full_augmented_path = zhang15_dbpedia_dir + "full_augmented.csv"

zhang15_dbpedia_train_path = zhang15_dbpedia_dir + "train.csv"
zhang15_dbpedia_train_processed_path = zhang15_dbpedia_dir + "processed_train_text.pkl"
Expand Down
4 changes: 2 additions & 2 deletions src_reject/train_reject.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def print_summary_of_all_iterations(iteration_statistics):
print(len(train_class_list), len(train_text_seqs), len(test_class_list), len(test_text_seqs))
iteration_statistics = []
num_epoch = config.args.nepoch

unseen_percentage = config.unseen_rate

for i in range(config.args.rgidx-1, 10):
unseen_percentage = config.unseen_rate
seen_classes, unseen_classes = random_set[i]['seen'], random_set[i]['unseen']
print(seen_classes, unseen_classes)
stat_list = []
Expand Down
Loading

0 comments on commit c087496

Please sign in to comment.