diff --git a/.gitignore b/.gitignore index 6b30907..a668dfe 100644 --- a/.gitignore +++ b/.gitignore @@ -155,5 +155,6 @@ cython_debug/ # evaluation results evaluation_results/* -raw_results/* +results/* debug/* +visualization/* diff --git a/README.md b/README.md index c8ba726..b8efd5a 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,57 @@ This repository is the official implementation of SAMURAI: Adapting Segment Anyt https://github.com/user-attachments/assets/9d368ca7-2e9b-4fed-9da0-d2efbf620d88 -## Code +## Getting Started -Coming soon! +#### SAMURAI Installation + +SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://github.com/facebookresearch/sam2?tab=readme-ov-file) to install both PyTorch and TorchVision dependencies. You can install **the SAMURAI version** of SAM 2 on a GPU machine using: +``` +cd sam2 +pip install -e . +pip install -e ".[notebooks]" +``` + +Please see [INSTALL.md](https://github.com/facebookresearch/sam2/blob/main/INSTALL.md) from the original SAM 2 repository for FAQs on potential issues and solutions. +``` +pip install requirements.txt +``` + +#### SAM 2.1 Checkpoint Download + +``` +cd checkpoints && \ +./download_ckpts.sh && \ +cd .. +``` + +#### Data Preparation + +Please prepare the data in the following format: +``` +data/LaSOT +├── airplane/ +│ ├── airplane-1/ +│ │ ├── full_occlusion.txt +│ │ ├── groundtruth.txt +│ │ ├── img +│ │ ├── nlp.txt +│ │ └── out_of_view.txt +│ ├── airplane-2/ +│ ├── airplane-3/ +│ ├── ... +├── basketball +├── bear +├── bicycle +... +├── training_set.txt +└── testing_set.txt +``` + +#### Main Inference +``` +python scripts/main_inference.py +``` ## Acknowledgment diff --git a/sam2/sam2/modeling/sam2_base.py b/sam2/sam2/modeling/sam2_base.py index 658e671..0ae6b9f 100644 --- a/sam2/sam2/modeling/sam2_base.py +++ b/sam2/sam2/modeling/sam2_base.py @@ -217,8 +217,7 @@ class SAM2Base(torch.nn.Module): self.memory_bank_obj_score_threshold = memory_bank_obj_score_threshold self.memory_bank_kf_score_threshold = memory_bank_kf_score_threshold - print(f"SAMURAI mode: {self.samurai_mode}") - print(f"Stable frames threshold: {self.stable_frames_threshold}") + print(f"\033[93mSAMURAI mode: {self.samurai_mode}\033[0m") # Model compilation if compile_image_encoder: diff --git a/scripts/main_inference.py b/scripts/main_inference.py new file mode 100644 index 0000000..542b5da --- /dev/null +++ b/scripts/main_inference.py @@ -0,0 +1,135 @@ +import cv2 +import gc +import numpy as np +import os +import os.path as osp +import pdb +import torch +from sam2.build_sam import build_sam2_video_predictor +from tqdm import tqdm + + +def load_lasot_gt(gt_path): + with open(gt_path, 'r') as f: + gt = f.readlines() + + # bbox in first frame are prompts + prompts = {} + fid = 0 + for line in gt: + x, y, w, h = map(int, line.split(',')) + prompts[fid] = ((x, y, x+w, y+h), 0) + fid += 1 + + return prompts + +color = [ + (255, 0, 0), +] + +testing_set = "data/LaSOT/testing_set.txt" +with open(testing_set, 'r') as f: + test_videos = f.readlines() + +exp_name = "samurai" +model_name = "base_plus" + +checkpoint = f"sam2/checkpoints/sam2.1_hiera_{model_name}.pt" +if model_name == "base_plus": + model_cfg = "configs/samurai/sam2.1_hiera_b+.yaml" +else: + model_cfg = f"configs/samurai/sam2.1_hiera_{model_name[0]}.yaml" + +video_folder= "data/LaSOT" +pred_folder = f"results/{exp_name}/{exp_name}_{model_name}" + +save_to_video = True +if save_to_video: + vis_folder = f"visualization/{exp_name}/{model_name}" + os.makedirs(vis_folder, exist_ok=True) + vis_mask = {} + vis_bbox = {} + +test_videos = sorted(test_videos) +for vid, video in enumerate(test_videos): + + cat_name = video.split('-')[0] + cid_name = video.split('-')[1] + video_basename = video.strip() + frame_folder = osp.join(video_folder, cat_name, video.strip(), "img") + + num_frames = len(os.listdir(osp.join(video_folder, cat_name, video.strip(), "img"))) + + print(f"\033[91mRunning video [{vid+1}/{len(test_videos)}]: {video} with {num_frames} frames\033[0m") + + height, width = cv2.imread(osp.join(frame_folder, "00000001.jpg")).shape[:2] + + predictor = build_sam2_video_predictor(model_cfg, checkpoint, device="cuda:0") + + predictions = [] + + if save_to_video: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(osp.join(vis_folder, f'{video_basename}.mp4'), fourcc, 30, (width, height)) + + # Start processing frames + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + state = predictor.init_state(frame_folder, offload_video_to_cpu=True, offload_state_to_cpu=True, async_loading_frames=True) + + prompts = load_lasot_gt(osp.join(video_folder, cat_name, video.strip(), "groundtruth.txt")) + + bbox, track_label = prompts[0] + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0) + + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + mask_to_vis = {} + bbox_to_vis = {} + + assert len(masks) == 1 and len(object_ids) == 1, "Only one object is supported right now" + for obj_id, mask in zip(object_ids, masks): + mask = mask[0].cpu().numpy() + mask = mask > 0.0 + non_zero_indices = np.argwhere(mask) + if len(non_zero_indices) == 0: + bbox = [0, 0, 0, 0] + else: + y_min, x_min = non_zero_indices.min(axis=0).tolist() + y_max, x_max = non_zero_indices.max(axis=0).tolist() + bbox = [x_min, y_min, x_max-x_min, y_max-y_min] + bbox_to_vis[obj_id] = bbox + mask_to_vis[obj_id] = mask + + if save_to_video: + + img = cv2.imread(f'{frame_folder}/{frame_idx+1:08d}.jpg') + if img is None: + break + + for obj_id in mask_to_vis.keys(): + mask_img = np.zeros((height, width, 3), np.uint8) + mask_img[mask_to_vis[obj_id]] = color[(obj_id+1)%len(color)] + img = cv2.addWeighted(img, 1, mask_img, 0.75, 0) + + for obj_id in bbox_to_vis.keys(): + cv2.rectangle(img, (bbox_to_vis[obj_id][0], bbox_to_vis[obj_id][1]), (bbox_to_vis[obj_id][0]+bbox_to_vis[obj_id][2], bbox_to_vis[obj_id][1]+bbox_to_vis[obj_id][3]), color[(obj_id)%len(color)], 2) + + x1, y1, x2, y2 = prompts[frame_idx][0] + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) + out.write(img) + + predictions.append(bbox_to_vis) + + os.makedirs(pred_folder, exist_ok=True) + with open(osp.join(pred_folder, f'{video_basename}.txt'), 'w') as f: + for pred in predictions: + x, y, w, h = pred[0] + f.write(f"{x},{y},{w},{h}\n") + + if save_to_video: + out.release() + + del predictor + del state + gc.collect() + torch.clear_autocast_cache() + torch.cuda.empty_cache()