-
Notifications
You must be signed in to change notification settings - Fork 53
/
wechat_utils.py
469 lines (462 loc) · 24.3 KB
/
wechat_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 7 12:44:30 2017
@author: Quantum Liu
"""
from keras import __version__ as kv
kv=int(kv[0])
import platform
pv=int(platform.python_version()[0])
import numpy as np
import scipy.io as sio
import itchat
from keras.callbacks import Callback
import time
import matplotlib
matplotlib.use('Agg') #
import matplotlib.pyplot as plt
from math import ceil
from itchat.content import TEXT
if pv>2:
import _thread as th
else:
import thread as th
import os
from os import system
import re
import traceback
import platform
from requests.exceptions import ConnectionError
#==============================================================================
#==============================================================================
# A log in function call it at first
#函数,需要首先调用
#==============================================================================
def login():
if 'Windows' in platform.system():
itchat.auto_login(enableCmdQR=1,hotReload=True)#
else:
itchat.auto_login(enableCmdQR=2,hotReload=True)#
itchat.dump_login_status()#dump
#==============================================================================
#
#==============================================================================
def send_text(text):
#send text msgs to 'filehelper'
#给文件助手发送文本信息
try:
itchat.send_msg(msg=text,toUserName='filehelper')
return
except (ConnectionError,NotImplementedError,KeyError):
traceback.print_exc()
print('\nConection error,failed to send the message!\n')
return
else:
return
def send_img(filename):
#send text imgs to 'filehelper'
#给文件助手发送
try:
itchat.send_image(filename,toUserName='filehelper')
return
except (ConnectionError,NotImplementedError,KeyError):
traceback.print_exc()
print('\nConection error,failed to send the figure!\n')
return
else:
return
#==============================================================================
#
#==============================================================================
class sendmessage(Callback):
#A subclss of keras.callbacks.Callback class
#keras.callbacks.Callback class的子类
def __init__(self,savelog=True,fexten=''):
self.fexten=(fexten if fexten else '')#the name of log and figure files
self.savelog=bool(savelog)#save log or not
def t_send(self,msg,toUserName='filehelper'):
try:
itchat.send_msg(msg=msg,toUserName=toUserName)
return
except (ConnectionError,NotImplementedError,KeyError):
traceback.print_exc()
print('\nConection error,failed to send the message!\n')
return
else:
return
def t_send_img(self,filename,toUserName='filehelper'):
try:
itchat.send_image(filename,toUserName=toUserName)
return
except (ConnectionError,NotImplementedError,KeyError):
traceback.print_exc()
print('\nConection error,failed to send the figure!\n')
return
else:
return
def shutdown(self,sec,save=True,filepath='temp.h5'):
#Function used to shut down the computer
#sec:waitting time to shut down the computer,sencond
#save:wether saving the model
#filepath:the filepath for saving the model
#关机函数
#sec:关机等待秒数
#save:是否保存模型
#filepath:保存模型的文件名
if save:
self.model.save(filepath, overwrite=True)
self.t_send('Command accepted,the model has already been saved,shutting down the computer....', toUserName='filehelper')
else:
self.t_send('Command accepted,shutting down the computer....', toUserName='filehelper')
if 'Windows' in platform.system():
th.start_new_thread(system, ('shutdown -s -t %d' %sec,))
else:
m=(int(sec/60) if int(sec/60) else 1)
th.start_new_thread(system, ('shutdown -h -t %d' %m,))
#==============================================================================
#
#==============================================================================
def cancel(self):
#Cancel function to cancel shutting down the computer
#取消关机函数
self.t_send('Command accepted,cancel shutting down the computer....', toUserName='filehelper')
if 'Windows' in platform.system():
th.start_new_thread(system, ('shutdown -a',))
else:
th.start_new_thread(system, ('shutdown -c',))
#==============================================================================
#
#==============================================================================
def GetMiddleStr(self,content,startStr,endStr):
#get the string between two specified strings
#从指定的字符串之间截取字符串
try:
startIndex = content.index(startStr)
if startIndex>=0:
startIndex += len(startStr)
endIndex = content.index(endStr)
return content[startIndex:endIndex]
except:
return ''
#==============================================================================
#
#==============================================================================
def validateTitle(self,title):
#transform a string to a validate filename
#将字符串转化为合法文件名
rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/\:*?"<>|'
new_title = re.sub(rstr, "", title).replace(' ','')
return new_title
#==============================================================================
#
#==============================================================================
def prog(self):#Show progress
nb_batches_total=(self.params['nb_epoch'] if not kv-1 else self.params['epochs'])*self.params['nb_sample']/self.params['batch_size']
nb_batches_epoch=self.params['nb_sample']/self.params['batch_size']
prog_total=(self.t_batches/nb_batches_total if nb_batches_total else 0)+0.01
prog_epoch=(self.c_batches/nb_batches_epoch if nb_batches_epoch else 0)+0.01
if self.t_epochs:
now=time.time()
t_mean=float(sum(self.t_epochs)) / len(self.t_epochs)
eta_t=(now-self.train_start)*((1/prog_total)-1)
eta_e=t_mean*(1-prog_epoch)
t_end=time.asctime(time.localtime(now+eta_t))
e_end=time.asctime(time.localtime(now+eta_e))
m='\nTotal:\nProg:'+str(prog_total*100.)[:5]+'%\nEpoch:'+str(self.epoch[-1])+'/'+str(self.stopped_epoch)+'\nETA:'+str(eta_t)[:8]+'sec\nTrain will be finished at '+t_end+'\nCurrent epoch:\nPROG:'+str(prog_epoch*100.)[:5]+'%\nETA:'+str(eta_e)[:8]+'sec\nCurrent epoch will be finished at '+e_end
self.t_send(msg=m)
print(m)
else:
now=time.time()
eta_t=(now-self.train_start)*((1/prog_total)-1)
eta_e=(now-self.train_start)*((1/prog_epoch)-1)
t_end=time.asctime(time.localtime(now+eta_t))
e_end=time.asctime(time.localtime(now+eta_e))
m='\nTotal:\nProg:'+str(prog_total*100.)[:5]+'%\nEpoch:'+str(len(self.epoch))+'/'+str(self.stopped_epoch)+'\nETA:'+str(eta_t)[:8]+'sec\nTrain will be finished at '+t_end+'\nCurrent epoch:\nPROG:'+str(prog_epoch*100.)[:5]+'%\nETA:'+str(eta_e)[:8]+'sec\nCurrent epoch will be finished at '+e_end
self.t_send(msg=m)
print(m)
#==============================================================================
#
#==============================================================================
def get_fig(self,level='all',metrics=['all']):
#Get figure of train infomation
#level:show the information of which level
#metrics:metrics want to show,only show available ones
#获取训练状态图表
#level:显示batch级别函数epoch级别
#metrics:希望获得的指标,只显示存在的指标,若指定了不存在的指标将不会被显示
color_list='rgbyck'*10
def batches(color_list='rgbyck'*10,metrics=['all']):
if 'all' in metrics:
m_available=list(self.logs_batches.keys())
else:
m_available=([val for val in list(self.logs_batches.keys()) if val in metrics]if[val for val in list(self.logs_batches.keys()) if val in metrics]else list(self.logs_batches.keys()))
nb_rows_batches=int(ceil(len(m_available)*1.0/2))
fig_batches=plt.figure('all_subs_batches')
for i,k in enumerate(m_available):
p=plt.subplot(nb_rows_batches,2,i+1)
data=self.logs_batches[k]
p.plot(range(len(data)),data,color_list[i]+'-',label=k)
p.set_title(k+' in batches',fontsize=14)
p.set_xlabel('batch',fontsize=10)
p.set_ylabel(k,fontsize=10)
#p.legend()
filename=(self.fexten if self.fexten else self.validateTitle(self.localtime))+'_batches.jpg'
plt.tight_layout()
plt.savefig(filename)
plt.close('all')
#==============================================================================
# try:
# itchat.send_image(filename,toUserName='filehelper')
# except (socket.gaierror,ConnectionError,NotImplementedError,TypeError,KeyError):
# traceback.print_exc()
# print('\nConection error!\n')
# return
#==============================================================================
self.t_send_img(filename,toUserName='filehelper')
time.sleep(.5)
self.t_send('Sending batches figure',toUserName='filehelper')
return
#==============================================================================
#
#==============================================================================
def epochs(color_list='rgbyck'*10,metrics=['all']):
if 'all' in metrics:
m_available=list(self.logs_epochs.keys())
else:
m_available=([val for val in list(self.logs_epochs.keys()) if val in metrics]if[val for val in list(self.logs_epochs.keys()) if val in metrics]else list(self.logs_epochs.keys()))
nb_rows_epochs=int(ceil(len(m_available)*1.0/2))
fig_epochs=plt.figure('all_subs_epochs')
for i,k in enumerate(m_available):
p=plt.subplot(nb_rows_epochs,2,i+1)
data=self.logs_epochs[k]
p.plot(range(len(data)),data,color_list[i]+'-',label=k)
p.set_title(k+' in epochs',fontsize=14)
p.set_xlabel('epoch',fontsize=10)
p.set_ylabel(k,fontsize=10)
filename=(self.fexten if self.fexten else self.validateTitle(self.localtime))+'_epochs.jpg'
plt.tight_layout()
plt.savefig(filename)
plt.close('all')
#==============================================================================
# try:
# itchat.send_image(filename,toUserName='filehelper')
# except (socket.gaierror,ConnectionError,NotImplementedError,TypeError,KeyError):
# traceback.print_exc()
# print('\nConection error!\n')
# return
#==============================================================================
self.t_send_img(filename,toUserName='filehelper')
time.sleep(.5)
self.t_send('Sending epochs figure',toUserName='filehelper')
return
#==============================================================================
#
#==============================================================================
try:
if not self.epoch and (level in ['all','epochs']):
level='batches'
if level=='all':
batches(metrics=metrics)
epochs(metrics=metrics)
th.exit()
return
elif level=='epochs':
epochs(metrics=metrics)
th.exit()
return
elif level=='batches':
batches(metrics=metrics)
th.exit()
return
else:
batches(metrics=metrics)
epochs(metrics=metrics)
th.exit()
return
except Exception:
traceback.print_exc()
self.t_send('Failed to send figure',toUserName='filehelper')
th.exit()
return
#==============================================================================
#
#==============================================================================
def gpu_status(self,av_type_list):
for t in av_type_list:
cmd='nvidia-smi -q --display='+t
#print('\nCMD:',cmd,'\n')
r=os.popen(cmd)
info=r.readlines()
r.close()
content = " ".join(info)
#print('\ncontent:',content,'\n')
index=content.find('Attached GPUs')
s=content[index:].replace(' ','').rstrip('\n')
self.t_send(s, toUserName='filehelper')
time.sleep(.5)
#th.exit()
#==============================================================================
#
#==============================================================================
def on_train_begin(self, logs={}):
self.epoch=[]
self.t_epochs=[]
self.t_batches=0
self.logs_batches={}
self.logs_epochs={}
self.train_start=time.time()
self.localtime = time.asctime( time.localtime(self.train_start) )
self.mesg = 'Train started at: '+self.localtime
self.t_send(self.mesg, toUserName='filehelper')
self.stopped_epoch = (self.params['epochs'] if kv-1 else self.params['nb_epoch'])
@itchat.msg_register(TEXT)
#==============================================================================
# registe methods to reply msgs,similar to main()
# 注册消息响应方法,相当于主函数
#==============================================================================
def manualstop(msg):
text=msg['Text']
stop_training_cmdlist=['Stop now',"That's enough",u'停止训练',u'放弃治疗']
#The keywords of stop training,if any of them is in the msg you sent,the command would be accepted
#停止训练的关键词列表,发送的消息中包含任意一项都可触发命令
shut_down_cmdlist=[u'关机','Shut down','Shut down the computer',u'别浪费电了',u'洗洗睡吧']
#The keywords of shutting down,similair to stop_training_cmdlist
#关机关键词列表,和stop_training_cmdlist类似
cancel_cmdlist=[u'取消','cancel','aaaa']
#The keywords of cancel shutting down,similair to stop_training_cmdlist
#取消关机关键词列表,和stop_training_cmdlist类似
get_fig_cmdlist=[u'获取图表','Show me the figure']
#The keywords of getting figure,similair to stop_training_cmdlist
#获取图表关键词列表,和stop_training_cmdlist类似
gpu_cmdlist=['GPU','gpu',u'显卡']
type_list=['MEMORY', 'UTILIZATION', 'ECC', 'TEMPERATURE', 'POWER', 'CLOCK', 'COMPUTE', 'PIDS', 'PERFORMANCE', 'SUPPORTED_CLOCKS,PAGE_RETIREMENT', 'ACCOUNTING']
prog_cmdlist=[u'进度','Progress']
if msg['ToUserName']=='filehelper':
print('\n',text,'\n')
if 'Stop at' in text:
# Specify stop epoch,training will be stop after that epoch
#指定停止轮数,训练在指定epoch完成后会停止
#Example:send:'Stop at:8' from your phone,and then training will be stopped after epoch8
#例如:手机发送“Stop at:8”,训练将在epoch8完成后停止
self.stopped_epoch = int(re.findall(r"\d+\.?\d*",text)[0])
if kv-1:
self.params['epochs']=self.stopped_epoch
else:
self.params['nb_epoch']=self.stopped_epoch
self.t_send('Command accepted,training will be stopped at epoch'+str(self.stopped_epoch), toUserName='filehelper')
#==============================================================================
#
#==============================================================================
if any((k in text) for k in stop_training_cmdlist) :
#Stop training after current epoch finished
#当前epoch完成后停止训练
#example:send:'Stop now' or send:'停止训练' from your phone,and then training will be stopped after current epoch
#例如:手机发送“停止训练”或者“Stop now”,训练将会在当前epoch完成后被停止
self.model.stop_training = True
self.t_send('Command accepted,stop training now at epoch'+str(self.epoch[-1]+1), toUserName='filehelper')
#==============================================================================
#
#==============================================================================
if any((k in text) for k in shut_down_cmdlist):
#Shutting down the computer after specified sec,specify waiting seconds and saved model filename by {sec} and [name](without .h5)
#在指定秒数后关机,用{sec}和[name]指定等待时间和保存文件名,文件名不包括.h5
#example:send:'Shut down now [test]{120}' from phone,the computer will be shut down after 120s,and save the model as test.h5
#or send:'Shut down now{120},don't save',then the model won't be saved.
if any((k in text) for k in [u'不保存模型',"don't save"]):
save=False
else:
save=True
filepath=(self.GetMiddleStr(text,'[',']')+'.h5' if self.GetMiddleStr(text,'[',']') else (self.fexten if self.fexten else self.validateTitle(self.localtime))+'.h5')
print('\n',filepath,'\n')
sec=int((self.GetMiddleStr(text,'{','}') if self.GetMiddleStr(text,'{','}')>'30' else 120))
self.shutdown(sec,save=save,filepath=filepath)
#==============================================================================
#
#==============================================================================
if any((k in text) for k in cancel_cmdlist):
#Cancel shutting down the computer
self.cancel()
#==============================================================================
#
#==============================================================================
if any((k in text) for k in get_fig_cmdlist):
#Get figure of train infomation,specify metrics and level you want to show by[metrics]and{level},defualt are both 'all'
#example:send:'Show me the figure [loss]{batches}' from phone,you will recive a jpg image of losses in batches
#send:'Show me the figure',you will recive two jpg images of all metrics in batches and epochs
#获取图表,通过[metrics]和{level}指定参数,如果没有指定则皆默认为’all'
#例如,手机发送"获取图表[loss]{batches}",会收到一个jpg格式的loss随batches变化的图片
#手机发送"获取图表",则会得到两张图片,分别是所有指标随batch和epoch的变化
metrics=(self.GetMiddleStr(text,'[',']').split() if self.GetMiddleStr(text,'[',']').split() else ['all'])
level=(self.GetMiddleStr(text,'{','}') if self.GetMiddleStr(text,'{','}') else 'all' )
if level in ['all','epochs','batches']:
th.start_new_thread(self.get_fig,(level,metrics))
else:
print("\nGot no level,using default 'all'\n")
self.t_send("Got no level,using default 'all'", toUserName='filehelper')
th.start_new_thread(self.get_fig,())
if any((k in text) for k in gpu_cmdlist):
sp_type_lsit=(self.GetMiddleStr(text,'[',']').split() if self.GetMiddleStr(text,'[',']').split() else ['MEMORY'])
av_type_list=[val for val in sp_type_lsit if val in type_list]
self.gpu_status(av_type_list,)
if any((k in text) for k in prog_cmdlist):
try:
self.prog()
except:
traceback.print_exc()
th.start_new_thread(itchat.run, ())
#==============================================================================
#
#==============================================================================
def on_batch_end(self, batch, logs=None):
logs = logs or {}
for k in self.params['metrics']:
if k in logs:
self.logs_batches.setdefault(k, []).append(logs[k])
self.c_batches+=1
self.t_batches+=1
#==============================================================================
#
#==============================================================================
def on_epoch_begin(self, epoch, logs=None):
self.t_s=time.time()
self.epoch.append(epoch)
self.c_batches=0
self.t_send('Epoch'+str(epoch+1)+'/'+str(self.stopped_epoch)+' started', toUserName='filehelper')
self.mesg = ('Epoch:'+str(epoch+1)+' ')
#==============================================================================
#
#==============================================================================
def on_epoch_end(self, epoch, logs=None):
for k in self.params['metrics']:
if k in logs:
self.mesg+=(k+': '+str(logs[k])[:5]+' ')
self.logs_epochs.setdefault(k, []).append(logs[k])
#==============================================================================
# except:
# itchat.auto_login(hotReload=True,enableCmdQR=True)
# itchat.dump_login_status()
# self.t_send(self.mesg, toUserName='filehelper')
#==============================================================================
if epoch+1>=self.stopped_epoch:
self.model.stop_training = True
logs = logs or {}
self.epoch.append(epoch)
self.t_epochs.append(time.time()-self.t_s)
if self.savelog:
sio.savemat((self.fexten if self.fexten else self.validateTitle(self.localtime))+'_logs_batches'+'.mat',{'log':np.array(self.logs_batches)})
sio.savemat((self.fexten if self.fexten else self.validateTitle(self.localtime))+'_logs_batches'+'.mat',{'log':np.array(self.logs_epochs)})
th.start_new_thread(self.get_fig,())
#==============================================================================
# try:
# itchat.send(self.mesg, toUserName='filehelper')
# except:
# traceback.print_exc()
# return
#==============================================================================
self.t_send(self.mesg, toUserName='filehelper')
return
#==============================================================================
#
#==============================================================================
def on_train_end(self, logs=None):
self.t_send('Train stopped at epoch'+str(self.epoch[-1]+1), toUserName='filehelper')