88
99class DnnCrf (DnnCrfBase ):
1010 def __init__ (self , * , config : DnnCrfConfig = None , data_path : str = '' , dtype : type = tf .float32 , mode : str = 'train' ,
11- train :str = 'll' ,nn : str , model_path : str = '' ):
11+ train : str = 'll' , nn : str , model_path : str = '' ):
1212 if mode not in ['train' , 'predict' ]:
1313 raise Exception ('mode error' )
14- if nn not in ['mlp' , 'rnn' , 'lstm' , 'gru' ]:
14+ if nn not in ['mlp' , 'rnn' , 'lstm' , 'bilstm' , ' gru' ]:
1515 raise Exception ('name of neural network entered is not supported' )
1616
1717 DnnCrfBase .__init__ (self , config , data_path , mode , model_path )
@@ -66,7 +66,6 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
6666 self .train = self .optimizer .minimize (self .loss )
6767 self .train_with_init = self .optimizer .minimize (self .loss_with_init )
6868
69-
7069 def fit (self , epochs : int = 100 , interval : int = 20 ):
7170 with tf .Session () as sess :
7271 tf .global_variables_initializer ().run ()
@@ -130,16 +129,16 @@ def fit_batch(self, characters, labels, lengths, sess):
130129 feed_dict [self .trans_init_curr ] = trans_init_neg_indices
131130 sess .run (self .train_with_init , feed_dict )
132131
133- def fit_ll (self ,epochs : int = 100 , interval : int = 20 ):
132+ def fit_ll (self , epochs : int = 100 , interval : int = 20 ):
134133 with tf .Session () as sess :
135134 tf .global_variables_initializer ().run ()
136135 saver = tf .train .Saver (max_to_keep = epochs )
137136 for epoch in range (1 , epochs + 1 ):
138137 print ('epoch:' , epoch )
139138 for _ in range (self .batch_count ):
140139 characters , labels , lengths = self .get_batch ()
141- #scores = sess.run(self.output, feed_dict={self.input: characters})
142- feed_dict = {self .input : characters , self .real_indices :labels , self .seq_length :lengths }
140+ # scores = sess.run(self.output, feed_dict={self.input: characters})
141+ feed_dict = {self .input : characters , self .real_indices : labels , self .seq_length : lengths }
143142 sess .run (self .train_ll , feed_dict = feed_dict )
144143 # self.fit_batch(characters, labels, lengths, sess)
145144 # if epoch % interval == 0:
@@ -178,7 +177,7 @@ def generate_transition_update_index(self, correct_labels, current_labels):
178177
179178 return trans_pos , trans_neg , trans_init_pos , trans_init_neg , update_init
180179
181- def predict (self , sentence : str ):
180+ def predict (self , sentence : str , return_labels = False ):
182181 if self .mode != 'predict' :
183182 raise Exception ('mode is not allowed to predict' )
184183 with tf .Session () as sess :
@@ -188,7 +187,10 @@ def predict(self, sentence: str):
188187 runner = [self .output , self .transition , self .transition_init ]
189188 output , trans , trans_init = sess .run (runner , feed_dict = {self .input : input })
190189 labels = self .viterbi (output , trans , trans_init )
191- return self .tags2words (sentence , labels )
190+ if not return_labels :
191+ return self .tags2words (sentence , labels )
192+ else :
193+ return self .tags2words (sentence , labels ), labels
192194
193195 def get_embedding_layer (self ) -> tf .Tensor :
194196 embeddings = self .__get_variable ([self .dict_size , self .embed_size ], 'embeddings' )
@@ -229,10 +231,10 @@ def get_dropout_layer(self, layer: tf.Tensor) -> tf.Tensor:
229231 return tf .layers .dropout (layer , self .dropout_rate )
230232
231233 def get_output_layer (self , layer : tf .Tensor ) -> tf .Tensor :
232- output_weight = self .__get_variable ([self .hidden_units ,self .tags_count ], 'output_weight' )
233- output_bias = self .__get_variable ([1 , 1 , self .tags_count ], 'output_bias' )
234+ output_weight = self .__get_variable ([self .hidden_units , self .tags_count ], 'output_weight' )
235+ output_bias = self .__get_variable ([1 , 1 , self .tags_count ], 'output_bias' )
234236 self .params += [output_weight , output_bias ]
235- return tf .tensordot ( layer ,output_weight , [[2 ], [0 ]]) + output_bias
237+ return tf .tensordot (layer , output_weight , [[2 ], [0 ]]) + output_bias
236238
237239 def get_loss (self ) -> (tf .Tensor , tf .Tensor ):
238240 output_loss = tf .reduce_sum (tf .gather_nd (self .output , self .ll_curr ) - tf .gather_nd (self .output , self .ll_corr ))
0 commit comments