Skip to content

Commit 279b7f8

Browse files
authored
Create dataset.py
1 parent 34311f0 commit 279b7f8

1 file changed

Lines changed: 66 additions & 0 deletions

File tree

blog46-FasterRCNN-Wheat/dataset.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Oct 29 13:42:38 2021
4+
@author: xiuzhang
5+
"""
6+
import numpy as np
7+
import cv2
8+
import torch
9+
from torch.utils.data import Dataset
10+
11+
12+
class WheatDataset(Dataset):
13+
14+
def __init__(self, dataframe, image_dir, transforms=None):
15+
super().__init__()
16+
17+
self.image_ids = dataframe['image_id'].unique()
18+
self.df = dataframe
19+
self.image_dir = image_dir
20+
self.transforms = transforms
21+
22+
def __getitem__(self, index: int):
23+
image_id = self.image_ids[index]
24+
records = self.df[self.df['image_id'] == image_id]
25+
26+
image = cv2.imread(f'{self.image_dir}/{image_id}.jpg', cv2.IMREAD_COLOR)
27+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
28+
image /= 255.0
29+
30+
boxes = records[['x', 'y', 'w', 'h']].values
31+
boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
32+
boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
33+
34+
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
35+
area = torch.as_tensor(area, dtype=torch.float32)
36+
37+
# there is only one class
38+
labels = torch.ones((records.shape[0],), dtype=torch.int64)
39+
40+
# suppose all instances are not crowd
41+
iscrowd = torch.zeros((records.shape[0],), dtype=torch.int64)
42+
43+
target = {}
44+
target['boxes'] = boxes
45+
target['labels'] = labels
46+
# target['masks'] = None
47+
target['image_id'] = torch.tensor([index])
48+
target['area'] = area
49+
target['iscrowd'] = iscrowd
50+
51+
if self.transforms:
52+
sample = {
53+
'image': image,
54+
'bboxes': target['boxes'],
55+
'labels': labels
56+
}
57+
sample = self.transforms(**sample)
58+
image = sample['image']
59+
60+
target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
61+
62+
return image, target, image_id
63+
64+
def __len__(self) -> int:
65+
return self.image_ids.shape[0]
66+

0 commit comments

Comments
 (0)