fix mask shape bug

This commit is contained in:
rentainhe
2024-08-09 01:54:40 +08:00
parent 9c5786fc09
commit 80676e866b
4 changed files with 5 additions and 17 deletions

View File

@@ -81,11 +81,7 @@ masks, scores, logits = sam2_predictor.predict(
Post-process the output of the model to get the masks, scores, and logits for visualization
"""
# convert the shape to (n, H, W)
if masks.ndim == 3:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
if masks.ndim == 4:
masks = masks.squeeze(1)