Skip to content

Commit 44ca765

Browse files
authored
Create blog48-01-cnn-dataset.py
1 parent 0194bac commit 44ca765

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Wed Jul 7 18:54:36 2021
4+
@author: xiuzhang CSDN
5+
参考:刘润森老师博客 推荐大家关注 很厉害的一位CV大佬
6+
https://maoli.blog.csdn.net/article/details/117688738
7+
"""
8+
import numpy as np
9+
import pandas as pd
10+
from IPython.display import display
11+
import csv
12+
from PIL import Image
13+
from scipy.ndimage import rotate
14+
15+
#----------------------------------------------------------------
16+
# 第一步 读取数据
17+
#----------------------------------------------------------------
18+
#训练数据images和labels
19+
letters_training_images_file_path = "dataset/csvTrainImages 13440x1024.csv"
20+
letters_training_labels_file_path = "dataset/csvTrainLabel 13440x1.csv"
21+
#测试数据images和labels
22+
letters_testing_images_file_path = "dataset/csvTestImages 3360x1024.csv"
23+
letters_testing_labels_file_path = "dataset/csvTestLabel 3360x1.csv"
24+
25+
#加载数据
26+
training_letters_images = pd.read_csv(letters_training_images_file_path, header=None)
27+
training_letters_labels = pd.read_csv(letters_training_labels_file_path, header=None)
28+
testing_letters_images = pd.read_csv(letters_testing_images_file_path, header=None)
29+
testing_letters_labels = pd.read_csv(letters_testing_labels_file_path, header=None)
30+
print("%d个32x32像素的训练阿拉伯字母图像" % training_letters_images.shape[0])
31+
print("%d个32x32像素的测试阿拉伯字母图像" % testing_letters_images.shape[0])
32+
print(training_letters_images.head())
33+
print(np.unique(training_letters_labels))
34+
35+
36+
#----------------------------------------------------------------
37+
# 第二步 数值转换为图像特征
38+
#----------------------------------------------------------------
39+
#原始数据集被反射使用np.flip翻转它 通过rotate旋转从而获得更好的图像
40+
def convert_values_to_image(image_values, display=False):
41+
#转换成32x32
42+
image_array = np.asarray(image_values)
43+
image_array = image_array.reshape(32,32).astype('uint8')
44+
#翻转+旋转
45+
image_array = np.flip(image_array, 0)
46+
image_array = rotate(image_array, -90)
47+
#图像显示
48+
new_image = Image.fromarray(image_array)
49+
if display == True:
50+
new_image.show()
51+
return new_image
52+
53+
convert_values_to_image(training_letters_images.loc[0], True)
54+
55+
56+
#----------------------------------------------------------------
57+
# 第三步 图像标准化处理
58+
#----------------------------------------------------------------
59+
training_letters_images_scaled = training_letters_images.values.astype('float32')/255
60+
training_letters_labels = training_letters_labels.values.astype('int32')
61+
testing_letters_images_scaled = testing_letters_images.values.astype('float32')/255
62+
testing_letters_labels = testing_letters_labels.values.astype('int32')
63+
print("Training images of letters after scaling")
64+
print(training_letters_images_scaled.shape)
65+
print(training_letters_images_scaled[0:5])
66+
67+
68+
#----------------------------------------------------------------
69+
# 第四步 输出One-hot编码转换
70+
#----------------------------------------------------------------
71+
import keras
72+
from keras.utils import to_categorical
73+
number_of_classes = 28
74+
training_letters_labels_encoded = to_categorical(training_letters_labels-1,
75+
num_classes=number_of_classes)
76+
testing_letters_labels_encoded = to_categorical(testing_letters_labels-1,
77+
num_classes=number_of_classes)
78+
print(training_letters_labels)
79+
print(training_letters_labels_encoded)
80+
print(training_letters_images_scaled.shape)
81+
# (13440, 1024)
82+
83+
84+
#----------------------------------------------------------------
85+
# 第五步 形状修改
86+
#----------------------------------------------------------------
87+
#输入形状 32x32x1
88+
training_letters_images_scaled = training_letters_images_scaled.reshape([-1, 32, 32, 1])
89+
testing_letters_images_scaled = testing_letters_images_scaled.reshape([-1, 32, 32, 1])
90+
print(training_letters_images_scaled.shape,
91+
training_letters_labels_encoded.shape,
92+
testing_letters_images_scaled.shape,
93+
testing_letters_labels_encoded.shape)
94+
# (13440, 32, 32, 1) (13440, 28) (3360, 32, 32, 1) (3360, 28)
95+
96+
97+
#----------------------------------------------------------------
98+
# 第六步 CNN模型设计
99+
#----------------------------------------------------------------
100+
from keras.models import Sequential
101+
from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, BatchNormalization, Dropout, Dense
102+
103+
#定义模型
104+
def create_model(optimizer='adam', kernel_initializer='he_normal', activation='relu'):
105+
#第一个卷积层
106+
model = Sequential()
107+
model.add(Conv2D(filters=16, kernel_size=3, padding='same', input_shape=(32, 32, 1), kernel_initializer=kernel_initializer, activation=activation))
108+
model.add(BatchNormalization())
109+
model.add(MaxPooling2D(pool_size=2))
110+
model.add(Dropout(0.2))
111+
112+
#第二个卷积层
113+
model.add(Conv2D(filters=32, kernel_size=3, padding='same', kernel_initializer=kernel_initializer, activation=activation))
114+
model.add(BatchNormalization())
115+
model.add(MaxPooling2D(pool_size=2))
116+
model.add(Dropout(0.2))
117+
118+
#第三个卷积层
119+
model.add(Conv2D(filters=64, kernel_size=3, padding='same', kernel_initializer=kernel_initializer, activation=activation))
120+
model.add(BatchNormalization())
121+
model.add(MaxPooling2D(pool_size=2))
122+
model.add(Dropout(0.2))
123+
124+
#第四个卷积层
125+
model.add(Conv2D(filters=128, kernel_size=3, padding='same', kernel_initializer=kernel_initializer, activation=activation))
126+
model.add(BatchNormalization())
127+
model.add(MaxPooling2D(pool_size=2))
128+
model.add(Dropout(0.2))
129+
model.add(GlobalAveragePooling2D())
130+
131+
#全连接层输出28类结果
132+
model.add(Dense(28, activation='softmax'))
133+
134+
#损失函数定义
135+
model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer=optimizer)
136+
return model
137+
138+
#创建模型
139+
model = create_model(optimizer='Adam', kernel_initializer='uniform', activation='relu')
140+
model.summary()
141+
142+
143+
#----------------------------------------------------------------
144+
# 第七步 模型绘制
145+
#----------------------------------------------------------------
146+
from keras.utils.vis_utils import plot_model
147+
from IPython.display import Image as IPythonImage
148+
149+
plot_model(model, to_file="model.png", show_shapes=True)
150+
display(IPythonImage('model.png'))
151+
152+
153+
#----------------------------------------------------------------
154+
# 第八步 模型训练
155+
#----------------------------------------------------------------
156+
from keras.callbacks import ModelCheckpoint
157+
158+
checkpointer = ModelCheckpoint(filepath='weights.hdf5',
159+
verbose=1,
160+
save_best_only=True)
161+
history = model.fit(training_letters_images_scaled,
162+
training_letters_labels_encoded,
163+
validation_data=(testing_letters_images_scaled,
164+
testing_letters_labels_encoded),
165+
epochs=15,
166+
batch_size=20,
167+
verbose=1,
168+
callbacks=[checkpointer])
169+
print(history)
170+
171+
#----------------------------------------------------------------
172+
# 第九步 绘制图形
173+
#----------------------------------------------------------------
174+
import matplotlib.pyplot as plt
175+
176+
def plot_loss_accuracy(history):
177+
# Loss
178+
plt.figure(figsize=[8,6])
179+
plt.plot(history.history['loss'],'r',linewidth=3.0)
180+
plt.plot(history.history['val_loss'],'b',linewidth=3.0)
181+
plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
182+
plt.xlabel('Epochs ',fontsize=16)
183+
plt.ylabel('Loss',fontsize=16)
184+
plt.title('Loss Curves',fontsize=16)
185+
186+
# Accuracy
187+
plt.figure(figsize=[8,6])
188+
plt.plot(history.history['accuracy'],'r',linewidth=3.0)
189+
plt.plot(history.history['val_accuracy'],'b',linewidth=3.0)
190+
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
191+
plt.xlabel('Epochs ',fontsize=16)
192+
plt.ylabel('Accuracy',fontsize=16)
193+
plt.title('Accuracy Curves',fontsize=16)
194+
195+
plot_loss_accuracy(history)
196+
197+

0 commit comments

Comments
 (0)