File tree Expand file tree Collapse file tree 2 files changed +23
-8
lines changed
Expand file tree Collapse file tree 2 files changed +23
-8
lines changed Original file line number Diff line number Diff line change 11# -*- coding: UTF-8 -*-
22import pickle
3+ from sklearn .metrics import f1_score ,precision_score ,recall_score
34from dnlp .utils .constant import TAG_BEGIN , TAG_INSIDE , TAG_END , TAG_SINGLE
45
56
@@ -75,13 +76,26 @@ def evaluate_cws(model, data_path: str):
7576 characters = data ['characters' ]
7677 labels_true = data ['labels' ]
7778 c_count = 0
79+
7880 p_count = 0
81+
7982 r_count = 0
83+
84+ all_labels_true = []
85+ all_labels_predict = []
8086 for sentence , label in zip (characters , labels_true ):
81- words , labels_predict = model .predict (sentence , return_labels = True )
87+ words , labels_predict = model .predict_ll (sentence , return_labels = True )
88+ #print("============")
89+ #print(words)
90+ all_labels_predict .extend (labels_predict )
91+ all_labels_true .extend (label )
8292 c , p , r = get_cws_statistics (label , labels_predict )
8393 c_count += c
8494 p_count += p
8595 r_count += r
8696 print (c_count / p_count )
8797 print (c_count / r_count )
98+ average = 'macro'
99+ print (precision_score (all_labels_true ,all_labels_predict ,average = average ))
100+ print (recall_score (all_labels_true ,all_labels_predict ,average = average ))
101+
Original file line number Diff line number Diff line change 77
88
99def train_cws ():
10- data_path = '../dnlp/data/cws/pku_training .pickle'
10+ data_path = '../dnlp/data/cws/msr_training .pickle'
1111 config = DnnCrfConfig ()
12- dnncrf = DnnCrf (config = config , data_path = data_path , nn = 'lstm ' )
13- dnncrf .fit_ll ()
12+ dnncrf = DnnCrf (config = config , data_path = data_path , nn = 'bilstm ' )
13+ dnncrf .fit ()
1414
1515
1616def test_cws ():
1717 sentence = '小明来自南京师范大学'
18- model_path = '../dnlp/models/cws1.ckpt'
18+ sentence = '中国人民决心继承邓小平同志的遗志,继续把建设有中国特色社会主义事业推向前进。'
19+ model_path = '../dnlp/models/cws4.ckpt'
1920 config = DnnCrfConfig ()
20- dnncrf = DnnCrf (config = config , mode = 'predict' , model_path = model_path , nn = 'lstm ' )
21- res , labels = dnncrf .predict (sentence , return_labels = True )
21+ dnncrf = DnnCrf (config = config , mode = 'predict' , model_path = model_path , nn = 'bilstm ' )
22+ res , labels = dnncrf .predict_ll (sentence , return_labels = True )
2223 print (res )
23- evaluate_cws (dnncrf , '../dnlp/data/cws/pku_test .pickle' )
24+ evaluate_cws (dnncrf , '../dnlp/data/cws/msr_test .pickle' )
2425
2526
2627if __name__ == '__main__' :
You can’t perform that action at this time.
0 commit comments