Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add condlane #20

Merged
merged 15 commits into from
Sep 11, 2021
Prev Previous commit
Next Next commit
remove some magic number
  • Loading branch information
Turoad committed Jul 8, 2021
commit 9faf0dc46834ba47e6fbd092304a7141ae1b154e
3 changes: 1 addition & 2 deletions configs/condlane/resnet101_culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
in_channels=[64, 128, 256, 512]
)

featuremap_out_channel = 128
featuremap_out_stride = 8
sample_y = range(589, 230, -20)

batch_size = 8
Expand Down Expand Up @@ -95,6 +93,7 @@
nms_thr = 4
img_scale = (800, 320)
crop_bbox = [0, 270, 1640, 590]
mask_size = (1, 80, 200)

train_process = [
dict(type='Alaug',
Expand Down
5 changes: 1 addition & 4 deletions lanedet/datasets/culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,10 @@ def load_annotation(self, line):
lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) if lane[i] >= 0 and lane[i + 1] >= 0]
for lane in data]
lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
lanes = [lane for lane in lanes if len(lane) > 2] # remove lanes with less than 2 points
lanes = [lane for lane in lanes if len(lane) > 3] # remove lanes with less than 2 points

lanes = [sorted(lane, key=lambda x: x[1]) for lane in lanes] # sort by y
# gt_points = [[float(lane[i]) for i in range(len(lane))] for lane in data]
# gt_points = [gt_point for gt_point in gt_points if len(gt_point) > 3]
infos['lanes'] = lanes
# infos['gt_points'] = gt_points

return infos

Expand Down
1 change: 1 addition & 0 deletions lanedet/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def validate(self):
data = self.to_cuda(data)
with torch.no_grad():
output = self.net(data)
output = self.net.module.get_lanes(output)
predictions.extend(output)
if self.cfg.view:
self.val_loader.dataset.view(output, data['meta'])
Expand Down
29 changes: 6 additions & 23 deletions lanedet/models/heads/condlane.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,8 @@ def __init__(self,
self.feat_width = location_configs['size'][-1]
self.mlp = MLP(self.feat_width, 64, 2, 2)

# TODO(check)
self.post_process = CondLanePostProcessor(
mask_size=(1, 80, 200), hm_thr=0.5, use_offset=True,
mask_size=self.cfg.mask_size, hm_thr=0.5, use_offset=True,
nms_thr=4)

self.loss_impl = CondLaneLoss(cfg.loss_weights, 1)
Expand Down Expand Up @@ -749,10 +748,6 @@ def parse_gt(self, gts, device):
gts['row']).to(device)).unsqueeze(0).unsqueeze(0)
row_mask = (torch.from_numpy(
gts['row_mask']).to(device)).unsqueeze(0).unsqueeze(0)
# reg = (gts['reg']).to(device)#.unsqueeze(0)
# reg_mask = (gts['reg_mask']).to(device)#.unsqueeze(0)
# row = (gts['row']).to(device).unsqueeze(0)#.unsqueeze(0)
# row_mask = (gts['row_mask']).to(device).unsqueeze(0)#.unsqueeze(0)
if 'range' in gts:
lane_range = torch.from_numpy(gts['range']).to(device) # new add: squeeze
#lane_range = (gts['range']).to(device).squeeze(0) # new add: squeeze
Expand Down Expand Up @@ -782,19 +777,12 @@ def parse_pos(self, gt_masks, hm_shape, device, mask_shape=None):
for m in m_img:
gts = self.parse_gt(m, device=device)
reg, reg_mask, row, row_mask, lane_range = gts
# print('reg: ', reg.shape)
# print('reg_mask: ', reg_mask.shape)
# print('row ns:', row.shape)
# print('row_mask:', row_mask.shape)
# print('lane_range:', lane_range.shape)
label = m['label']
num += len(m['points'])
for p in m['points']:
pos = idx * n * hm_h * hm_w + label * hm_h * hm_w + p[
1] * hm_w + p[0]
# pos = [idx, label, p[1], p[0]]
poses.append(pos)
# m['label'] = torch.from_numpy(np.array(m['label'])).to(device)
for i in range(len(m['points'])):
labels.append(label)
regs.append(reg)
Expand Down Expand Up @@ -878,8 +866,8 @@ def _format(heat, inds):
def forward_train(self, output, batch):
img_metas = batch['img_metas']
gt_batch_masks = [m['gt_masks'] for m in img_metas]
hm_shape = [20, 50]
mask_shape = [80, 200]
hm_shape = img_metas[0]['hm_shape']
mask_shape = img_metas[0]['mask_shape']
inputs = output
pos, labels, num_ins, gts = self.parse_pos(
gt_batch_masks, hm_shape, inputs[0].device, mask_shape=mask_shape)
Expand Down Expand Up @@ -1025,17 +1013,13 @@ def forward(

def get_lanes(self, output):
out_seeds, out_hm = output['seeds'], output['hm']
#TODO
ret = []
for seeds, hm in zip(out_seeds, out_hm):
downscale=4
lanes, seed = self.post_process(seeds, downscale)
crop_bbox = [0, 270, 1640, 590]
img_shape = (320, 800, 3)
lanes, seed = self.post_process(seeds, self.cfg.mask_down_scale)
result = adjust_result(
lanes=lanes,
crop_bbox=crop_bbox,
img_shape=img_shape,
crop_bbox=self.cfg.crop_bbox,
img_shape=self.cfg.img_scale,
tgt_shape=(590, 1640),
)
lanes = []
Expand All @@ -1044,7 +1028,6 @@ def get_lanes(self, output):
for x, y in lane:
coord.append([x, y])
coord = np.array(coord)
#coord = np.flip(coord, axis=0)
coord[:, 0] /= self.cfg.ori_img_w
coord[:, 1] /= self.cfg.ori_img_h
lanes.append(Lane(coord))
Expand Down
6 changes: 5 additions & 1 deletion lanedet/models/necks/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def __init__(self,

def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.in_channels)
assert len(inputs) >= len(self.in_channels)

if len(inputs) > len(self.in_channels):
for _ in range(len(inputs) - len(self.in_channels)):
del inputs[0]

# build laterals
laterals = [
Expand Down
1 change: 0 additions & 1 deletion lanedet/models/nets/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,5 @@ def forward(self, batch):
output.update(self.heads.loss(out, batch))
else:
output = self.heads(fea)
output = self.heads.get_lanes(output)

return output
2 changes: 1 addition & 1 deletion tools/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, cfg):
def preprocess(self, img_path):
ori_img = cv2.imread(img_path)
img = ori_img[self.cfg.cut_height:, :, :].astype(np.float32)
data = {'img': img}
data = {'img': img, 'lanes': []}
data = self.processes(data)
data['img'] = data['img'].unsqueeze(0)
data.update({'img_path':img_path, 'ori_img':ori_img})
Expand Down