2024-07-29 21:54:20 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List , Optional , Tuple , Union
import numpy as np
import torch
from PIL . Image import Image
from sam2 . modeling . sam2_base import SAM2Base
2024-08-06 08:34:42 +02:00
2024-07-29 21:54:20 +00:00
from sam2 . utils . transforms import SAM2Transforms
class SAM2ImagePredictor :
def __init__ (
self ,
sam_model : SAM2Base ,
mask_threshold = 0.0 ,
max_hole_area = 0.0 ,
max_sprinkle_area = 0.0 ,
) - > None :
"""
Uses SAM - 2 to calculate the image embedding for an image , and then
allow repeated , efficient mask prediction given prompts .
Arguments :
sam_model ( Sam - 2 ) : The model to use for mask prediction .
mask_threshold ( float ) : The threshold to use when converting mask logits
to binary masks . Masks are thresholded at 0 by default .
fill_hole_area ( int ) : If fill_hole_area > 0 , we fill small holes in up to
the maximum area of fill_hole_area in low_res_masks .
"""
super ( ) . __init__ ( )
self . model = sam_model
self . _transforms = SAM2Transforms (
resolution = self . model . image_size ,
mask_threshold = mask_threshold ,
max_hole_area = max_hole_area ,
max_sprinkle_area = max_sprinkle_area ,
)
# Predictor state
self . _is_image_set = False
self . _features = None
self . _orig_hw = None
# Whether the predictor is set for single image or a batch of images
self . _is_batch = False
# Predictor config
self . mask_threshold = mask_threshold
# Spatial dim for backbone feature maps
self . _bb_feat_sizes = [
( 256 , 256 ) ,
( 128 , 128 ) ,
( 64 , 64 ) ,
]
2024-08-03 14:14:12 +02:00
@classmethod
def from_pretrained ( cls , model_id : str , * * kwargs ) - > " SAM2ImagePredictor " :
2024-08-03 12:57:05 +02:00
"""
2024-08-06 23:00:26 +02:00
Load a pretrained model from the Hugging Face hub .
2024-08-03 12:57:05 +02:00
Arguments :
model_id ( str ) : The Hugging Face repository ID .
* * kwargs : Additional arguments to pass to the model constructor .
Returns :
( SAM2ImagePredictor ) : The loaded model .
"""
2024-08-05 09:37:53 +02:00
from sam2 . build_sam import build_sam2_hf
2024-08-03 12:57:05 +02:00
sam_model = build_sam2_hf ( model_id , * * kwargs )
2024-08-03 14:14:12 +02:00
return cls ( sam_model )
2024-08-03 12:57:05 +02:00
2024-07-29 21:54:20 +00:00
@torch.no_grad ( )
def set_image (
self ,
image : Union [ np . ndarray , Image ] ,
) - > None :
"""
Calculates the image embeddings for the provided image , allowing
masks to be predicted with the ' predict ' method .
Arguments :
image ( np . ndarray or PIL Image ) : The input image to embed in RGB format . The image should be in HWC format if np . ndarray , or WHC format if PIL Image
with pixel values in [ 0 , 255 ] .
image_format ( str ) : The color format of the image , in [ ' RGB ' , ' BGR ' ] .
"""
self . reset_predictor ( )
# Transform the image to the form expected by the model
if isinstance ( image , np . ndarray ) :
logging . info ( " For numpy array image, we assume (HxWxC) format " )
self . _orig_hw = [ image . shape [ : 2 ] ]
elif isinstance ( image , Image ) :
w , h = image . size
self . _orig_hw = [ ( h , w ) ]
else :
raise NotImplementedError ( " Image format not supported " )
input_image = self . _transforms ( image )
input_image = input_image [ None , . . . ] . to ( self . device )
assert (
len ( input_image . shape ) == 4 and input_image . shape [ 1 ] == 3
) , f " input_image must be of size 1x3xHxW, got { input_image . shape } "
logging . info ( " Computing image embeddings for the provided image... " )
backbone_out = self . model . forward_image ( input_image )
_ , vision_feats , _ , _ = self . model . _prepare_backbone_features ( backbone_out )
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self . model . directly_add_no_mem_embed :
vision_feats [ - 1 ] = vision_feats [ - 1 ] + self . model . no_mem_embed
feats = [
feat . permute ( 1 , 2 , 0 ) . view ( 1 , - 1 , * feat_size )
for feat , feat_size in zip ( vision_feats [ : : - 1 ] , self . _bb_feat_sizes [ : : - 1 ] )
] [ : : - 1 ]
self . _features = { " image_embed " : feats [ - 1 ] , " high_res_feats " : feats [ : - 1 ] }
self . _is_image_set = True
logging . info ( " Image embeddings computed. " )
@torch.no_grad ( )
def set_image_batch (
self ,
image_list : List [ Union [ np . ndarray ] ] ,
) - > None :
"""
Calculates the image embeddings for the provided image batch , allowing
masks to be predicted with the ' predict_batch ' method .
Arguments :
image_list ( List [ np . ndarray ] ) : The input images to embed in RGB format . The image should be in HWC format if np . ndarray
with pixel values in [ 0 , 255 ] .
"""
self . reset_predictor ( )
assert isinstance ( image_list , list )
self . _orig_hw = [ ]
for image in image_list :
assert isinstance (
image , np . ndarray
) , " Images are expected to be an np.ndarray in RGB format, and of shape HWC "
self . _orig_hw . append ( image . shape [ : 2 ] )
# Transform the image to the form expected by the model
img_batch = self . _transforms . forward_batch ( image_list )
img_batch = img_batch . to ( self . device )
batch_size = img_batch . shape [ 0 ]
assert (
len ( img_batch . shape ) == 4 and img_batch . shape [ 1 ] == 3
) , f " img_batch must be of size Bx3xHxW, got { img_batch . shape } "
logging . info ( " Computing image embeddings for the provided images... " )
backbone_out = self . model . forward_image ( img_batch )
_ , vision_feats , _ , _ = self . model . _prepare_backbone_features ( backbone_out )
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self . model . directly_add_no_mem_embed :
vision_feats [ - 1 ] = vision_feats [ - 1 ] + self . model . no_mem_embed
feats = [
feat . permute ( 1 , 2 , 0 ) . view ( batch_size , - 1 , * feat_size )
for feat , feat_size in zip ( vision_feats [ : : - 1 ] , self . _bb_feat_sizes [ : : - 1 ] )
] [ : : - 1 ]
self . _features = { " image_embed " : feats [ - 1 ] , " high_res_feats " : feats [ : - 1 ] }
self . _is_image_set = True
self . _is_batch = True
logging . info ( " Image embeddings computed. " )
def predict_batch (
self ,
point_coords_batch : List [ np . ndarray ] = None ,
point_labels_batch : List [ np . ndarray ] = None ,
box_batch : List [ np . ndarray ] = None ,
mask_input_batch : List [ np . ndarray ] = None ,
multimask_output : bool = True ,
return_logits : bool = False ,
normalize_coords = True ,
) - > Tuple [ List [ np . ndarray ] , List [ np . ndarray ] , List [ np . ndarray ] ] :
""" This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
It returns a tupele of lists of masks , ious , and low_res_masks_logits .
"""
assert self . _is_batch , " This function should only be used when in batched mode "
if not self . _is_image_set :
raise RuntimeError (
" An image must be set with .set_image_batch(...) before mask prediction. "
)
num_images = len ( self . _features [ " image_embed " ] )
all_masks = [ ]
all_ious = [ ]
all_low_res_masks = [ ]
for img_idx in range ( num_images ) :
# Transform input prompts
point_coords = (
point_coords_batch [ img_idx ] if point_coords_batch is not None else None
)
point_labels = (
point_labels_batch [ img_idx ] if point_labels_batch is not None else None
)
box = box_batch [ img_idx ] if box_batch is not None else None
mask_input = (
mask_input_batch [ img_idx ] if mask_input_batch is not None else None
)
mask_input , unnorm_coords , labels , unnorm_box = self . _prep_prompts (
point_coords ,
point_labels ,
box ,
mask_input ,
normalize_coords ,
img_idx = img_idx ,
)
masks , iou_predictions , low_res_masks = self . _predict (
unnorm_coords ,
labels ,
unnorm_box ,
mask_input ,
multimask_output ,
return_logits = return_logits ,
img_idx = img_idx ,
)
masks_np = masks . squeeze ( 0 ) . float ( ) . detach ( ) . cpu ( ) . numpy ( )
iou_predictions_np = (
iou_predictions . squeeze ( 0 ) . float ( ) . detach ( ) . cpu ( ) . numpy ( )
)
low_res_masks_np = low_res_masks . squeeze ( 0 ) . float ( ) . detach ( ) . cpu ( ) . numpy ( )
all_masks . append ( masks_np )
all_ious . append ( iou_predictions_np )
all_low_res_masks . append ( low_res_masks_np )
return all_masks , all_ious , all_low_res_masks
def predict (
self ,
point_coords : Optional [ np . ndarray ] = None ,
point_labels : Optional [ np . ndarray ] = None ,
box : Optional [ np . ndarray ] = None ,
mask_input : Optional [ np . ndarray ] = None ,
multimask_output : bool = True ,
return_logits : bool = False ,
normalize_coords = True ,
) - > Tuple [ np . ndarray , np . ndarray , np . ndarray ] :
"""
Predict masks for the given input prompts , using the currently set image .
Arguments :
point_coords ( np . ndarray or None ) : A Nx2 array of point prompts to the
model . Each point is in ( X , Y ) in pixels .
point_labels ( np . ndarray or None ) : A length N array of labels for the
point prompts . 1 indicates a foreground point and 0 indicates a
background point .
box ( np . ndarray or None ) : A length 4 array given a box prompt to the
model , in XYXY format .
mask_input ( np . ndarray ) : A low resolution mask input to the model , typically
coming from a previous prediction iteration . Has form 1 xHxW , where
for SAM , H = W = 256.
multimask_output ( bool ) : If true , the model will return three masks .
For ambiguous input prompts ( such as a single click ) , this will often
produce better masks than a single prediction . If only a single
mask is needed , the model ' s predicted quality score can be used
to select the best mask . For non - ambiguous prompts , such as multiple
input prompts , multimask_output = False can give better results .
return_logits ( bool ) : If true , returns un - thresholded masks logits
instead of a binary mask .
normalize_coords ( bool ) : If true , the point coordinates will be normalized to the range [ 0 , 1 ] and point_coords is expected to be wrt . image dimensions .
Returns :
( np . ndarray ) : The output masks in CxHxW format , where C is the
number of masks , and ( H , W ) is the original image size .
( np . ndarray ) : An array of length C containing the model ' s
predictions for the quality of each mask .
( np . ndarray ) : An array of shape CxHxW , where C is the number
of masks and H = W = 256. These low resolution logits can be passed to
a subsequent iteration as mask input .
"""
if not self . _is_image_set :
raise RuntimeError (
" An image must be set with .set_image(...) before mask prediction. "
)
# Transform input prompts
mask_input , unnorm_coords , labels , unnorm_box = self . _prep_prompts (
point_coords , point_labels , box , mask_input , normalize_coords
)
masks , iou_predictions , low_res_masks = self . _predict (
unnorm_coords ,
labels ,
unnorm_box ,
mask_input ,
multimask_output ,
return_logits = return_logits ,
)
masks_np = masks . squeeze ( 0 ) . float ( ) . detach ( ) . cpu ( ) . numpy ( )
iou_predictions_np = iou_predictions . squeeze ( 0 ) . float ( ) . detach ( ) . cpu ( ) . numpy ( )
low_res_masks_np = low_res_masks . squeeze ( 0 ) . float ( ) . detach ( ) . cpu ( ) . numpy ( )
return masks_np , iou_predictions_np , low_res_masks_np
def _prep_prompts (
self , point_coords , point_labels , box , mask_logits , normalize_coords , img_idx = - 1
) :
unnorm_coords , labels , unnorm_box , mask_input = None , None , None , None
if point_coords is not None :
assert (
point_labels is not None
) , " point_labels must be supplied if point_coords is supplied. "
point_coords = torch . as_tensor (
point_coords , dtype = torch . float , device = self . device
)
unnorm_coords = self . _transforms . transform_coords (
point_coords , normalize = normalize_coords , orig_hw = self . _orig_hw [ img_idx ]
)
labels = torch . as_tensor ( point_labels , dtype = torch . int , device = self . device )
if len ( unnorm_coords . shape ) == 2 :
unnorm_coords , labels = unnorm_coords [ None , . . . ] , labels [ None , . . . ]
if box is not None :
box = torch . as_tensor ( box , dtype = torch . float , device = self . device )
unnorm_box = self . _transforms . transform_boxes (
box , normalize = normalize_coords , orig_hw = self . _orig_hw [ img_idx ]
) # Bx2x2
if mask_logits is not None :
mask_input = torch . as_tensor (
mask_logits , dtype = torch . float , device = self . device
)
if len ( mask_input . shape ) == 3 :
mask_input = mask_input [ None , : , : , : ]
return mask_input , unnorm_coords , labels , unnorm_box
@torch.no_grad ( )
def _predict (
self ,
point_coords : Optional [ torch . Tensor ] ,
point_labels : Optional [ torch . Tensor ] ,
boxes : Optional [ torch . Tensor ] = None ,
mask_input : Optional [ torch . Tensor ] = None ,
multimask_output : bool = True ,
return_logits : bool = False ,
img_idx : int = - 1 ,
) - > Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
"""
Predict masks for the given input prompts , using the currently set image .
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using SAM2Transforms .
Arguments :
point_coords ( torch . Tensor or None ) : A BxNx2 array of point prompts to the
model . Each point is in ( X , Y ) in pixels .
point_labels ( torch . Tensor or None ) : A BxN array of labels for the
point prompts . 1 indicates a foreground point and 0 indicates a
background point .
boxes ( np . ndarray or None ) : A Bx4 array given a box prompt to the
model , in XYXY format .
mask_input ( np . ndarray ) : A low resolution mask input to the model , typically
coming from a previous prediction iteration . Has form Bx1xHxW , where
for SAM , H = W = 256. Masks returned by a previous iteration of the
predict method do not need further transformation .
multimask_output ( bool ) : If true , the model will return three masks .
For ambiguous input prompts ( such as a single click ) , this will often
produce better masks than a single prediction . If only a single
mask is needed , the model ' s predicted quality score can be used
to select the best mask . For non - ambiguous prompts , such as multiple
input prompts , multimask_output = False can give better results .
return_logits ( bool ) : If true , returns un - thresholded masks logits
instead of a binary mask .
Returns :
( torch . Tensor ) : The output masks in BxCxHxW format , where C is the
number of masks , and ( H , W ) is the original image size .
( torch . Tensor ) : An array of shape BxC containing the model ' s
predictions for the quality of each mask .
( torch . Tensor ) : An array of shape BxCxHxW , where C is the number
of masks and H = W = 256. These low res logits can be passed to
a subsequent iteration as mask input .
"""
if not self . _is_image_set :
raise RuntimeError (
" An image must be set with .set_image(...) before mask prediction. "
)
if point_coords is not None :
concat_points = ( point_coords , point_labels )
else :
concat_points = None
# Embed prompts
if boxes is not None :
box_coords = boxes . reshape ( - 1 , 2 , 2 )
box_labels = torch . tensor ( [ [ 2 , 3 ] ] , dtype = torch . int , device = boxes . device )
box_labels = box_labels . repeat ( boxes . size ( 0 ) , 1 )
# we merge "boxes" and "points" into a single "concat_points" input (where
# boxes are added at the beginning) to sam_prompt_encoder
if concat_points is not None :
concat_coords = torch . cat ( [ box_coords , concat_points [ 0 ] ] , dim = 1 )
concat_labels = torch . cat ( [ box_labels , concat_points [ 1 ] ] , dim = 1 )
concat_points = ( concat_coords , concat_labels )
else :
concat_points = ( box_coords , box_labels )
sparse_embeddings , dense_embeddings = self . model . sam_prompt_encoder (
points = concat_points ,
boxes = None ,
masks = mask_input ,
)
# Predict masks
batched_mode = (
concat_points is not None and concat_points [ 0 ] . shape [ 0 ] > 1
) # multi object prediction
high_res_features = [
feat_level [ img_idx ] . unsqueeze ( 0 )
for feat_level in self . _features [ " high_res_feats " ]
]
low_res_masks , iou_predictions , _ , _ = self . model . sam_mask_decoder (
image_embeddings = self . _features [ " image_embed " ] [ img_idx ] . unsqueeze ( 0 ) ,
image_pe = self . model . sam_prompt_encoder . get_dense_pe ( ) ,
sparse_prompt_embeddings = sparse_embeddings ,
dense_prompt_embeddings = dense_embeddings ,
multimask_output = multimask_output ,
repeat_image = batched_mode ,
high_res_features = high_res_features ,
)
# Upscale the masks to the original image resolution
masks = self . _transforms . postprocess_masks (
low_res_masks , self . _orig_hw [ img_idx ]
)
low_res_masks = torch . clamp ( low_res_masks , - 32.0 , 32.0 )
if not return_logits :
masks = masks > self . mask_threshold
return masks , iou_predictions , low_res_masks
def get_image_embedding ( self ) - > torch . Tensor :
"""
Returns the image embeddings for the currently set image , with
shape 1 xCxHxW , where C is the embedding dimension and ( H , W ) are
the embedding spatial dimension of SAM ( typically C = 256 , H = W = 64 ) .
"""
if not self . _is_image_set :
raise RuntimeError (
" An image must be set with .set_image(...) to generate an embedding. "
)
assert (
self . _features is not None
) , " Features must exist if an image has been set. "
return self . _features [ " image_embed " ]
@property
def device ( self ) - > torch . device :
return self . model . device
def reset_predictor ( self ) - > None :
"""
Resets the image embeddings and other state variables .
"""
self . _is_image_set = False
self . _features = None
self . _orig_hw = None
self . _is_batch = False