Skip to content

Predict from a video file #25

Open
@Masrur02

Description

Hi,
When I ran the code python grounded_sam2_local_demo.py
the result was good with a prompt text="car. road."
grounded_sam2_annotated_image_with_mask

But, when I have modified the code to read images from a video file and keep looping

import cv2
import torch
import numpy as np
import supervision as sv
from torchvision.ops import box_convert
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
import time
import os

# Environment settings
# Use bfloat16 only where supported

# Build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
sam2_predictor = SAM2ImagePredictor(sam2_model)

# Build Grounding DINO model
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
grounding_model = load_model(
    model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py", 
    model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth",
    device=device
)

# Setup the input text prompt for Grounding DINO
text = "road. car."
output_dir = "test"
os.makedirs(output_dir, exist_ok=True)

# Capture video
video_path = 'notebooks/videos/indy.mp4'
cap = cv2.VideoCapture(video_path)
frame_num = 0

while cap.isOpened():
    start_time = time.time()
    ret, frame = cap.read()
    if not ret:
        break

    #time.sleep(0.1)

    # Convert the frame to the required format for processing
    image_source, image = load_image(frame)
   

    sam2_predictor.set_image(image_source)

    boxes, confidences, labels = predict(
        model=grounding_model,
        image=image,
        caption=text,
        box_threshold=0.35,
        text_threshold=0.25
    )

    # Process the box prompt for SAM2
    h, w, _ = frame.shape
    boxes = boxes * torch.Tensor([w, h, w, h])
    input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()

    # Enable mixed precision only for the specific block
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        if torch.cuda.get_device_properties(0).major >= 8:
            # Enable tfloat32 for Ampere GPUs
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True

        # Perform SAM2 prediction within the mixed precision context
        masks, scores, logits = sam2_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_boxes,
            multimask_output=False,
        )

    # Post-process the output of the model to get the masks, scores, and logits for visualization
    if masks.ndim == 4:
        masks = masks.squeeze(1)

    confidences = confidences.numpy().tolist()
    class_names = labels
    class_ids = np.array(list(range(len(class_names))))

    labels = [
        f"{class_name} {confidence:.2f}"
        for class_name, confidence
        in zip(class_names, confidences)
    ]

    # Calculate FPS
    end_time = time.time()
    fps = 1 / (end_time - start_time)
    
    # Visualize image with supervision API
    detections = sv.Detections(
        xyxy=input_boxes,  # (n, 4)
        mask=masks.astype(bool),  # (n, h, w)
        class_id=class_ids
    )

    box_annotator = sv.BoxAnnotator()
    annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections)

    label_annotator = sv.LabelAnnotator()
    annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)

    mask_annotator = sv.MaskAnnotator()
    annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
    mask_image_save_path = os.path.join(output_dir, f"{frame_num:04d}_mask.jpg")

    cv2.imwrite(mask_image_save_path, annotated_frame)
    print(f"FPS for frame {frame_num}: {fps:.2f}")

    frame_num += 1

cap.release()
cv2.destroyAllWindows()

the result has become very bad

0002_mask

What is the reason? Can you please help??

TIA

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions