1+ # -*- coding: utf-8 -*-
2+ """
3+ Created on Fri Oct 29 13:42:38 2021
4+ @author: xiuzhang
5+ """
6+ import os
7+ import re
8+ import cv2
9+ import pandas as pd
10+ import numpy as np
11+ from PIL import Image
12+ import albumentations as A
13+ from matplotlib import pyplot as plt
14+ from albumentations .pytorch .transforms import ToTensorV2
15+
16+ import torch
17+ import torchvision
18+ from torchvision .models .detection .faster_rcnn import FastRCNNPredictor
19+ from torchvision .models .detection import FasterRCNN
20+ from torchvision .models .detection .rpn import AnchorGenerator
21+ from torch .utils .data import DataLoader , Dataset
22+ from torch .utils .data .sampler import SequentialSampler
23+
24+ from dataset import WheatDataset
25+
26+ #-----------------------------------------------------------------------------
27+ #第一步 函数定义
28+ #----------------------------------------------------------------------------
29+ #提取box的四个坐标
30+ def expand_bbox (x ):
31+ r = np .array (re .findall ("([0-9]+[.]?[0-9]*)" , x ))
32+ if len (r ) == 0 :
33+ r = [- 1 , - 1 , - 1 , - 1 ]
34+ return r
35+
36+ #训练图像增强 Albumentations
37+ def get_train_transform ():
38+ return A .Compose ([
39+ A .Flip (0.5 ),
40+ ToTensorV2 (p = 1.0 )
41+ ], bbox_params = {'format' : 'pascal_voc' , 'label_fields' : ['labels' ]})
42+
43+ #验证图像增强
44+ def get_valid_transform ():
45+ return A .Compose ([
46+ ToTensorV2 (p = 1.0 )
47+ ], bbox_params = {'format' : 'pascal_voc' , 'label_fields' : ['labels' ]})
48+
49+ def collate_fn (batch ):
50+ return tuple (zip (* batch ))
51+
52+ #-----------------------------------------------------------------------------
53+ #第二步 定义变量并读取数据
54+ #-----------------------------------------------------------------------------
55+ DIR_INPUT = 'data'
56+ DIR_TRAIN = f'{ DIR_INPUT } /train'
57+ DIR_TEST = f'{ DIR_INPUT } /test'
58+ train_df = pd .read_csv (f'{ DIR_INPUT } /train.csv' )
59+ print (train_df .shape )
60+
61+ train_df ['x' ] = - 1
62+ train_df ['y' ] = - 1
63+ train_df ['w' ] = - 1
64+ train_df ['h' ] = - 1
65+
66+ #读取box四个坐标
67+ train_df [['x' , 'y' , 'w' , 'h' ]] = np .stack (train_df ['bbox' ].apply (lambda x : expand_bbox (x )))
68+ train_df .drop (columns = ['bbox' ], inplace = True )
69+ train_df ['x' ] = train_df ['x' ].astype (np .float )
70+ train_df ['y' ] = train_df ['y' ].astype (np .float )
71+ train_df ['w' ] = train_df ['w' ].astype (np .float )
72+ train_df ['h' ] = train_df ['h' ].astype (np .float )
73+
74+ #获取图像id
75+ image_ids = train_df ['image_id' ].unique ()
76+ valid_ids = image_ids [- 665 :]
77+ train_ids = image_ids [:- 665 ]
78+ valid_df = train_df [train_df ['image_id' ].isin (valid_ids )]
79+ train_df = train_df [train_df ['image_id' ].isin (train_ids )]
80+ print (valid_df .shape , train_df .shape )
81+ print (train_df .head ())
82+
83+ """
84+ (147793, 5)
85+ (25006, 8) (122787, 8)
86+ image_id width height source x y w h
87+ 0 b6ab77fd7 1024 1024 usask_1 834.0 222.0 56.0 36.0
88+ 1 b6ab77fd7 1024 1024 usask_1 226.0 548.0 130.0 58.0
89+ 2 b6ab77fd7 1024 1024 usask_1 377.0 504.0 74.0 160.0
90+ 3 b6ab77fd7 1024 1024 usask_1 834.0 95.0 109.0 107.0
91+ 4 b6ab77fd7 1024 1024 usask_1 26.0 144.0 124.0 117.0
92+ """
93+
94+ #-----------------------------------------------------------------------------
95+ #第三步 加载数据
96+ #-----------------------------------------------------------------------------
97+ train_dataset = WheatDataset (train_df , DIR_TRAIN , get_train_transform ())
98+ valid_dataset = WheatDataset (valid_df , DIR_TRAIN , get_valid_transform ())
99+
100+ train_data_loader = DataLoader (
101+ train_dataset ,
102+ batch_size = 2 ,
103+ shuffle = False ,
104+ num_workers = 0 ,
105+ collate_fn = collate_fn
106+ )
107+
108+ valid_data_loader = DataLoader (
109+ valid_dataset ,
110+ batch_size = 2 ,
111+ shuffle = False ,
112+ num_workers = 0 ,
113+ collate_fn = collate_fn
114+ )
115+
116+ #-----------------------------------------------------------------------------
117+ #第四步 数据可视化
118+ #-----------------------------------------------------------------------------
119+ #提取训练数据和类别
120+ device = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('cpu' )
121+ images , targets , image_ids = next (iter (train_data_loader ))
122+ images = list (image .to (device ) for image in images )
123+ targets = [{k : v .to (device ) for k , v in t .items ()} for t in targets ]
124+ boxes = targets [0 ]['boxes' ].cpu ().numpy ().astype (np .int32 )
125+ sample = images [0 ].permute (1 , 2 , 0 ).cpu ().numpy ()
126+
127+ fig , ax = plt .subplots (1 , 1 , figsize = (10 , 8 ))
128+
129+ #绘制小麦目标识别box
130+ for box in boxes :
131+ cv2 .rectangle (sample ,
132+ (box [0 ], box [1 ]),
133+ (box [2 ], box [3 ]),
134+ (255 , 0 , 0 ), 3 )
135+
136+ ax .text (box [0 ],
137+ box [1 ] - 2 ,
138+ '{:s}' .format ('wheat' ),
139+ bbox = dict (facecolor = 'blue' , alpha = 0.5 ),
140+ fontsize = 12 ,
141+ color = 'white' )
142+
143+ ax .set_axis_off ()
144+ ax .imshow (sample )
145+ #plt.show()
146+
147+ #-----------------------------------------------------------------------------
148+ #第五步 模型构建
149+ #-----------------------------------------------------------------------------
150+ num_classes = 2 #1 class (wheat) + background
151+ lr_scheduler = None
152+ num_epochs = 1
153+ itr = 1
154+
155+ class Averager :
156+ def __init__ (self ):
157+ self .current_total = 0.0
158+ self .iterations = 0.0
159+
160+ def send (self , value ):
161+ self .current_total += value
162+ self .iterations += 1
163+
164+ @property
165+ def value (self ):
166+ if self .iterations == 0 :
167+ return 0
168+ else :
169+ return 1.0 * self .current_total / self .iterations
170+
171+ def reset (self ):
172+ self .current_total = 0.0
173+ self .iterations = 0.0
174+
175+ #load a model pre-trained on COCO
176+ model = torchvision .models .detection .fasterrcnn_resnet50_fpn (pretrained = True )
177+
178+ #获取分类器输入特征数量
179+ in_features = model .roi_heads .box_predictor .cls_score .in_features
180+
181+ #replace the pre-trained head with a new one
182+ model .roi_heads .box_predictor = FastRCNNPredictor (in_features , num_classes )
183+
184+ #参数设置
185+ model .to (device )
186+ params = [p for p in model .parameters () if p .requires_grad ]
187+ optimizer = torch .optim .SGD (params , lr = 0.005 , momentum = 0.9 , weight_decay = 0.0005 )
188+ #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
189+
190+ loss_hist = Averager ()
191+ print ("Start training...." )
192+
193+ # 迭代训练
194+ for epoch in range (num_epochs ):
195+ loss_hist .reset ()
196+
197+ for images , targets , image_ids in train_data_loader :
198+
199+ images = list (image .to (device ) for image in images )
200+ targets = [{k : v .to (device ) for k , v in t .items ()} for t in targets ]
201+ for t in targets :
202+ t ['boxes' ] = t ['boxes' ].float ()
203+
204+ loss_dict = model (images , targets )
205+ losses = sum (loss for loss in loss_dict .values ())
206+ loss_value = losses .item ()
207+ loss_hist .send (loss_value )
208+ print ("loss is :" ,loss_value )
209+
210+ optimizer .zero_grad ()
211+ losses .backward ()
212+ optimizer .step ()
213+ if itr % 50 == 0 :
214+ print (f"Iteration #{ itr } /{ len (train_data_loader )} loss: { loss_value } " )
215+ itr += 1
216+
217+ #更新学习率
218+ if lr_scheduler is not None :
219+ lr_scheduler .step ()
220+ print (f"Epoch #{ epoch } loss: { loss_hist .value } " )
221+
222+ torch .save (model .state_dict (), 'fasterrcnn_resnet50_fpn.pth' )
223+ print ("Next Test...." )
224+
225+ #-----------------------------------------------------------------------------
226+ #第六步 模型测试
227+ #-----------------------------------------------------------------------------
228+ images , targets , image_ids = next (iter (valid_data_loader ))
229+ images = list (img .to (device ) for img in images )
230+ targets = [{k : v .to (device ) for k , v in t .items ()} for t in targets ]
231+ boxes = targets [0 ]['boxes' ].cpu ().numpy ().astype (np .int32 )
232+ sample = images [0 ].permute (1 , 2 , 0 ).cpu ().numpy ()
233+
234+ model .eval ()
235+ cpu_device = torch .device ("cpu" )
236+
237+ outputs = model (images )
238+ outputs = [{k : v .to (cpu_device ) for k , v in t .items ()} for t in outputs ]
239+ fig , ax = plt .subplots (1 , 1 , figsize = (16 , 8 ))
240+ for box in boxes :
241+ cv2 .rectangle (sample ,
242+ (box [0 ], box [1 ]),
243+ (box [2 ], box [3 ]),
244+ (220 , 0 , 0 ), 3 )
245+
246+ ax .set_axis_off ()
247+ ax .imshow (sample )
248+ plt .show ()
249+
250+
251+
252+
0 commit comments