Skip to content

Commit 76a0b34

Browse files
authored
Add files via upload
1 parent 279b7f8 commit 76a0b34

File tree

1 file changed

+252
-0
lines changed

1 file changed

+252
-0
lines changed
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Oct 29 13:42:38 2021
4+
@author: xiuzhang
5+
"""
6+
import os
7+
import re
8+
import cv2
9+
import pandas as pd
10+
import numpy as np
11+
from PIL import Image
12+
import albumentations as A
13+
from matplotlib import pyplot as plt
14+
from albumentations.pytorch.transforms import ToTensorV2
15+
16+
import torch
17+
import torchvision
18+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
19+
from torchvision.models.detection import FasterRCNN
20+
from torchvision.models.detection.rpn import AnchorGenerator
21+
from torch.utils.data import DataLoader, Dataset
22+
from torch.utils.data.sampler import SequentialSampler
23+
24+
from dataset import WheatDataset
25+
26+
#-----------------------------------------------------------------------------
27+
#第一步 函数定义
28+
#----------------------------------------------------------------------------
29+
#提取box的四个坐标
30+
def expand_bbox(x):
31+
r = np.array(re.findall("([0-9]+[.]?[0-9]*)", x))
32+
if len(r) == 0:
33+
r = [-1, -1, -1, -1]
34+
return r
35+
36+
#训练图像增强 Albumentations
37+
def get_train_transform():
38+
return A.Compose([
39+
A.Flip(0.5),
40+
ToTensorV2(p=1.0)
41+
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
42+
43+
#验证图像增强
44+
def get_valid_transform():
45+
return A.Compose([
46+
ToTensorV2(p=1.0)
47+
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
48+
49+
def collate_fn(batch):
50+
return tuple(zip(*batch))
51+
52+
#-----------------------------------------------------------------------------
53+
#第二步 定义变量并读取数据
54+
#-----------------------------------------------------------------------------
55+
DIR_INPUT = 'data'
56+
DIR_TRAIN = f'{DIR_INPUT}/train'
57+
DIR_TEST = f'{DIR_INPUT}/test'
58+
train_df = pd.read_csv(f'{DIR_INPUT}/train.csv')
59+
print(train_df.shape)
60+
61+
train_df['x'] = -1
62+
train_df['y'] = -1
63+
train_df['w'] = -1
64+
train_df['h'] = -1
65+
66+
#读取box四个坐标
67+
train_df[['x', 'y', 'w', 'h']] = np.stack(train_df['bbox'].apply(lambda x: expand_bbox(x)))
68+
train_df.drop(columns=['bbox'], inplace=True)
69+
train_df['x'] = train_df['x'].astype(np.float)
70+
train_df['y'] = train_df['y'].astype(np.float)
71+
train_df['w'] = train_df['w'].astype(np.float)
72+
train_df['h'] = train_df['h'].astype(np.float)
73+
74+
#获取图像id
75+
image_ids = train_df['image_id'].unique()
76+
valid_ids = image_ids[-665:]
77+
train_ids = image_ids[:-665]
78+
valid_df = train_df[train_df['image_id'].isin(valid_ids)]
79+
train_df = train_df[train_df['image_id'].isin(train_ids)]
80+
print(valid_df.shape, train_df.shape)
81+
print(train_df.head())
82+
83+
"""
84+
(147793, 5)
85+
(25006, 8) (122787, 8)
86+
image_id width height source x y w h
87+
0 b6ab77fd7 1024 1024 usask_1 834.0 222.0 56.0 36.0
88+
1 b6ab77fd7 1024 1024 usask_1 226.0 548.0 130.0 58.0
89+
2 b6ab77fd7 1024 1024 usask_1 377.0 504.0 74.0 160.0
90+
3 b6ab77fd7 1024 1024 usask_1 834.0 95.0 109.0 107.0
91+
4 b6ab77fd7 1024 1024 usask_1 26.0 144.0 124.0 117.0
92+
"""
93+
94+
#-----------------------------------------------------------------------------
95+
#第三步 加载数据
96+
#-----------------------------------------------------------------------------
97+
train_dataset = WheatDataset(train_df, DIR_TRAIN, get_train_transform())
98+
valid_dataset = WheatDataset(valid_df, DIR_TRAIN, get_valid_transform())
99+
100+
train_data_loader = DataLoader(
101+
train_dataset,
102+
batch_size=2,
103+
shuffle=False,
104+
num_workers=0,
105+
collate_fn=collate_fn
106+
)
107+
108+
valid_data_loader = DataLoader(
109+
valid_dataset,
110+
batch_size=2,
111+
shuffle=False,
112+
num_workers=0,
113+
collate_fn=collate_fn
114+
)
115+
116+
#-----------------------------------------------------------------------------
117+
#第四步 数据可视化
118+
#-----------------------------------------------------------------------------
119+
#提取训练数据和类别
120+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
121+
images, targets, image_ids = next(iter(train_data_loader))
122+
images = list(image.to(device) for image in images)
123+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
124+
boxes = targets[0]['boxes'].cpu().numpy().astype(np.int32)
125+
sample = images[0].permute(1, 2, 0).cpu().numpy()
126+
127+
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
128+
129+
#绘制小麦目标识别box
130+
for box in boxes:
131+
cv2.rectangle(sample,
132+
(box[0], box[1]),
133+
(box[2], box[3]),
134+
(255, 0, 0), 3)
135+
136+
ax.text(box[0],
137+
box[1] - 2,
138+
'{:s}'.format('wheat'),
139+
bbox=dict(facecolor='blue', alpha=0.5),
140+
fontsize=12,
141+
color='white')
142+
143+
ax.set_axis_off()
144+
ax.imshow(sample)
145+
#plt.show()
146+
147+
#-----------------------------------------------------------------------------
148+
#第五步 模型构建
149+
#-----------------------------------------------------------------------------
150+
num_classes = 2 #1 class (wheat) + background
151+
lr_scheduler = None
152+
num_epochs = 1
153+
itr = 1
154+
155+
class Averager:
156+
def __init__(self):
157+
self.current_total = 0.0
158+
self.iterations = 0.0
159+
160+
def send(self, value):
161+
self.current_total += value
162+
self.iterations += 1
163+
164+
@property
165+
def value(self):
166+
if self.iterations == 0:
167+
return 0
168+
else:
169+
return 1.0 * self.current_total / self.iterations
170+
171+
def reset(self):
172+
self.current_total = 0.0
173+
self.iterations = 0.0
174+
175+
#load a model pre-trained on COCO
176+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
177+
178+
#获取分类器输入特征数量
179+
in_features = model.roi_heads.box_predictor.cls_score.in_features
180+
181+
#replace the pre-trained head with a new one
182+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
183+
184+
#参数设置
185+
model.to(device)
186+
params = [p for p in model.parameters() if p.requires_grad]
187+
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
188+
#lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
189+
190+
loss_hist = Averager()
191+
print("Start training....")
192+
193+
# 迭代训练
194+
for epoch in range(num_epochs):
195+
loss_hist.reset()
196+
197+
for images, targets, image_ids in train_data_loader:
198+
199+
images = list(image.to(device) for image in images)
200+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
201+
for t in targets:
202+
t['boxes'] = t['boxes'].float()
203+
204+
loss_dict = model(images, targets)
205+
losses = sum(loss for loss in loss_dict.values())
206+
loss_value = losses.item()
207+
loss_hist.send(loss_value)
208+
print("loss is :",loss_value)
209+
210+
optimizer.zero_grad()
211+
losses.backward()
212+
optimizer.step()
213+
if itr % 50 == 0:
214+
print(f"Iteration #{itr}/{len(train_data_loader)} loss: {loss_value}")
215+
itr += 1
216+
217+
#更新学习率
218+
if lr_scheduler is not None:
219+
lr_scheduler.step()
220+
print(f"Epoch #{epoch} loss: {loss_hist.value}")
221+
222+
torch.save(model.state_dict(), 'fasterrcnn_resnet50_fpn.pth')
223+
print("Next Test....")
224+
225+
#-----------------------------------------------------------------------------
226+
#第六步 模型测试
227+
#-----------------------------------------------------------------------------
228+
images, targets, image_ids = next(iter(valid_data_loader))
229+
images = list(img.to(device) for img in images)
230+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
231+
boxes = targets[0]['boxes'].cpu().numpy().astype(np.int32)
232+
sample = images[0].permute(1, 2, 0).cpu().numpy()
233+
234+
model.eval()
235+
cpu_device = torch.device("cpu")
236+
237+
outputs = model(images)
238+
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
239+
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
240+
for box in boxes:
241+
cv2.rectangle(sample,
242+
(box[0], box[1]),
243+
(box[2], box[3]),
244+
(220, 0, 0), 3)
245+
246+
ax.set_axis_off()
247+
ax.imshow(sample)
248+
plt.show()
249+
250+
251+
252+

0 commit comments

Comments
 (0)