@@ -30,6 +30,7 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
3030 self .seq_length = tf .placeholder (tf .int32 , [self .batch_size ])
3131 else :
3232 self .input = tf .placeholder (tf .int32 , [None , self .windows_size ])
33+
3334 # 查找表层
3435 self .embedding_layer = self .get_embedding_layer ()
3536 # 隐藏层
@@ -46,6 +47,9 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
4647
4748 if mode == 'predict' :
4849 self .output = tf .squeeze (tf .transpose (self .output ), axis = 2 )
50+ self .sess = tf .Session ()
51+ self .sess .run (tf .global_variables_initializer ())
52+ tf .train .Saver ().restore (save_path = self .model_path , sess = self .sess )
4953 elif train == 'll' :
5054 self .ll_loss , _ = tf .contrib .crf .crf_log_likelihood (self .output , self .real_indices , self .seq_length ,
5155 self .transition )
@@ -180,17 +184,15 @@ def generate_transition_update_index(self, correct_labels, current_labels):
180184 def predict (self , sentence : str , return_labels = False ):
181185 if self .mode != 'predict' :
182186 raise Exception ('mode is not allowed to predict' )
183- with tf .Session () as sess :
184- tf .global_variables_initializer ().run ()
185- tf .train .Saver ().restore (save_path = self .model_path , sess = sess )
186- input = self .indices2input (self .sentence2indices (sentence ))
187- runner = [self .output , self .transition , self .transition_init ]
188- output , trans , trans_init = sess .run (runner , feed_dict = {self .input : input })
189- labels = self .viterbi (output , trans , trans_init )
190- if not return_labels :
191- return self .tags2words (sentence , labels )
192- else :
193- return self .tags2words (sentence , labels ), labels
187+
188+ input = self .indices2input (self .sentence2indices (sentence ))
189+ runner = [self .output , self .transition , self .transition_init ]
190+ output , trans , trans_init = self .sess .run (runner , feed_dict = {self .input : input })
191+ labels = self .viterbi (output , trans , trans_init )
192+ if not return_labels :
193+ return self .tags2words (sentence , labels )
194+ else :
195+ return self .tags2words (sentence , labels ), self .tag2sequences (labels )
194196
195197 def get_embedding_layer (self ) -> tf .Tensor :
196198 embeddings = self .__get_variable ([self .dict_size , self .embed_size ], 'embeddings' )
0 commit comments