|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Created on Fri Oct 29 13:42:38 2021 |
| 4 | +@author: xiuzhang |
| 5 | +""" |
| 6 | +import numpy as np |
| 7 | +import cv2 |
| 8 | +import torch |
| 9 | +from torch.utils.data import Dataset |
| 10 | + |
| 11 | + |
| 12 | +class WheatDataset(Dataset): |
| 13 | + |
| 14 | + def __init__(self, dataframe, image_dir, transforms=None): |
| 15 | + super().__init__() |
| 16 | + |
| 17 | + self.image_ids = dataframe['image_id'].unique() |
| 18 | + self.df = dataframe |
| 19 | + self.image_dir = image_dir |
| 20 | + self.transforms = transforms |
| 21 | + |
| 22 | + def __getitem__(self, index: int): |
| 23 | + image_id = self.image_ids[index] |
| 24 | + records = self.df[self.df['image_id'] == image_id] |
| 25 | + |
| 26 | + image = cv2.imread(f'{self.image_dir}/{image_id}.jpg', cv2.IMREAD_COLOR) |
| 27 | + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) |
| 28 | + image /= 255.0 |
| 29 | + |
| 30 | + boxes = records[['x', 'y', 'w', 'h']].values |
| 31 | + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] |
| 32 | + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] |
| 33 | + |
| 34 | + area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) |
| 35 | + area = torch.as_tensor(area, dtype=torch.float32) |
| 36 | + |
| 37 | + # there is only one class |
| 38 | + labels = torch.ones((records.shape[0],), dtype=torch.int64) |
| 39 | + |
| 40 | + # suppose all instances are not crowd |
| 41 | + iscrowd = torch.zeros((records.shape[0],), dtype=torch.int64) |
| 42 | + |
| 43 | + target = {} |
| 44 | + target['boxes'] = boxes |
| 45 | + target['labels'] = labels |
| 46 | + # target['masks'] = None |
| 47 | + target['image_id'] = torch.tensor([index]) |
| 48 | + target['area'] = area |
| 49 | + target['iscrowd'] = iscrowd |
| 50 | + |
| 51 | + if self.transforms: |
| 52 | + sample = { |
| 53 | + 'image': image, |
| 54 | + 'bboxes': target['boxes'], |
| 55 | + 'labels': labels |
| 56 | + } |
| 57 | + sample = self.transforms(**sample) |
| 58 | + image = sample['image'] |
| 59 | + |
| 60 | + target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0) |
| 61 | + |
| 62 | + return image, target, image_id |
| 63 | + |
| 64 | + def __len__(self) -> int: |
| 65 | + return self.image_ids.shape[0] |
| 66 | + |
0 commit comments