Skip to content

Commit e146761

Browse files
add core algorithm implementation package
1 parent 9edbaf0 commit e146761

File tree

5 files changed

+389
-0
lines changed

5 files changed

+389
-0
lines changed

python/dnlp/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#-*- coding: UTF-8 -*-

python/dnlp/core/dnn_crf.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# -*- coding: UTF-8 -*-
2+
import tensorflow as tf
3+
import numpy as np
4+
import math
5+
from dnlp.core.dnn_crf_base import DnnCrfBase
6+
from dnlp.config.config import DnnCrfConfig
7+
8+
9+
class DnnCrf(DnnCrfBase):
10+
def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf.float32, mode: str = 'train',nn:str, model_path:str=''):
11+
if mode not in ['train', 'predict']:
12+
raise Exception('mode error')
13+
if nn not in ['mlp','lstm','gru']:
14+
raise Exception('name of neural network entered is not supported')
15+
16+
DnnCrfBase.__init__(self, config, data_path, mode, model_path)
17+
self.dtype = dtype
18+
self.mode = mode
19+
20+
# 构建
21+
tf.reset_default_graph()
22+
self.transition = self.__get_variable([self.tags_count, self.tags_count], 'transition')
23+
self.transition_init = self.__get_variable([self.tags_count], 'transition_init')
24+
self.params = [self.transition, self.transition_init]
25+
# 输入层
26+
if mode == 'train':
27+
self.input = tf.placeholder(tf.int32, [self.batch_size, self.batch_length, self.windows_size])
28+
else:
29+
self.input = tf.placeholder(tf.int32, [None, self.windows_size])
30+
# 查找表层
31+
self.embedding_layer = self.get_embedding_layer()
32+
# 隐藏层
33+
if nn == 'mlp':
34+
self.hidden_layer = self.get_mlp_layer(tf.transpose(self.embedding_layer))
35+
elif nn == 'lstm':
36+
self.hidden_layer = self.get_lstm_layer(tf.transpose(self.embedding_layer))
37+
else:
38+
self.hidden_layer = self.get_gru_layer(tf.transpose(self.embedding_layer))
39+
# 输出层
40+
self.output = self.get_output_layer(self.hidden_layer)
41+
42+
if mode == 'predict':
43+
self.output = tf.squeeze(self.output, axis=2)
44+
else:
45+
# 构建训练函数
46+
# 训练用placeholder
47+
self.ll_corr = tf.placeholder(tf.int32, shape=[None, 3])
48+
self.ll_curr = tf.placeholder(tf.int32, shape=[None, 3])
49+
self.trans_corr = tf.placeholder(tf.int32, [None, 2])
50+
self.trans_curr = tf.placeholder(tf.int32, [None, 2])
51+
self.trans_init_corr = tf.placeholder(tf.int32, [None, 1])
52+
self.trans_init_curr = tf.placeholder(tf.int32, [None, 1])
53+
# 损失函数
54+
self.loss, self.loss_with_init = self.get_loss()
55+
self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
56+
self.train = self.optimizer.minimize(self.loss)
57+
self.train_with_init = self.optimizer.minimize(self.loss_with_init)
58+
59+
def fit(self, epochs: int = 100, interval: int = 20):
60+
with tf.Session() as sess:
61+
tf.global_variables_initializer().run()
62+
saver = tf.train.Saver(max_to_keep=100)
63+
for epoch in range(1, epochs + 1):
64+
print('epoch:', epoch)
65+
for _ in range(self.batch_count):
66+
characters, labels, lengths = self.get_batch()
67+
self.fit_batch(characters, labels, lengths, sess)
68+
# if epoch % interval == 0:
69+
model_path = '../dnlp/models/cws{0}.ckpt'.format(epoch)
70+
saver.save(sess, model_path)
71+
self.save_config(model_path)
72+
73+
def fit_batch(self, characters, labels, lengths, sess):
74+
scores = sess.run(self.output, feed_dict={self.input: characters})
75+
transition = self.transition.eval(session=sess)
76+
transition_init = self.transition_init.eval(session=sess)
77+
update_labels_pos = None
78+
update_labels_neg = None
79+
current_labels = []
80+
trans_pos_indices = []
81+
trans_neg_indices = []
82+
trans_init_pos_indices = []
83+
trans_init_neg_indices = []
84+
for i in range(self.batch_size):
85+
current_label = self.viterbi(scores[:, :lengths[i], i], transition, transition_init)
86+
current_labels.append(current_label)
87+
diff_tag = np.subtract(labels[i, :lengths[i]], current_label)
88+
update_index = np.where(diff_tag != 0)[0]
89+
update_length = len(update_index)
90+
if update_length == 0:
91+
continue
92+
update_label_pos = np.stack([labels[i, update_index], update_index, i * np.ones([update_length])], axis=-1)
93+
update_label_neg = np.stack([current_label[update_index], update_index, i * np.ones([update_length])], axis=-1)
94+
if update_labels_pos is not None:
95+
np.concatenate((update_labels_pos, update_label_pos))
96+
np.concatenate((update_labels_neg, update_label_neg))
97+
else:
98+
update_labels_pos = update_label_pos
99+
update_labels_neg = update_label_neg
100+
101+
trans_pos_index, trans_neg_index, trans_init_pos, trans_init_neg, update_init = self.generate_transition_update_index(
102+
labels[i, :lengths[i]], current_labels[i])
103+
104+
trans_pos_indices.extend(trans_pos_index)
105+
trans_neg_indices.extend(trans_neg_index)
106+
107+
if update_init:
108+
trans_init_pos_indices.append(trans_init_pos)
109+
trans_init_neg_indices.append(trans_init_neg)
110+
111+
if update_labels_pos is not None and update_labels_neg is not None:
112+
feed_dict = {self.input: characters, self.ll_curr: update_labels_neg, self.ll_corr: update_labels_pos,
113+
self.trans_curr: trans_neg_indices, self.trans_corr: trans_pos_indices}
114+
115+
if not trans_init_pos_indices:
116+
sess.run(self.train, feed_dict)
117+
else:
118+
feed_dict[self.trans_init_corr] = trans_init_pos_indices
119+
feed_dict[self.trans_init_curr] = trans_init_neg_indices
120+
sess.run(self.train_with_init, feed_dict)
121+
122+
def generate_transition_update_index(self, correct_labels, current_labels):
123+
if correct_labels.shape != current_labels.shape:
124+
print('sequence length is not equal')
125+
return None
126+
127+
before_corr = correct_labels[0]
128+
before_curr = current_labels[0]
129+
update_init = False
130+
131+
trans_init_pos = None
132+
trans_init_neg = None
133+
trans_pos = []
134+
trans_neg = []
135+
136+
if before_corr != before_curr:
137+
trans_init_pos = [before_corr]
138+
trans_init_neg = [before_curr]
139+
update_init = True
140+
141+
for _, (corr_label, curr_label) in enumerate(zip(correct_labels[1:], current_labels[1:])):
142+
if corr_label != curr_label or before_corr != before_curr:
143+
trans_pos.append([before_corr, corr_label])
144+
trans_neg.append([before_curr, curr_label])
145+
before_corr = corr_label
146+
before_curr = curr_label
147+
148+
return trans_pos, trans_neg, trans_init_pos, trans_init_neg, update_init
149+
150+
def predict(self, sentence: str):
151+
if self.mode != 'predict':
152+
raise Exception('mode is not allowed to predict')
153+
with tf.Session() as sess:
154+
tf.global_variables_initializer().run()
155+
tf.train.Saver().restore(save_path=self.model_path, sess=sess)
156+
input = self.indices2input(self.sentence2indices(sentence))
157+
runner = [self.output, self.transition, self.transition_init]
158+
output, trans, trans_init = sess.run(runner, feed_dict={self.input: input})
159+
labels = self.viterbi(output, trans, trans_init)
160+
return self.tags2words(sentence, labels)
161+
162+
def get_embedding_layer(self) -> tf.Tensor:
163+
embeddings = self.__get_variable([self.dict_size, self.embed_size], 'embeddings')
164+
self.params.append(embeddings)
165+
if self.mode == 'train':
166+
input_size = [self.batch_size, self.batch_length, self.concat_embed_size]
167+
layer = tf.reshape(tf.nn.embedding_lookup(embeddings, self.input), input_size)
168+
else:
169+
layer = tf.reshape(tf.nn.embedding_lookup(embeddings, self.input), [1, -1, self.concat_embed_size])
170+
return layer
171+
172+
def get_mlp_layer(self, layer: tf.Tensor) -> tf.Tensor:
173+
hidden_weight = self.__get_variable([self.hidden_units, self.concat_embed_size], 'hidden_weight')
174+
hidden_bias = self.__get_variable([self.hidden_units, 1, 1], 'hidden_bias')
175+
self.params += [hidden_weight, hidden_bias]
176+
layer = tf.sigmoid(tf.tensordot(hidden_weight, layer, [[1], [0]]) + hidden_bias)
177+
return layer
178+
179+
def get_lstm_layer(self, layer: tf.Tensor) -> tf.Tensor:
180+
lstm = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units)
181+
lstm_output, lstm_out_state = tf.nn.dynamic_rnn(lstm, layer, dtype=self.dtype)
182+
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
183+
return tf.transpose(lstm_output)
184+
185+
def get_gru_layer(self, layer: tf.Tensor) -> tf.Tensor:
186+
gru = tf.nn.rnn_cell.GRUCell(self.hidden_units)
187+
gru_output, gru_out_state = tf.nn.dynamic_rnn(gru, layer, dtype=self.dtype)
188+
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
189+
return tf.transpose(gru_output)
190+
191+
def get_dropout_layer(self, layer: tf.Tensor) -> tf.Tensor:
192+
return tf.layers.dropout(layer, self.dropout_rate)
193+
194+
def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor:
195+
output_weight = self.__get_variable([self.tags_count, self.hidden_units], 'output_weight')
196+
output_bias = self.__get_variable([self.tags_count, 1, 1], 'output_bias')
197+
self.params += [output_weight, output_bias]
198+
return tf.tensordot(output_weight, layer, [[1], [0]]) + output_bias
199+
200+
def get_loss(self) -> (tf.Tensor, tf.Tensor):
201+
output_loss = tf.reduce_sum(tf.gather_nd(self.output, self.ll_curr) - tf.gather_nd(self.output, self.ll_corr))
202+
trans_loss = tf.gather_nd(self.transition, self.trans_curr) - tf.gather_nd(self.transition, self.trans_corr)
203+
trans_i_curr = tf.gather_nd(self.transition_init, self.trans_init_curr)
204+
trans_i_corr = tf.gather_nd(self.transition_init, self.trans_init_corr)
205+
trans_init_loss = tf.reduce_sum(trans_i_curr - trans_i_corr)
206+
loss = output_loss + trans_loss
207+
regu = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam), self.params)
208+
l1 = loss + regu
209+
l2 = l1 + trans_init_loss
210+
return l1, l2
211+
212+
def __get_variable(self, size, name) -> tf.Variable:
213+
return tf.Variable(tf.truncated_normal(size, stddev=1.0 / math.sqrt(size[-1]), dtype=self.dtype), name=name)

python/dnlp/core/dnn_crf_base.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# -*- coding: UTF-8 -*-
2+
import numpy as np
3+
import pickle
4+
from dnlp.config.config import DnnCrfConfig
5+
from dnlp.utils.constant import BATCH_PAD, STRT_VAL, END_VAL, TAG_PAD, TAG_BEGIN, TAG_INSIDE, TAG_SINGLE
6+
7+
8+
class DnnCrfBase(object):
9+
def __init__(self, config: DnnCrfConfig, data_path: str = '', mode: str = 'train', model_path: str = ''):
10+
# 加载数据
11+
self.data_path = data_path
12+
self.config_suffix = '.config.pickle'
13+
if mode == 'train':
14+
self.dictionary, self.tags, self.characters, self.labels = self.__load_data()
15+
else:
16+
self.model_path = model_path
17+
self.config_path = self.model_path + self.config_suffix
18+
self.dictionary, self.tags = self.__load_config()
19+
self.tags_count = len(self.tags) - 1 # 忽略TAG_PAD
20+
self.tags_map = self.__generate_tag_map()
21+
self.dict_size = len(self.dictionary)
22+
# 初始化超参数
23+
self.skip_left = config.skip_left
24+
self.skip_right = config.skip_right
25+
self.embed_size = config.embed_size
26+
self.hidden_units = config.hidden_units
27+
self.learning_rate = config.learning_rate
28+
self.lam = config.lam
29+
self.dropout_rate = config.dropout_rate
30+
self.windows_size = self.skip_left + self.skip_right + 1
31+
self.concat_embed_size = self.embed_size * self.windows_size
32+
self.batch_length = config.batch_length
33+
self.batch_size = config.batch_size
34+
# 数据
35+
if mode == 'train':
36+
self.sentences_length = list(map(lambda s: len(s), self.characters))
37+
self.sentences_count = len(self.sentences_length)
38+
self.batch_count = self.sentences_count // self.batch_size
39+
self.batch_start = 0
40+
41+
def __load_data(self) -> (dict, tuple, np.ndarray, np.ndarray):
42+
with open(self.data_path, 'rb') as f:
43+
data = pickle.load(f)
44+
return data['dictionary'], data['tags'], data['characters'], data['labels']
45+
46+
def __load_config(self) -> (dict, tuple):
47+
with open(self.config_path, 'rb') as cf:
48+
config = pickle.load(cf)
49+
return config['dictionary'], config['tags']
50+
51+
def save_config(self, model_path: str):
52+
config_path = model_path + self.config_suffix
53+
config = {}
54+
config['dictionary'] = self.dictionary
55+
config['tags'] = self.tags
56+
with open(config_path, 'wb') as cf:
57+
pickle.dump(config, cf)
58+
59+
def __generate_tag_map(self):
60+
tags_map = {}
61+
for i in range(len(self.tags)):
62+
tags_map[self.tags[i]] = i
63+
return tags_map
64+
65+
def get_batch(self) -> (np.ndarray, np.ndarray, np.ndarray):
66+
if self.batch_start + self.batch_size > self.sentences_count:
67+
new_start = self.batch_start + self.batch_size - self.sentences_count
68+
chs_batch = self.characters[self.batch_start:] + self.characters[:new_start]
69+
lls_batch = self.labels[self.batch_start:] + self.labels[:new_start]
70+
len_batch = self.sentences_length[self.batch_start:] + self.sentences_length[:new_start]
71+
else:
72+
new_start = self.batch_start + self.batch_size
73+
chs_batch = self.characters[self.batch_start:new_start]
74+
lls_batch = self.labels[self.batch_start:new_start]
75+
len_batch = self.sentences_length[self.batch_start:new_start]
76+
for i, (chs, lls) in enumerate(zip(chs_batch, lls_batch)):
77+
if len(chs) > self.batch_length:
78+
chs_batch[i] = chs[:self.batch_length]
79+
lls_batch[i] = list(map(lambda t: self.tags_map[t], lls[:self.batch_length]))
80+
len_batch[i] = self.batch_length
81+
else:
82+
ext_size = self.batch_length - len(chs)
83+
chs_batch[i] = chs + ext_size * [self.dictionary[BATCH_PAD]]
84+
lls_batch[i] = list(map(lambda t: self.tags_map[t], lls)) + ext_size * [self.tags_map[TAG_PAD]]
85+
86+
self.batch_start = new_start
87+
return self.indices2input(chs_batch), np.array(lls_batch, dtype=np.int32), np.array(len_batch, dtype=np.int32)
88+
89+
def viterbi(self, emission: np.ndarray, transition: np.ndarray, transition_init: np.ndarray):
90+
length = emission.shape[1]
91+
path = np.ones([self.tags_count, length], dtype=np.int32) * -1
92+
corr_path = np.zeros([length], dtype=np.int32)
93+
path_score = np.ones([self.tags_count, length], dtype=np.float64) * (np.finfo('f').min)
94+
path_score[:, 0] = transition_init + emission[:, 0]
95+
96+
for pos in range(1, length):
97+
for t in range(self.tags_count):
98+
for prev in range(self.tags_count):
99+
temp = path_score[prev][pos - 1] + transition[prev][t] + emission[t][pos]
100+
if temp >= path_score[t][pos]:
101+
path[t][pos] = prev
102+
path_score[t][pos] = temp
103+
104+
max_index = np.argmax(path_score[:, -1])
105+
corr_path[length - 1] = max_index
106+
for i in range(length - 1, 0, -1):
107+
max_index = path[max_index][i]
108+
corr_path[i - 1] = max_index
109+
110+
return corr_path
111+
112+
def sentence2indices(self, sentence: str) -> list:
113+
return list(map(lambda ch: self.dictionary[ch], sentence))
114+
115+
def indices2input(self, indices: list) -> np.ndarray:
116+
res = []
117+
if isinstance(indices[0], list):
118+
for idx in indices:
119+
res.append(self.__indices2input_single(idx))
120+
else:
121+
res = self.__indices2input_single(indices)
122+
123+
return np.array(res, np.int32)
124+
125+
def __indices2input_single(self, indices: list) -> list:
126+
ext_indices = [STRT_VAL] * self.skip_left
127+
ext_indices.extend(indices + [END_VAL] * self.skip_right)
128+
seq = []
129+
for index in range(self.skip_left, len(ext_indices) - self.skip_right):
130+
seq.append(ext_indices[index - self.skip_left: index + self.skip_right + 1])
131+
132+
return seq
133+
134+
def tags2words(self, sentence: str, tags_seq: np.ndarray) -> list:
135+
words = []
136+
word = ''
137+
for tag_index, tag in enumerate(tags_seq):
138+
if tag == self.tags_map[TAG_SINGLE]:
139+
words.append(sentence[tag_index])
140+
elif tag == self.tags_map[TAG_BEGIN]:
141+
word = sentence[tag_index]
142+
elif tag == self.tags_map[TAG_INSIDE]:
143+
word += sentence[tag_index]
144+
else:
145+
words.append(word + sentence[tag_index])
146+
word = ''
147+
# 处理最后一个标记为I的情况
148+
if word != '':
149+
words.append(word)
150+
151+
return words
152+
153+
def tags2entities(self, sentence: str, tags_seq: np.ndarray, return_start: bool = False):
154+
entities = []
155+
entity_starts = []
156+
entity = ''
157+
158+
for tag_index, tag in enumerate(tags_seq):
159+
if tag == 0:
160+
continue
161+
elif tag == 1:
162+
if entity:
163+
entities.append(entity)
164+
entity = sentence[tag_index]
165+
entity_starts.append(tag_index)
166+
else:
167+
entity += sentence[tag_index]
168+
if entity != '':
169+
entities.append(entity)
170+
if return_start:
171+
return entities, entity_starts
172+
else:
173+
return entities

python/dnlp/core/mmtnn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#-*- coding: UTF-8 -*-

python/dnlp/core/re_cnn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#-*- coding: UTF-8 -*-

0 commit comments

Comments
 (0)