|
| 1 | +# -*- coding: UTF-8 -*- |
| 2 | +import tensorflow as tf |
| 3 | +import numpy as np |
| 4 | +import math |
| 5 | +from dnlp.core.dnn_crf_base import DnnCrfBase |
| 6 | +from dnlp.config.config import DnnCrfConfig |
| 7 | + |
| 8 | + |
| 9 | +class DnnCrf(DnnCrfBase): |
| 10 | + def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf.float32, mode: str = 'train',nn:str, model_path:str=''): |
| 11 | + if mode not in ['train', 'predict']: |
| 12 | + raise Exception('mode error') |
| 13 | + if nn not in ['mlp','lstm','gru']: |
| 14 | + raise Exception('name of neural network entered is not supported') |
| 15 | + |
| 16 | + DnnCrfBase.__init__(self, config, data_path, mode, model_path) |
| 17 | + self.dtype = dtype |
| 18 | + self.mode = mode |
| 19 | + |
| 20 | + # 构建 |
| 21 | + tf.reset_default_graph() |
| 22 | + self.transition = self.__get_variable([self.tags_count, self.tags_count], 'transition') |
| 23 | + self.transition_init = self.__get_variable([self.tags_count], 'transition_init') |
| 24 | + self.params = [self.transition, self.transition_init] |
| 25 | + # 输入层 |
| 26 | + if mode == 'train': |
| 27 | + self.input = tf.placeholder(tf.int32, [self.batch_size, self.batch_length, self.windows_size]) |
| 28 | + else: |
| 29 | + self.input = tf.placeholder(tf.int32, [None, self.windows_size]) |
| 30 | + # 查找表层 |
| 31 | + self.embedding_layer = self.get_embedding_layer() |
| 32 | + # 隐藏层 |
| 33 | + if nn == 'mlp': |
| 34 | + self.hidden_layer = self.get_mlp_layer(tf.transpose(self.embedding_layer)) |
| 35 | + elif nn == 'lstm': |
| 36 | + self.hidden_layer = self.get_lstm_layer(tf.transpose(self.embedding_layer)) |
| 37 | + else: |
| 38 | + self.hidden_layer = self.get_gru_layer(tf.transpose(self.embedding_layer)) |
| 39 | + # 输出层 |
| 40 | + self.output = self.get_output_layer(self.hidden_layer) |
| 41 | + |
| 42 | + if mode == 'predict': |
| 43 | + self.output = tf.squeeze(self.output, axis=2) |
| 44 | + else: |
| 45 | + # 构建训练函数 |
| 46 | + # 训练用placeholder |
| 47 | + self.ll_corr = tf.placeholder(tf.int32, shape=[None, 3]) |
| 48 | + self.ll_curr = tf.placeholder(tf.int32, shape=[None, 3]) |
| 49 | + self.trans_corr = tf.placeholder(tf.int32, [None, 2]) |
| 50 | + self.trans_curr = tf.placeholder(tf.int32, [None, 2]) |
| 51 | + self.trans_init_corr = tf.placeholder(tf.int32, [None, 1]) |
| 52 | + self.trans_init_curr = tf.placeholder(tf.int32, [None, 1]) |
| 53 | + # 损失函数 |
| 54 | + self.loss, self.loss_with_init = self.get_loss() |
| 55 | + self.optimizer = tf.train.AdagradOptimizer(self.learning_rate) |
| 56 | + self.train = self.optimizer.minimize(self.loss) |
| 57 | + self.train_with_init = self.optimizer.minimize(self.loss_with_init) |
| 58 | + |
| 59 | + def fit(self, epochs: int = 100, interval: int = 20): |
| 60 | + with tf.Session() as sess: |
| 61 | + tf.global_variables_initializer().run() |
| 62 | + saver = tf.train.Saver(max_to_keep=100) |
| 63 | + for epoch in range(1, epochs + 1): |
| 64 | + print('epoch:', epoch) |
| 65 | + for _ in range(self.batch_count): |
| 66 | + characters, labels, lengths = self.get_batch() |
| 67 | + self.fit_batch(characters, labels, lengths, sess) |
| 68 | + # if epoch % interval == 0: |
| 69 | + model_path = '../dnlp/models/cws{0}.ckpt'.format(epoch) |
| 70 | + saver.save(sess, model_path) |
| 71 | + self.save_config(model_path) |
| 72 | + |
| 73 | + def fit_batch(self, characters, labels, lengths, sess): |
| 74 | + scores = sess.run(self.output, feed_dict={self.input: characters}) |
| 75 | + transition = self.transition.eval(session=sess) |
| 76 | + transition_init = self.transition_init.eval(session=sess) |
| 77 | + update_labels_pos = None |
| 78 | + update_labels_neg = None |
| 79 | + current_labels = [] |
| 80 | + trans_pos_indices = [] |
| 81 | + trans_neg_indices = [] |
| 82 | + trans_init_pos_indices = [] |
| 83 | + trans_init_neg_indices = [] |
| 84 | + for i in range(self.batch_size): |
| 85 | + current_label = self.viterbi(scores[:, :lengths[i], i], transition, transition_init) |
| 86 | + current_labels.append(current_label) |
| 87 | + diff_tag = np.subtract(labels[i, :lengths[i]], current_label) |
| 88 | + update_index = np.where(diff_tag != 0)[0] |
| 89 | + update_length = len(update_index) |
| 90 | + if update_length == 0: |
| 91 | + continue |
| 92 | + update_label_pos = np.stack([labels[i, update_index], update_index, i * np.ones([update_length])], axis=-1) |
| 93 | + update_label_neg = np.stack([current_label[update_index], update_index, i * np.ones([update_length])], axis=-1) |
| 94 | + if update_labels_pos is not None: |
| 95 | + np.concatenate((update_labels_pos, update_label_pos)) |
| 96 | + np.concatenate((update_labels_neg, update_label_neg)) |
| 97 | + else: |
| 98 | + update_labels_pos = update_label_pos |
| 99 | + update_labels_neg = update_label_neg |
| 100 | + |
| 101 | + trans_pos_index, trans_neg_index, trans_init_pos, trans_init_neg, update_init = self.generate_transition_update_index( |
| 102 | + labels[i, :lengths[i]], current_labels[i]) |
| 103 | + |
| 104 | + trans_pos_indices.extend(trans_pos_index) |
| 105 | + trans_neg_indices.extend(trans_neg_index) |
| 106 | + |
| 107 | + if update_init: |
| 108 | + trans_init_pos_indices.append(trans_init_pos) |
| 109 | + trans_init_neg_indices.append(trans_init_neg) |
| 110 | + |
| 111 | + if update_labels_pos is not None and update_labels_neg is not None: |
| 112 | + feed_dict = {self.input: characters, self.ll_curr: update_labels_neg, self.ll_corr: update_labels_pos, |
| 113 | + self.trans_curr: trans_neg_indices, self.trans_corr: trans_pos_indices} |
| 114 | + |
| 115 | + if not trans_init_pos_indices: |
| 116 | + sess.run(self.train, feed_dict) |
| 117 | + else: |
| 118 | + feed_dict[self.trans_init_corr] = trans_init_pos_indices |
| 119 | + feed_dict[self.trans_init_curr] = trans_init_neg_indices |
| 120 | + sess.run(self.train_with_init, feed_dict) |
| 121 | + |
| 122 | + def generate_transition_update_index(self, correct_labels, current_labels): |
| 123 | + if correct_labels.shape != current_labels.shape: |
| 124 | + print('sequence length is not equal') |
| 125 | + return None |
| 126 | + |
| 127 | + before_corr = correct_labels[0] |
| 128 | + before_curr = current_labels[0] |
| 129 | + update_init = False |
| 130 | + |
| 131 | + trans_init_pos = None |
| 132 | + trans_init_neg = None |
| 133 | + trans_pos = [] |
| 134 | + trans_neg = [] |
| 135 | + |
| 136 | + if before_corr != before_curr: |
| 137 | + trans_init_pos = [before_corr] |
| 138 | + trans_init_neg = [before_curr] |
| 139 | + update_init = True |
| 140 | + |
| 141 | + for _, (corr_label, curr_label) in enumerate(zip(correct_labels[1:], current_labels[1:])): |
| 142 | + if corr_label != curr_label or before_corr != before_curr: |
| 143 | + trans_pos.append([before_corr, corr_label]) |
| 144 | + trans_neg.append([before_curr, curr_label]) |
| 145 | + before_corr = corr_label |
| 146 | + before_curr = curr_label |
| 147 | + |
| 148 | + return trans_pos, trans_neg, trans_init_pos, trans_init_neg, update_init |
| 149 | + |
| 150 | + def predict(self, sentence: str): |
| 151 | + if self.mode != 'predict': |
| 152 | + raise Exception('mode is not allowed to predict') |
| 153 | + with tf.Session() as sess: |
| 154 | + tf.global_variables_initializer().run() |
| 155 | + tf.train.Saver().restore(save_path=self.model_path, sess=sess) |
| 156 | + input = self.indices2input(self.sentence2indices(sentence)) |
| 157 | + runner = [self.output, self.transition, self.transition_init] |
| 158 | + output, trans, trans_init = sess.run(runner, feed_dict={self.input: input}) |
| 159 | + labels = self.viterbi(output, trans, trans_init) |
| 160 | + return self.tags2words(sentence, labels) |
| 161 | + |
| 162 | + def get_embedding_layer(self) -> tf.Tensor: |
| 163 | + embeddings = self.__get_variable([self.dict_size, self.embed_size], 'embeddings') |
| 164 | + self.params.append(embeddings) |
| 165 | + if self.mode == 'train': |
| 166 | + input_size = [self.batch_size, self.batch_length, self.concat_embed_size] |
| 167 | + layer = tf.reshape(tf.nn.embedding_lookup(embeddings, self.input), input_size) |
| 168 | + else: |
| 169 | + layer = tf.reshape(tf.nn.embedding_lookup(embeddings, self.input), [1, -1, self.concat_embed_size]) |
| 170 | + return layer |
| 171 | + |
| 172 | + def get_mlp_layer(self, layer: tf.Tensor) -> tf.Tensor: |
| 173 | + hidden_weight = self.__get_variable([self.hidden_units, self.concat_embed_size], 'hidden_weight') |
| 174 | + hidden_bias = self.__get_variable([self.hidden_units, 1, 1], 'hidden_bias') |
| 175 | + self.params += [hidden_weight, hidden_bias] |
| 176 | + layer = tf.sigmoid(tf.tensordot(hidden_weight, layer, [[1], [0]]) + hidden_bias) |
| 177 | + return layer |
| 178 | + |
| 179 | + def get_lstm_layer(self, layer: tf.Tensor) -> tf.Tensor: |
| 180 | + lstm = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units) |
| 181 | + lstm_output, lstm_out_state = tf.nn.dynamic_rnn(lstm, layer, dtype=self.dtype) |
| 182 | + self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')] |
| 183 | + return tf.transpose(lstm_output) |
| 184 | + |
| 185 | + def get_gru_layer(self, layer: tf.Tensor) -> tf.Tensor: |
| 186 | + gru = tf.nn.rnn_cell.GRUCell(self.hidden_units) |
| 187 | + gru_output, gru_out_state = tf.nn.dynamic_rnn(gru, layer, dtype=self.dtype) |
| 188 | + self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')] |
| 189 | + return tf.transpose(gru_output) |
| 190 | + |
| 191 | + def get_dropout_layer(self, layer: tf.Tensor) -> tf.Tensor: |
| 192 | + return tf.layers.dropout(layer, self.dropout_rate) |
| 193 | + |
| 194 | + def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor: |
| 195 | + output_weight = self.__get_variable([self.tags_count, self.hidden_units], 'output_weight') |
| 196 | + output_bias = self.__get_variable([self.tags_count, 1, 1], 'output_bias') |
| 197 | + self.params += [output_weight, output_bias] |
| 198 | + return tf.tensordot(output_weight, layer, [[1], [0]]) + output_bias |
| 199 | + |
| 200 | + def get_loss(self) -> (tf.Tensor, tf.Tensor): |
| 201 | + output_loss = tf.reduce_sum(tf.gather_nd(self.output, self.ll_curr) - tf.gather_nd(self.output, self.ll_corr)) |
| 202 | + trans_loss = tf.gather_nd(self.transition, self.trans_curr) - tf.gather_nd(self.transition, self.trans_corr) |
| 203 | + trans_i_curr = tf.gather_nd(self.transition_init, self.trans_init_curr) |
| 204 | + trans_i_corr = tf.gather_nd(self.transition_init, self.trans_init_corr) |
| 205 | + trans_init_loss = tf.reduce_sum(trans_i_curr - trans_i_corr) |
| 206 | + loss = output_loss + trans_loss |
| 207 | + regu = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam), self.params) |
| 208 | + l1 = loss + regu |
| 209 | + l2 = l1 + trans_init_loss |
| 210 | + return l1, l2 |
| 211 | + |
| 212 | + def __get_variable(self, size, name) -> tf.Variable: |
| 213 | + return tf.Variable(tf.truncated_normal(size, stddev=1.0 / math.sqrt(size[-1]), dtype=self.dtype), name=name) |
0 commit comments