better support for non-CUDA devices (CPU, MPS) (#192)

This commit is contained in:
Ronghang Hu
2024-08-12 10:46:50 -07:00
committed by GitHub
parent 778e112740
commit 1034ee2a1a
8 changed files with 213 additions and 377 deletions

View File

@@ -284,7 +284,9 @@ class SAM2AutomaticMaskGenerator:
orig_h, orig_w = orig_size
# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)