88
99class DnnCrf (DnnCrfBase ):
1010 def __init__ (self , * , config : DnnCrfConfig = None , data_path : str = '' , dtype : type = tf .float32 , mode : str = 'train' ,
11- train : str = 'll' , nn : str , model_path : str = '' ):
11+ predict : str = 'll' , nn : str , model_path : str = '' ):
1212 if mode not in ['train' , 'predict' ]:
1313 raise Exception ('mode error' )
1414 if nn not in ['mlp' , 'rnn' , 'lstm' , 'bilstm' , 'gru' ]:
@@ -27,113 +27,42 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
2727 if mode == 'train' :
2828 self .input = tf .placeholder (tf .int32 , [self .batch_size , self .batch_length , self .windows_size ])
2929 self .real_indices = tf .placeholder (tf .int32 , [self .batch_size , self .batch_length ])
30- self .seq_length = tf .placeholder (tf .int32 , [self .batch_size ])
3130 else :
3231 self .input = tf .placeholder (tf .int32 , [None , self .windows_size ])
3332
33+ self .seq_length = tf .placeholder (tf .int32 , [None ])
34+
3435 # 查找表层
3536 self .embedding_layer = self .get_embedding_layer ()
3637 # 隐藏层
3738 if nn == 'mlp' :
3839 self .hidden_layer = self .get_mlp_layer (tf .transpose (self .embedding_layer ))
3940 elif nn == 'lstm' :
4041 self .hidden_layer = self .get_lstm_layer (self .embedding_layer )
42+ elif nn == 'bilstm' :
43+ self .hidden_layer = self .get_bilstm_layer (self .embedding_layer )
4144 elif nn == 'gru' :
42- self .hidden_layer = self .get_gru_layer (tf . transpose ( self .embedding_layer ) )
45+ self .hidden_layer = self .get_gru_layer (self .embedding_layer )
4346 else :
44- self .hidden_layer = self .get_rnn_layer (tf . transpose ( self .embedding_layer ) )
47+ self .hidden_layer = self .get_rnn_layer (self .embedding_layer )
4548 # 输出层
4649 self .output = self .get_output_layer (self .hidden_layer )
4750
4851 if mode == 'predict' :
49- self .output = tf .squeeze (tf .transpose (self .output ), axis = 2 )
52+ if predict != 'll' :
53+ self .output = tf .squeeze (tf .transpose (self .output ), axis = 2 )
54+ self .seq , self .best_score = tf .contrib .crf .crf_decode (self .output , self .transition , self .seq_length )
5055 self .sess = tf .Session ()
5156 self .sess .run (tf .global_variables_initializer ())
5257 tf .train .Saver ().restore (save_path = self .model_path , sess = self .sess )
53- elif train == 'll' :
54- self .ll_loss , _ = tf .contrib .crf .crf_log_likelihood (self .output , self .real_indices , self .seq_length ,
55- self .transition )
56- self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
57- self .train_ll = self .optimizer .minimize (- self .ll_loss )
5858 else :
59- # 构建训练函数
60- # 训练用placeholder
61- self .ll_corr = tf .placeholder (tf .int32 , shape = [None , 3 ])
62- self .ll_curr = tf .placeholder (tf .int32 , shape = [None , 3 ])
63- self .trans_corr = tf .placeholder (tf .int32 , [None , 2 ])
64- self .trans_curr = tf .placeholder (tf .int32 , [None , 2 ])
65- self .trans_init_corr = tf .placeholder (tf .int32 , [None , 1 ])
66- self .trans_init_curr = tf .placeholder (tf .int32 , [None , 1 ])
67- # 损失函数
68- self .loss , self .loss_with_init = self .get_loss ()
59+ self .loss , _ = tf .contrib .crf .crf_log_likelihood (self .output , self .real_indices , self .seq_length ,
60+ self .transition )
6961 self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
70- self .train = self . optimizer . minimize ( self . loss )
71- self .train_with_init = self .optimizer .minimize (self .loss_with_init )
62+ self .new_optimizer = tf . train . AdamOptimizer ( )
63+ self .train = self .optimizer .minimize (- self .loss )
7264
7365 def fit (self , epochs : int = 100 , interval : int = 20 ):
74- with tf .Session () as sess :
75- tf .global_variables_initializer ().run ()
76- saver = tf .train .Saver (max_to_keep = 100 )
77- for epoch in range (1 , epochs + 1 ):
78- print ('epoch:' , epoch )
79- for _ in range (self .batch_count ):
80- characters , labels , lengths = self .get_batch ()
81- self .fit_batch (characters , labels , lengths , sess )
82- # if epoch % interval == 0:
83- model_path = '../dnlp/models/cws{0}.ckpt' .format (epoch )
84- saver .save (sess , model_path )
85- self .save_config (model_path )
86-
87- def fit_batch (self , characters , labels , lengths , sess ):
88- scores = sess .run (self .output , feed_dict = {self .input : characters })
89- transition = self .transition .eval (session = sess )
90- transition_init = self .transition_init .eval (session = sess )
91- update_labels_pos = None
92- update_labels_neg = None
93- current_labels = []
94- trans_pos_indices = []
95- trans_neg_indices = []
96- trans_init_pos_indices = []
97- trans_init_neg_indices = []
98- for i in range (self .batch_size ):
99- current_label = self .viterbi (scores [:, :lengths [i ], i ], transition , transition_init )
100- current_labels .append (current_label )
101- diff_tag = np .subtract (labels [i , :lengths [i ]], current_label )
102- update_index = np .where (diff_tag != 0 )[0 ]
103- update_length = len (update_index )
104- if update_length == 0 :
105- continue
106- update_label_pos = np .stack ([labels [i , update_index ], update_index , i * np .ones ([update_length ])], axis = - 1 )
107- update_label_neg = np .stack ([current_label [update_index ], update_index , i * np .ones ([update_length ])], axis = - 1 )
108- if update_labels_pos is not None :
109- np .concatenate ((update_labels_pos , update_label_pos ))
110- np .concatenate ((update_labels_neg , update_label_neg ))
111- else :
112- update_labels_pos = update_label_pos
113- update_labels_neg = update_label_neg
114-
115- trans_pos_index , trans_neg_index , trans_init_pos , trans_init_neg , update_init = self .generate_transition_update_index (
116- labels [i , :lengths [i ]], current_labels [i ])
117-
118- trans_pos_indices .extend (trans_pos_index )
119- trans_neg_indices .extend (trans_neg_index )
120-
121- if update_init :
122- trans_init_pos_indices .append (trans_init_pos )
123- trans_init_neg_indices .append (trans_init_neg )
124-
125- if update_labels_pos is not None and update_labels_neg is not None :
126- feed_dict = {self .input : characters , self .ll_curr : update_labels_neg , self .ll_corr : update_labels_pos ,
127- self .trans_curr : trans_neg_indices , self .trans_corr : trans_pos_indices }
128-
129- if not trans_init_pos_indices :
130- sess .run (self .train , feed_dict )
131- else :
132- feed_dict [self .trans_init_corr ] = trans_init_pos_indices
133- feed_dict [self .trans_init_curr ] = trans_init_neg_indices
134- sess .run (self .train_with_init , feed_dict )
135-
136- def fit_ll (self , epochs : int = 100 , interval : int = 20 ):
13766 with tf .Session () as sess :
13867 tf .global_variables_initializer ().run ()
13968 saver = tf .train .Saver (max_to_keep = epochs )
@@ -143,44 +72,13 @@ def fit_ll(self, epochs: int = 100, interval: int = 20):
14372 characters , labels , lengths = self .get_batch ()
14473 # scores = sess.run(self.output, feed_dict={self.input: characters})
14574 feed_dict = {self .input : characters , self .real_indices : labels , self .seq_length : lengths }
146- sess .run (self .train_ll , feed_dict = feed_dict )
75+ sess .run (self .train , feed_dict = feed_dict )
14776 # self.fit_batch(characters, labels, lengths, sess)
14877 # if epoch % interval == 0:
14978 model_path = '../dnlp/models/cws{0}.ckpt' .format (epoch )
15079 saver .save (sess , model_path )
15180 self .save_config (model_path )
15281
153- def fit_batch_ll (self ):
154- pass
155-
156- def generate_transition_update_index (self , correct_labels , current_labels ):
157- if correct_labels .shape != current_labels .shape :
158- print ('sequence length is not equal' )
159- return None
160-
161- before_corr = correct_labels [0 ]
162- before_curr = current_labels [0 ]
163- update_init = False
164-
165- trans_init_pos = None
166- trans_init_neg = None
167- trans_pos = []
168- trans_neg = []
169-
170- if before_corr != before_curr :
171- trans_init_pos = [before_corr ]
172- trans_init_neg = [before_curr ]
173- update_init = True
174-
175- for _ , (corr_label , curr_label ) in enumerate (zip (correct_labels [1 :], current_labels [1 :])):
176- if corr_label != curr_label or before_corr != before_curr :
177- trans_pos .append ([before_corr , corr_label ])
178- trans_neg .append ([before_curr , curr_label ])
179- before_corr = corr_label
180- before_curr = curr_label
181-
182- return trans_pos , trans_neg , trans_init_pos , trans_init_neg , update_init
183-
18482 def predict (self , sentence : str , return_labels = False ):
18583 if self .mode != 'predict' :
18684 raise Exception ('mode is not allowed to predict' )
@@ -194,6 +92,22 @@ def predict(self, sentence: str, return_labels=False):
19492 else :
19593 return self .tags2words (sentence , labels ), self .tag2sequences (labels )
19694
95+ def predict_ll (self , sentence : str , return_labels = False ):
96+ if self .mode != 'predict' :
97+ raise Exception ('mode is not allowed to predict' )
98+
99+ input = self .indices2input (self .sentence2indices (sentence ))
100+ runner = [self .seq , self .best_score , self .output , self .transition ]
101+ labels , best_score , output , trans = self .sess .run (runner ,
102+ feed_dict = {self .input : input , self .seq_length : [len (sentence )]})
103+ # print(output)
104+ # print(trans)
105+ labels = np .squeeze (labels )
106+ if return_labels :
107+ return self .tags2words (sentence , labels ), self .tag2sequences (labels )
108+ else :
109+ return self .tags2words (sentence , labels )
110+
197111 def get_embedding_layer (self ) -> tf .Tensor :
198112 embeddings = self .__get_variable ([self .dict_size , self .embed_size ], 'embeddings' )
199113 self .params .append (embeddings )
@@ -215,19 +129,27 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
215129 rnn = tf .nn .rnn_cell .RNNCell (self .hidden_units )
216130 rnn_output , rnn_out_state = tf .nn .dynamic_rnn (rnn , layer , dtype = self .dtype )
217131 self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
218- return tf . transpose ( rnn_output )
132+ return rnn_output
219133
220134 def get_lstm_layer (self , layer : tf .Tensor ) -> tf .Tensor :
221135 lstm = tf .nn .rnn_cell .LSTMCell (self .hidden_units )
222136 lstm_output , lstm_out_state = tf .nn .dynamic_rnn (lstm , layer , dtype = self .dtype )
223137 self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
224138 return lstm_output
225139
140+ def get_bilstm_layer (self , layer : tf .Tensor ) -> tf .Tensor :
141+ lstm_fw = tf .nn .rnn_cell .LSTMCell (self .hidden_units // 2 )
142+ lstm_bw = tf .nn .rnn_cell .LSTMCell (self .hidden_units // 2 )
143+ bilstm_output , bilstm_output_state = tf .nn .bidirectional_dynamic_rnn (lstm_fw , lstm_bw , layer , self .seq_length ,
144+ dtype = self .dtype )
145+ self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
146+ return tf .concat ([bilstm_output [0 ],bilstm_output [1 ]],- 1 )
147+
226148 def get_gru_layer (self , layer : tf .Tensor ) -> tf .Tensor :
227149 gru = tf .nn .rnn_cell .GRUCell (self .hidden_units )
228150 gru_output , gru_out_state = tf .nn .dynamic_rnn (gru , layer , dtype = self .dtype )
229151 self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
230- return tf . transpose ( gru_output )
152+ return gru_output
231153
232154 def get_dropout_layer (self , layer : tf .Tensor ) -> tf .Tensor :
233155 return tf .layers .dropout (layer , self .dropout_rate )
@@ -238,17 +160,5 @@ def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor:
238160 self .params += [output_weight , output_bias ]
239161 return tf .tensordot (layer , output_weight , [[2 ], [0 ]]) + output_bias
240162
241- def get_loss (self ) -> (tf .Tensor , tf .Tensor ):
242- output_loss = tf .reduce_sum (tf .gather_nd (self .output , self .ll_curr ) - tf .gather_nd (self .output , self .ll_corr ))
243- trans_loss = tf .gather_nd (self .transition , self .trans_curr ) - tf .gather_nd (self .transition , self .trans_corr )
244- trans_i_curr = tf .gather_nd (self .transition_init , self .trans_init_curr )
245- trans_i_corr = tf .gather_nd (self .transition_init , self .trans_init_corr )
246- trans_init_loss = tf .reduce_sum (trans_i_curr - trans_i_corr )
247- loss = output_loss + trans_loss
248- regu = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ), self .params )
249- l1 = loss + regu
250- l2 = l1 + trans_init_loss
251- return l1 , l2
252-
253163 def __get_variable (self , size , name ) -> tf .Variable :
254164 return tf .Variable (tf .truncated_normal (size , stddev = 1.0 / math .sqrt (size [- 1 ]), dtype = self .dtype ), name = name )
0 commit comments