-
Notifications
You must be signed in to change notification settings - Fork 110
/
train.py
32 lines (26 loc) · 859 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, LambdaCallback
from keras.callbacks import EarlyStopping, TensorBoard
import argparse
import midi
import os
from constants import *
from dataset import *
from generate import *
from midi_util import midi_encode
from model import *
def main():
models = build_or_load()
train(models)
def train(models):
print('Loading data')
train_data, train_labels = load_all(styles, BATCH_SIZE, SEQ_LEN)
cbs = [
ModelCheckpoint(MODEL_FILE, monitor='loss', save_best_only=True, save_weights_only=True),
EarlyStopping(monitor='loss', patience=5),
TensorBoard(log_dir='out/logs', histogram_freq=1)
]
print('Training')
models[0].fit(train_data, train_labels, epochs=1000, callbacks=cbs, batch_size=BATCH_SIZE)
if __name__ == '__main__':
main()