Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add condlane #20

Merged
merged 15 commits into from
Sep 11, 2021
Prev Previous commit
Next Next commit
update condlane
  • Loading branch information
Turoad committed Jun 19, 2021
commit 8d6e0227a0b00322349ab397441dcbfa94a2a784
2 changes: 1 addition & 1 deletion configs/condlane/resnet101_culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
location_configs=dict(size=(batch_size, 1, 80, 200), device='cuda:0')
)

optimizer = dict(type='Adam', lr=6e-4, betas=(0.9, 0.999), eps=1e-8)
optimizer = dict(type='AdamW', lr=3e-4, betas=(0.9, 0.999), eps=1e-8)

epochs = 16
total_iter = (88880 // batch_size) * epochs
Expand Down
26 changes: 10 additions & 16 deletions lanedet/datasets/culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,17 @@ def load_annotation(self, line):
anno_path = img_path[:-3] + 'lines.txt' # remove sufix jpg and add lines.txt
with open(anno_path, 'r') as anno_file:
data = [list(map(float, line.split())) for line in anno_file.readlines()]
# lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) if lane[i] >= 0 and lane[i + 1] >= 0]
# for lane in data]
# lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
# lanes = [lane for lane in lanes if len(lane) >= 2] # remove lanes with less than 2 points
lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) if lane[i] >= 0 and lane[i + 1] >= 0]
for lane in data]
lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
lanes = [lane for lane in lanes if len(lane) > 3] # remove lanes with less than 2 points

# lanes = [sorted(lane, key=lambda x: x[1]) for lane in lanes] # sort by y
gt_points = [[float(lane[i]) for i in range(len(lane))] for lane in data]
gt_points = [gt_point for gt_point in gt_points if len(gt_point) > 3]
#infos['lanes'] = lanes
infos['gt_points'] = gt_points
infos['id_classes'] = [1 for i in range(len(gt_points))]
infos['id_instances'] = [i + 1 for i in range(len(gt_points))]
lanes = [sorted(lane, key=lambda x: x[1]) for lane in lanes] # sort by y
# gt_points = [[float(lane[i]) for i in range(len(lane))] for lane in data]
# gt_points = [gt_point for gt_point in gt_points if len(gt_point) > 3]

infos['lanes'] = lanes
# infos['gt_points'] = gt_points

return infos

Expand All @@ -69,11 +68,6 @@ def get_prediction_string(self, pred):
ys = np.arange(270, 590, 8) / self.cfg.ori_img_h
out = []
for lane in pred:
# points = lane.points
# lane_str = ' '.join(['{:.5f} {:.5f}'.format(x, y) for x, y in points])
# if lane_str != '':
# out.append(lane_str)
# continue
xs = lane(ys)
valid_mask = (xs >= 0) & (xs < 1)
xs = xs * self.cfg.ori_img_w
Expand Down
26 changes: 10 additions & 16 deletions lanedet/datasets/process/alaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,21 @@ def __call__(self, data):
else:
keypoints_val = keypoints_val + points_val

if 'lanes' in data:
points_val = []
for lane in data['lanes']:
points_val.extend(lane)

points_index = [len(lane) for lane in data['lanes']]
keypoints_val = points_val

aug = self.__augmentor(
image=img,
keypoints=keypoints_val,
bboxes=bboxes,
mask=masks,
bbox_labels=bbox_labels)

data['img'] = aug['image']
data['img_shape'] = data['img'].shape
if 'gt_bboxes' in data:
Expand All @@ -156,23 +165,8 @@ def __call__(self, data):
return None
if 'gt_masks' in data:
data['gt_masks'] = [np.array(aug['mask'])]
if 'gt_keypoints' in data:
kp_list = [[0 for j in range(i * 2)] for i in keypoints_index]
for i in range(len(keypoints_index)):
for j in range(keypoints_index[i]):
kp_list[i][2 * j] = aug['keypoints'][
self.cal_sum_list(keypoints_index, i) + j][0]
kp_list[i][2 * j + 1] = aug['keypoints'][
self.cal_sum_list(keypoints_index, i) + j][1]
data['gt_keypoints'] = []
valid = []
for i in range(kp_group_num):
index = int(aug['bboxes'][i][-1])
valid.append(index)
data['gt_keypoints'].append(kp_list[index])
data['gt_keypoints_ignore'] = data['gt_keypoints_ignore'][valid]

if 'gt_points' in data:
if 'gt_points' in data or 'lanes' in data:
start_idx = num_keypoints if 'gt_keypoints' in data else 0
points = aug['keypoints'][start_idx:]
kp_list = [[0 for j in range(i * 2)] for i in points_index]
Expand Down
24 changes: 13 additions & 11 deletions lanedet/datasets/process/collect_lane.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,12 @@ def min_dis_one_point(points, idx):

# gt heatmap and ins of bank
gt_points = results['gt_points']
# gt_points = results['lanes']
valid_gt = []
for pts in gt_points:
id_class = 1
# pts = np.array(pts) / self.down_scale
#pts = [(pt[0]/self.down_scale, pt[1]/self.down_scale) for pt in pts]
pts = convert_list(pts, self.down_scale)
pts = sorted(pts, key=cmp_to_key(lambda a, b: b[-1] - a[-1]))
pts = clamp_line(
Expand All @@ -240,21 +243,20 @@ def min_dis_one_point(points, idx):

# draw gt_hm_lane
gt_hm_lane_ends = []
radius = []
for l in valid_gt:
label = l[1]
point = (l[0][0][0] * ratio_hm_mask, l[0][0][1] * ratio_hm_mask)
gt_hm_lane_ends.append([point, l[0]])
for i, p in enumerate(gt_hm_lane_ends):
r = self.radius
radius.append(r)

if len(gt_hm_lane_ends) >= 2:
endpoints = [p[0] for p in gt_hm_lane_ends]
for j in range(len(endpoints)):
dis = min_dis_one_point(endpoints, j)
if dis < 1.5 * radius[j]:
radius[j] = int(max(dis / 1.5, 1) + 0.49999)
radius = [self.radius for _ in range(len(gt_hm_lane_ends))]

# if len(gt_hm_lane_ends) >= 2:
# endpoints = [p[0] for p in gt_hm_lane_ends]
# for j in range(len(endpoints)):
# dis = min_dis_one_point(endpoints, j)
# if dis < 1.5 * radius[j]:
# raise False
# radius[j] = int(max(dis / 1.5, 1) + 0.49999)

for (end_point, line), r in zip(gt_hm_lane_ends, radius):
pos = np.zeros((mask_h), np.float32)
pos_mask = np.zeros((mask_h), np.float32)
Expand Down
2 changes: 0 additions & 2 deletions lanedet/models/heads/condlane.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,7 @@ def parse_pos(seeds, batchsize, num_classes, h, w, device):
h_mask, w_mask = f_mask.size()[2:]
hms, params = z['hm'], z['params']
hms = torch.clamp(hms.sigmoid(), min=1e-4, max=1 - 1e-4)
print(params.shape)
params = params.view(m_batchsize, self.num_classes, -1, h_hm, w_hm)
print(params.shape)
# with Timer("Elapsed time in two branch: %f"): # 0.6ms
mask_branchs = self.mask_branch(f_mask)
reg_branchs = mask_branchs
Expand Down