[update] main inference script

This commit is contained in:
Cheng-Yen Yang
2024-11-19 22:30:23 -08:00
parent c17e4cecc0
commit 47ebf4528b
4 changed files with 188 additions and 5 deletions

3
.gitignore vendored
View File

@@ -155,5 +155,6 @@ cython_debug/
# evaluation results # evaluation results
evaluation_results/* evaluation_results/*
raw_results/* results/*
debug/* debug/*
visualization/*

View File

@@ -6,9 +6,57 @@ This repository is the official implementation of SAMURAI: Adapting Segment Anyt
https://github.com/user-attachments/assets/9d368ca7-2e9b-4fed-9da0-d2efbf620d88 https://github.com/user-attachments/assets/9d368ca7-2e9b-4fed-9da0-d2efbf620d88
## Code ## Getting Started
Coming soon! #### SAMURAI Installation
SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://github.com/facebookresearch/sam2?tab=readme-ov-file) to install both PyTorch and TorchVision dependencies. You can install **the SAMURAI version** of SAM 2 on a GPU machine using:
```
cd sam2
pip install -e .
pip install -e ".[notebooks]"
```
Please see [INSTALL.md](https://github.com/facebookresearch/sam2/blob/main/INSTALL.md) from the original SAM 2 repository for FAQs on potential issues and solutions.
```
pip install requirements.txt
```
#### SAM 2.1 Checkpoint Download
```
cd checkpoints && \
./download_ckpts.sh && \
cd ..
```
#### Data Preparation
Please prepare the data in the following format:
```
data/LaSOT
├── airplane/
│ ├── airplane-1/
│ │ ├── full_occlusion.txt
│ │ ├── groundtruth.txt
│ │ ├── img
│ │ ├── nlp.txt
│ │ └── out_of_view.txt
│ ├── airplane-2/
│ ├── airplane-3/
│ ├── ...
├── basketball
├── bear
├── bicycle
...
├── training_set.txt
└── testing_set.txt
```
#### Main Inference
```
python scripts/main_inference.py
```
## Acknowledgment ## Acknowledgment

View File

@@ -217,8 +217,7 @@ class SAM2Base(torch.nn.Module):
self.memory_bank_obj_score_threshold = memory_bank_obj_score_threshold self.memory_bank_obj_score_threshold = memory_bank_obj_score_threshold
self.memory_bank_kf_score_threshold = memory_bank_kf_score_threshold self.memory_bank_kf_score_threshold = memory_bank_kf_score_threshold
print(f"SAMURAI mode: {self.samurai_mode}") print(f"\033[93mSAMURAI mode: {self.samurai_mode}\033[0m")
print(f"Stable frames threshold: {self.stable_frames_threshold}")
# Model compilation # Model compilation
if compile_image_encoder: if compile_image_encoder:

135
scripts/main_inference.py Normal file
View File

@@ -0,0 +1,135 @@
import cv2
import gc
import numpy as np
import os
import os.path as osp
import pdb
import torch
from sam2.build_sam import build_sam2_video_predictor
from tqdm import tqdm
def load_lasot_gt(gt_path):
with open(gt_path, 'r') as f:
gt = f.readlines()
# bbox in first frame are prompts
prompts = {}
fid = 0
for line in gt:
x, y, w, h = map(int, line.split(','))
prompts[fid] = ((x, y, x+w, y+h), 0)
fid += 1
return prompts
color = [
(255, 0, 0),
]
testing_set = "data/LaSOT/testing_set.txt"
with open(testing_set, 'r') as f:
test_videos = f.readlines()
exp_name = "samurai"
model_name = "base_plus"
checkpoint = f"sam2/checkpoints/sam2.1_hiera_{model_name}.pt"
if model_name == "base_plus":
model_cfg = "configs/samurai/sam2.1_hiera_b+.yaml"
else:
model_cfg = f"configs/samurai/sam2.1_hiera_{model_name[0]}.yaml"
video_folder= "data/LaSOT"
pred_folder = f"results/{exp_name}/{exp_name}_{model_name}"
save_to_video = True
if save_to_video:
vis_folder = f"visualization/{exp_name}/{model_name}"
os.makedirs(vis_folder, exist_ok=True)
vis_mask = {}
vis_bbox = {}
test_videos = sorted(test_videos)
for vid, video in enumerate(test_videos):
cat_name = video.split('-')[0]
cid_name = video.split('-')[1]
video_basename = video.strip()
frame_folder = osp.join(video_folder, cat_name, video.strip(), "img")
num_frames = len(os.listdir(osp.join(video_folder, cat_name, video.strip(), "img")))
print(f"\033[91mRunning video [{vid+1}/{len(test_videos)}]: {video} with {num_frames} frames\033[0m")
height, width = cv2.imread(osp.join(frame_folder, "00000001.jpg")).shape[:2]
predictor = build_sam2_video_predictor(model_cfg, checkpoint, device="cuda:0")
predictions = []
if save_to_video:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(osp.join(vis_folder, f'{video_basename}.mp4'), fourcc, 30, (width, height))
# Start processing frames
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
state = predictor.init_state(frame_folder, offload_video_to_cpu=True, offload_state_to_cpu=True, async_loading_frames=True)
prompts = load_lasot_gt(osp.join(video_folder, cat_name, video.strip(), "groundtruth.txt"))
bbox, track_label = prompts[0]
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
mask_to_vis = {}
bbox_to_vis = {}
assert len(masks) == 1 and len(object_ids) == 1, "Only one object is supported right now"
for obj_id, mask in zip(object_ids, masks):
mask = mask[0].cpu().numpy()
mask = mask > 0.0
non_zero_indices = np.argwhere(mask)
if len(non_zero_indices) == 0:
bbox = [0, 0, 0, 0]
else:
y_min, x_min = non_zero_indices.min(axis=0).tolist()
y_max, x_max = non_zero_indices.max(axis=0).tolist()
bbox = [x_min, y_min, x_max-x_min, y_max-y_min]
bbox_to_vis[obj_id] = bbox
mask_to_vis[obj_id] = mask
if save_to_video:
img = cv2.imread(f'{frame_folder}/{frame_idx+1:08d}.jpg')
if img is None:
break
for obj_id in mask_to_vis.keys():
mask_img = np.zeros((height, width, 3), np.uint8)
mask_img[mask_to_vis[obj_id]] = color[(obj_id+1)%len(color)]
img = cv2.addWeighted(img, 1, mask_img, 0.75, 0)
for obj_id in bbox_to_vis.keys():
cv2.rectangle(img, (bbox_to_vis[obj_id][0], bbox_to_vis[obj_id][1]), (bbox_to_vis[obj_id][0]+bbox_to_vis[obj_id][2], bbox_to_vis[obj_id][1]+bbox_to_vis[obj_id][3]), color[(obj_id)%len(color)], 2)
x1, y1, x2, y2 = prompts[frame_idx][0]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
out.write(img)
predictions.append(bbox_to_vis)
os.makedirs(pred_folder, exist_ok=True)
with open(osp.join(pred_folder, f'{video_basename}.txt'), 'w') as f:
for pred in predictions:
x, y, w, h = pred[0]
f.write(f"{x},{y},{w},{h}\n")
if save_to_video:
out.release()
del predictor
del state
gc.collect()
torch.clear_autocast_cache()
torch.cuda.empty_cache()