-
Notifications
You must be signed in to change notification settings - Fork 595
/
base.py
172 lines (154 loc) · 6.53 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from sahi.utils.import_utils import is_available
from sahi.utils.torch import is_torch_cuda_available
class DetectionModel:
def __init__(
self,
model_path: Optional[str] = None,
model: Optional[Any] = None,
config_path: Optional[str] = None,
device: Optional[str] = None,
mask_threshold: float = 0.5,
confidence_threshold: float = 0.3,
category_mapping: Optional[Dict] = None,
category_remapping: Optional[Dict] = None,
load_at_init: bool = True,
image_size: int = None,
):
"""
Init object detection/instance segmentation model.
Args:
model_path: str
Path for the instance segmentation model weight
config_path: str
Path for the mmdetection instance segmentation model config file
device: str
Torch device, "cpu" or "cuda"
mask_threshold: float
Value to threshold mask pixels, should be between 0 and 1
confidence_threshold: float
All predictions with score < confidence_threshold will be discarded
category_mapping: dict: str to str
Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
category_remapping: dict: str to int
Remap category ids based on category names, after performing inference e.g. {"car": 3}
load_at_init: bool
If True, automatically loads the model at initalization
image_size: int
Inference input size.
"""
self.model_path = model_path
self.config_path = config_path
self.model = None
self.device = device
self.mask_threshold = mask_threshold
self.confidence_threshold = confidence_threshold
self.category_mapping = category_mapping
self.category_remapping = category_remapping
self.image_size = image_size
self._original_predictions = None
self._object_prediction_list_per_image = None
# automatically set device if its None
if not (self.device):
self.device = "cuda:0" if is_torch_cuda_available() else "cpu"
# automatically load model if load_at_init is True
if load_at_init:
if model:
self.set_model(model)
else:
self.load_model()
def load_model(self):
"""
This function should be implemented in a way that detection model
should be initialized and set to self.model.
(self.model_path, self.config_path, and self.device should be utilized)
"""
raise NotImplementedError()
def set_model(self, model: Any, **kwargs):
"""
This function should be implemented to instantiate a DetectionModel out of an already loaded model
Args:
model: Any
Loaded model
"""
raise NotImplementedError()
def unload_model(self):
"""
Unloads the model from CPU/GPU.
"""
self.model = None
if is_available("torch"):
from sahi.utils.torch import empty_cuda_cache
empty_cuda_cache()
def perform_inference(self, image: np.ndarray):
"""
This function should be implemented in a way that prediction should be
performed using self.model and the prediction result should be set to self._original_predictions.
Args:
image: np.ndarray
A numpy array that contains the image to be predicted.
"""
raise NotImplementedError()
def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
This function should be implemented in a way that self._original_predictions should
be converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list. self.mask_threshold can also be utilized.
Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
full_shape_list: list of list
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
raise NotImplementedError()
def _apply_category_remapping(self):
"""
Applies category remapping based on mapping given in self.category_remapping
"""
# confirm self.category_remapping is not None
if self.category_remapping is None:
raise ValueError("self.category_remapping cannot be None")
# remap categories
for object_prediction_list in self._object_prediction_list_per_image:
for object_prediction in object_prediction_list:
old_category_id_str = str(object_prediction.category.id)
new_category_id_int = self.category_remapping[old_category_id_str]
object_prediction.category.id = new_category_id_int
def convert_original_predictions(
self,
shift_amount: Optional[List[int]] = [0, 0],
full_shape: Optional[List[int]] = None,
):
"""
Converts original predictions of the detection model to a list of
prediction.ObjectPrediction object. Should be called after perform_inference().
Args:
shift_amount: list
To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]
full_shape: list
Size of the full image after shifting, should be in the form of [height, width]
"""
self._create_object_prediction_list_from_original_predictions(
shift_amount_list=shift_amount,
full_shape_list=full_shape,
)
if self.category_remapping:
self._apply_category_remapping()
@property
def object_prediction_list(self):
return self._object_prediction_list_per_image[0]
@property
def object_prediction_list_per_image(self):
return self._object_prediction_list_per_image
@property
def original_predictions(self):
return self._original_predictions