Skip to content

Commit fc15308

Browse files
authored
Add files via upload
1 parent 655da23 commit fc15308

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import cv2
4+
import numpy as np
5+
#from sklearn.cross_validation import train_test_split
6+
from sklearn.model_selection import train_test_split
7+
from sklearn.metrics import confusion_matrix, classification_report
8+
9+
#----------------------------------------------------------------------------------
10+
# 第一步 切分训练集和测试集
11+
#----------------------------------------------------------------------------------
12+
13+
X = [] #定义图像名称
14+
Y = [] #定义图像分类类标
15+
Z = [] #定义图像像素
16+
17+
for i in range(0, 10):
18+
#遍历文件夹,读取图片
19+
for f in os.listdir("photo/%s" % i):
20+
#获取图像名称
21+
X.append("photo//" +str(i) + "//" + str(f))
22+
#获取图像类标即为文件夹名称
23+
Y.append(i)
24+
25+
X = np.array(X)
26+
Y = np.array(Y)
27+
28+
#随机率为100% 选取其中的30%作为测试集
29+
X_train, X_test, y_train, y_test = train_test_split(X, Y,
30+
test_size=0.3, random_state=1)
31+
32+
print(len(X_train), len(X_test), len(y_train), len(y_test))
33+
34+
#----------------------------------------------------------------------------------
35+
# 第二步 图像读取及转换为像素直方图
36+
#----------------------------------------------------------------------------------
37+
38+
#训练集
39+
XX_train = []
40+
for i in X_train:
41+
#读取图像
42+
#print i
43+
image = cv2.imread(i)
44+
45+
#图像像素大小一致
46+
img = cv2.resize(image, (256,256),
47+
interpolation=cv2.INTER_CUBIC)
48+
49+
#计算图像直方图并存储至X数组
50+
hist = cv2.calcHist([img], [0,1], None,
51+
[256,256], [0.0,255.0,0.0,255.0])
52+
53+
XX_train.append(((hist/255).flatten()))
54+
55+
#测试集
56+
XX_test = []
57+
for i in X_test:
58+
#读取图像
59+
#print i
60+
image = cv2.imread(i)
61+
62+
#图像像素大小一致
63+
img = cv2.resize(image, (256,256),
64+
interpolation=cv2.INTER_CUBIC)
65+
66+
#计算图像直方图并存储至X数组
67+
hist = cv2.calcHist([img], [0,1], None,
68+
[256,256], [0.0,255.0,0.0,255.0])
69+
70+
XX_test.append(((hist/255).flatten()))
71+
72+
#----------------------------------------------------------------------------------
73+
# 第三步 基于决策树的图像分类处理
74+
#----------------------------------------------------------------------------------
75+
76+
from sklearn.tree import DecisionTreeClassifier
77+
clf = DecisionTreeClassifier().fit(XX_train, y_train)
78+
predictions_labels = clf.predict(XX_test)
79+
80+
print(u'预测结果:')
81+
print(predictions_labels)
82+
83+
print(u'算法评价:')
84+
print((classification_report(y_test, predictions_labels)))
85+
86+
#输出前10张图片及预测结果
87+
k = 0
88+
while k<10:
89+
#读取图像
90+
print(X_test[k])
91+
image = cv2.imread(X_test[k])
92+
print(predictions_labels[k])
93+
#显示图像
94+
cv2.imshow("img", image)
95+
cv2.waitKey(0)
96+
cv2.destroyAllWindows()
97+
k = k + 1

0 commit comments

Comments
 (0)