Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add condlane #20

Merged
merged 15 commits into from
Sep 11, 2021
Prev Previous commit
Next Next commit
minor fix
  • Loading branch information
Turoad committed Jun 28, 2021
commit d3e23e884fb051e8ea028a087dc84c6c76a7134e
6 changes: 3 additions & 3 deletions configs/condlane/resnet101_culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
featuremap_out_stride = 8
sample_y = range(589, 230, -20)

batch_size = 4
batch_size = 8
aggregator = dict(
type='TransConvEncoderModule',
in_dim=2048,
Expand Down Expand Up @@ -73,8 +73,8 @@
)

seg_loss_weight = 1.0
eval_ep = 2
save_ep = 2
eval_ep = 1
save_ep = 1

img_norm = dict(
mean=[75.3, 76.6, 77.6],
Expand Down
3 changes: 2 additions & 1 deletion lanedet/datasets/culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def load_annotation(self, line):
lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) if lane[i] >= 0 and lane[i + 1] >= 0]
for lane in data]
lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
lanes = [lane for lane in lanes if len(lane) > 3] # remove lanes with less than 2 points
lanes = [lane for lane in lanes if len(lane) > 2] # remove lanes with less than 2 points

lanes = [sorted(lane, key=lambda x: x[1]) for lane in lanes] # sort by y
# gt_points = [[float(lane[i]) for i in range(len(lane))] for lane in data]
# gt_points = [gt_point for gt_point in gt_points if len(gt_point) > 3]
infos['lanes'] = lanes
# infos['gt_points'] = gt_points

return infos

Expand Down
1 change: 0 additions & 1 deletion lanedet/datasets/process/collect_lane.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def min_dis_one_point(points, idx):
# for j in range(len(endpoints)):
# dis = min_dis_one_point(endpoints, j)
# if dis < 1.5 * radius[j]:
# raise False
# radius[j] = int(max(dis / 1.5, 1) + 0.49999)

for (end_point, line), r in zip(gt_hm_lane_ends, radius):
Expand Down