366 lines
7.5 MiB
Plaintext
366 lines
7.5 MiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "5fa21d44",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
|
||
|
"# Lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b7c0041e",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Automatically generating object masks with SAM"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "289bb0b4",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Since SAM 2 can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image.\n",
|
||
|
"\n",
|
||
|
"The class `SAM2AutomaticMaskGenerator` implements this capability. It works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "c0b71431",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Environment Set-up"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "47e5a78f",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"If running locally using jupyter, first install `SAM 2` in your environment using the installation instructions in the repository."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "fd2bc687",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Set-up"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "560725a2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import numpy as np\n",
|
||
|
"import torch\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from PIL import Image\n",
|
||
|
"\n",
|
||
|
"# use bfloat16 for the entire notebook\n",
|
||
|
"torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16).__enter__()\n",
|
||
|
"\n",
|
||
|
"if torch.cuda.get_device_properties(0).major >= 8:\n",
|
||
|
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
|
||
|
" torch.backends.cuda.matmul.allow_tf32 = True\n",
|
||
|
" torch.backends.cudnn.allow_tf32 = True"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "74b6e5f0",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def show_anns(anns, borders=True):\n",
|
||
|
" if len(anns) == 0:\n",
|
||
|
" return\n",
|
||
|
" sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)\n",
|
||
|
" ax = plt.gca()\n",
|
||
|
" ax.set_autoscale_on(False)\n",
|
||
|
"\n",
|
||
|
" img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))\n",
|
||
|
" img[:,:,3] = 0\n",
|
||
|
" for ann in sorted_anns:\n",
|
||
|
" m = ann['segmentation']\n",
|
||
|
" color_mask = np.concatenate([np.random.random(3), [0.5]])\n",
|
||
|
" img[m] = color_mask \n",
|
||
|
" if borders:\n",
|
||
|
" import cv2\n",
|
||
|
" contours, _ = cv2.findContours(m.astype(np.uint8),cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) \n",
|
||
|
" # Try to smooth contours\n",
|
||
|
" contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]\n",
|
||
|
" cv2.drawContours(img, contours, -1, (0,0,1,0.4), thickness=1) \n",
|
||
|
"\n",
|
||
|
" ax.imshow(img)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "27c41445",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Example image"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "ad354922",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"image = Image.open('images/cars.jpg')\n",
|
||
|
"image = np.array(image.convert(\"RGB\"))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "e0ac8c67",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABiIAAAQdCAYAAADTkVDqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz93bIkSZIeBn6qZuYRJ6uqu6cHmBkCnKGQlMXF3vBuZR9jH2afdnchQgFAArMApqcqK/NEuJup8kJVzdT9nB4UL0ZWZCWspTozz4lwtx81/fn0j1RV8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8c8w+P/XE3iN13iN13iN13iN13iN13iN13iN13iN13iN13iN13iN13iN//8dL0fEa7zGa7zGa7zGa7zGa7zGa7zGa7zGa7zGa7zGa7zGa7zGa/yzjZcj4jVe4zVe4zVe4zVe4zVe4zVe4zVe4zVe4zVe4zVe4zVe4zX+2cbLEfEar/Ear/Ear/Ear/Ear/Ear/Ear/Ear/Ear/Ear/Ear/Ea/2zj5Yh4jdd4jdd4jdd4jdd4jdd4jdd4jdd4jdd4jdd4jdd4jdd4jX+28XJEvMZrvMZrvMZrvMZrvMZrvMZrvMZrvMZrvMZrvMZrvMZrvMY/23g5Il7jNV7jNV7jNV7jNV7jNV7jNV7jNV7jNV7jNV7jNV7jNV7jn228HBGv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8Rqv8c82Xo6I13iN13iN13iN13iN13iN13iN13iN13iN13iN13iN13iN1/hnG/W3fnD7N/8X9KODFCjMEFVwYYAIxAxmggwBCCAqICH0b+/g3vGv/uaP+Ju//iOIgH//7/53/OM//Iz39wfQBwgE1ILy5Q7c78C2obQGEEFFIWNgQFAKAyCQKsa+Q8dAIdg8WgUpg4lQlLC/7+i74Pe//wF/93d/g9/94Xd4PHf8+//4n/Ht8USHAIXAXKBE6DrApYCIMbpAVVEI6L1DVSDHABQoVPBGFX/9xz/ib//1v8KPX95w2yra1iAi6GOAmbEfO479gV/+6z9ARscQwX/8T/8R39+/4/3xRB8dJArpAyoCiGAofdhzVQUR2bqJ0bYb/uW//Bf4n/7N34G3Df/rv/+P+K9/+ooOBd8aemUIMwiAvD/Rv36DHh3t7QaqFaICHQImQETAomAFBgS8VQwoKhcUBeSxg44BOQ7I6Lh/ecP/+D//T/j9H/8AJcL/9p/+E37+/o6ff/2Gfgi2+w21VTzfv0NkoH15Q/vyBWCCdAGIACjk63cc394BJmy//xFgxuPrN7S2YfvxBwgRSi0olVG5ogjw+y9v+Nv/7i9xu234T3//X/G//f1/xve+Y6iCiKGFgcKACB5/+kfoYwe3Cr1v0FpQ2oZSik/B9l1EAAXI6ZdKsf1RQiECKaGA8Vc//YC//sPv8fuffsI/fvuO//3v/x7vzweECdpuUC5O8wRRxVABVAEiMNuciNjeKwoSwY9bw3//F3+Jv/2rv8JbadhagwLgUsDM0DGgAB7ff8XXP/0DhAS/Ph74d3///8UvzwOP/oTIQCkFXRmqCi4EJoYCAAEKhYoa3XLBkAHmgo0Zf/vXf4X/8W//NWph/K///j/gP//pK375/t3nOCCqAPvdJmCoopQCMIFAUNjPjTYBAgCFfb7afEgB7QIVNXpXwZfbhn/zd3+Lf/XXfwWRgX/3H/4D/uFP/4DvxxNDBMpsa+odtVajUS44jg4udr9V7V0i4r8ntLaBmTGgUAJ6H2BiQBQQBQP4crvjf/jX/xp/81d/jefjO/5f//b/g3/45St6P0BMtm9syyiFjR8g7qQCAPoYaLVh7wf240D187J9Kjj2HaN3bLcNKoTCBdoHfnz7gv/+v/sb/Mu/+Av88vM/4v/9b/8tfv75K8bodidIjQaZ0ftA5QJxulFVqMLXrWAoNgZ+vFX88acf8LsfvuCHtw0FiloL9n0HiCBKoMKotYKZISIYY8wz2/cdRIRaG4gIx9EhoujHAJeK+/0Nj8cTAFBKQWvNaIjMd01O33av7O+qitYaVG2/juOYPGzSChFEZH5PBPO7zPZsZsIYHaUUqCq2bcNz3zFkoPcOJvssEaGUglJsfaqKrRWMYXx1HB1CRsqFGa0UcCFsbUOrdc4TIIwhEBl4Pp849h3P9yfe35/4+edf8F/+4U/4/nhAlPD+3HGMA4CAa0Pbbii1odQCooKtNvzFH36Pt/sdox/QMdC7gAj4wx9+j9//xR+wbdvci1LK3CsAdmdh9B0yoNQCGYK3tzeAgH50k02IO6hz/fl7dj9sbwtX1LpBAFQuAAo6Ko7yA/7n/+X/hr/7N/9XvP30EwCd84pnxfmp84HYNztbu/9xrrVU23unNSor1mH4na212hkFX/FhNL/eNenFn1X8vgVdlVIAIvS+aCWeW9P5ktOs0R1juJyO+yUiEFUoFJzW5ySLMQZaMxnPRCBd8427lfcp5hM0n8+TmaFEOPoxaXjJefi7z3eFCH4/1pnEZ+POiAiqr+d61/Kerj8x1zlGojVff3xOVFFahajpYq1uUBH0o/v8Ex2kd+Q9ERGgMFQUYxhPz3uTP59pI/Y2+EXsn+iiHRkDIrY3sRYmBdPSn5j5RG+9d+d9FX0IRBcPynub9zGft/FkXeeZ7l0eNmfT3YLm8t1RXfQIkK9jfTY+V+tS04PHhm5q+0RzTldeC2DS46JD102QaM91beOn6y4REYYIamuuHy2as3mf78Li63H+gMhI+6Tz3gW9xHryiN9lOgAAJuPVpRAUevqs/en3BYpbbRhjzHnGPo4hwJSxtOZLAIFOezjPGXJaJzPPOYeuMnz/rsPmPybtrZ+tuxJz2bbNPuN3L59p3EfB+tlcO+zcjuOY53d6H6nZGlh8I/OHmFPmpQA7PfZJp6Ww6yN+/8QVMqy7E++NM897Sczo6Z3zzomiJH44xjjx8DzvoGn7e9xFkw9xBvmu9t4n/z6fyeLf8Z0xhsvnJVNH3LvaIIJ5L+K7qgO1mB0gfZyen+VYvCP2pKvTH2hGBI4xFr8rPNcGXXIlzzv4RbwTAI7RJ40CwJAOKMBcMETAhSFuKIh01FKmTUS1mU3NbPy+VDAAGQJyvSvTWLw/3//eu9kKk0YL+hhul2DylDjrfJdKNZkSPJqIMIaabfMJP4iRzyP+Lbp0qer3kphgZpnZR6rru3MPdUCHTL2m0OLFIoIw1UsphlGIoFa7CyPxuiGCurV0bgpyOUVEGL0777bzLmR8goghfYDJ6SLdU3YaFxFs2zZp8UQLvPZVCCiIfRwA+f7A9Cpbr2EMMa/FD2XueZy3KqBJZ6+1Tnq4yvzYUwCgAgwZJz0SgPEkvehHie8xm007ZMnvqTsQzbmw00xtdfIPRvAYm3MpNG11APZ7f1eXgVLr2qfEbzIvyXwzeMvb2xv2fT+tK2T4PLfLuuLPeFemwfmZsHFrwRA7Jgl2kJ5TSoGOgcqUdFtf+5D576wD5fMJnVxEpi08dRy1u/dBTsiZ54xhPHB/7lOHu+oysTYAaNuGoy89Pnh03JXTOtTphABKz1ySwfSCbDeUWtGPY+o2Wd7kswxZGvw+85PQ5a/2RD7XkFN5DzMdBx/L+8TMUxfJ+u2kiWPx76uueaWrKy3lz2TZE/zKbP3zfVVVVOKTzMq8ND4rIii1GsaVRn7PdY6xL3nEXE/nDIAdQ5r3zfHdqQv2RUv5XXktJfHHaV9hnO5wlhV5j/O6s30yaczl1PWu5u/G3LKcP/GyJDMznWS6/OxM493xjqlLiNm3Md/4fO8dxARJdLNVX7PjkZnOF56WePYnP7/ywLw3WS+K9f0//x//9w/7+Nn4zY4IckKFKjQuLgAuDKoG6HMt6EdHKwVkKC80QFEA5IotEaOQbZKogFAALkAtoFJArYJgTABUULl67gaBREBS3QFijgmqBQxGLRUYCh4CHh1cCKWVCTQQmyAnZgxSoFaoErhUKNSYfGvAGFAHsKS7UqLmuNjaBpCBxrUZ4wkC4RAsrrAQmzP
|
||
|
"text/plain": [
|
||
|
"<Figure size 2000x2000 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.figure(figsize=(20,20))\n",
|
||
|
"plt.imshow(image)\n",
|
||
|
"plt.axis('off')\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b8c2824a",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Automatic mask generation"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "d9ef74c5",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"To run automatic mask generation, provide a version of SAM 2 to the `SAM2AutomaticMaskGenerator` class. Set the path below to the SAM 2 checkpoint."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "1848a108",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sam2.build_sam import build_sam2\n",
|
||
|
"from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator\n",
|
||
|
"\n",
|
||
|
"sam2_checkpoint = \"../checkpoints/sam2_hiera_large.pt\"\n",
|
||
|
"model_cfg = \"sam2_hiera_l.yaml\"\n",
|
||
|
"\n",
|
||
|
"sam2 = build_sam2(model_cfg, sam2_checkpoint, device ='cuda', apply_postprocessing=False)\n",
|
||
|
"\n",
|
||
|
"mask_generator = SAM2AutomaticMaskGenerator(sam2)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "d6b1ea21",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"To generate masks, just run `generate` on an image."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "391771c1",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"masks = mask_generator.generate(image)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e36a1a39",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:\n",
|
||
|
"* `segmentation` : the mask\n",
|
||
|
"* `area` : the area of the mask in pixels\n",
|
||
|
"* `bbox` : the boundary box of the mask in XYWH format\n",
|
||
|
"* `predicted_iou` : the model's own prediction for the quality of the mask\n",
|
||
|
"* `point_coords` : the sampled input point that generated this mask\n",
|
||
|
"* `stability_score` : an additional measure of mask quality\n",
|
||
|
"* `crop_box` : the crop of the image used to generate this mask in XYWH format"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "4fae8d66",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"60\n",
|
||
|
"dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(len(masks))\n",
|
||
|
"print(masks[0].keys())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "53009a1f",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Show all the masks overlayed on the image."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "77ac29c5",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABiIAAAQdCAYAAADTkVDqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9W7Nty3EeiH1ZNeZca+997gfn4OAOEACJS1MSKRJutexuu1sdHeEHR0f41T/Dz/5DfuoIhcMPdoQV0R3htkXJLQqyRIEkCAHElQQOzmXvNeeoSj9kfplZY21Chw+MjnDMAeyz1pqXMaqysjK/vJaoquJ23a7bdbtu1+26Xbfrdt2u23W7btftul2363bdrtt1u27X7bpdt+vv4Gr/cw/gdt2u23W7btftul2363bdrtt1u27X7bpdt+t23a7bdbtu1+26Xf//e90CEbfrdt2u23W7btftul2363bdrtt1u27X7bpdt+t23a7bdbtu1+36O7tugYjbdbtu1+26Xbfrdt2u23W7btftul2363bdrtt1u27X7bpdt+t2/Z1dt0DE7bpdt+t23a7bdbtu1+26Xbfrdt2u23W7btftul2363bdrtt1u/7Orlsg4nbdrtt1u27X7bpdt+t23a7bdbtu1+26Xbfrdt2u23W7btftul1/Z9ctEHG7btftul2363bdrtt1u27X7bpdt+t23a7bdbtu1+26Xbfrdt2uv7PrFoi4Xbfrdt2u23W7btftul2363bdrtt1u27X7bpdt+t23a7bdbtu19/ZdQtE3K7bdbtu1+26Xbfrdt2u23W7btftul2363bdrtt1u27X7bpdt+vv7LoFIm7X7bpdt+t23a7bdbtu1+26Xbfrdt2u23W7btftul2363bdrtv1d3Ztn/SDX/z2/wn/zXcAKNBEoABEBBAAIhAR6FRAAIEAKpjXK2ROvPbKE7zy7CkgwPvv/xovnj9gv+7AnAAEaIJ22oBtA3qH9G4PVUB1YkLRROyzUOgY0KloAqgqpDdAbQxNgbEPzKG4uzvjjddfwd39HfYx8P4HH+Oy75hQoAEiFoeZUJ+LzUEVaALMOaGq0Dl93g0bGl558gSvv/YqzqcTem/ovUFVMafdZ8yBMXY8fPwcqnaPDz78AP/0n32ML/3ijwHdbW5zAqqAKqYKAOCCDT9+9xv43/zBBoUaLWE07n3D06dPsT15Hf/P//59vPP+d/Hxi4uNv3fMJlCxb+i+Yz5coXOibx1oDQoFpkJINwVEff7d3m/S/PsDMiZ02r/ttOGNt97E/ZN7KAQffPghXlyveLhcMaeibx2tNez7FaqKftrQTqegKS+9XDEvV5vP/RkQwf5wQesd/XyGAmitQZqgSUNT4O604fVXn6L3jg8/+hgffPgxrnPYuCHQZmsHVewvXgD7gLQG9XlL78Y/gqC7qo1JYN+VJrbWEI/OCRoEz84nPLu/x/3dGS8uV3zw0Ue47jtUAPQNKuLLY9+fUNjm8P2huYbqa/2D9kX8l//kS/jKZ59hk4bu/C7cRz62/XrBw/PnUFH80b/b8X/bfoKHdy7Y527rJ4IJARSQZnNR+DaJ5xnfTp0QaegieP2VZ3jjtVfRmuBX7/8aH7+44OF6BVSdBmr0tM1sz2ot9rbyEZy771UIbONwuaf6vpgQKE694+3XX8errzyD6sT7v/41nj9/bmupGnJkzonWms+x2fdFuMQxN9Kgt+60MPkzp9Pc5y8ATtuG1199Fa88ewVjv+KvfvnXeP7wgDln0MtJiSZi8z1cc0701m1/z2mfc7kBCOYY9pmthzzCVJy3E1579RU8vb/Hw8ML/PVf/xIvHh5s/wNQqXM3WWc8QL7h5lHg5x3/8KfP8I/+kw1Pzifcnc84bx0CRWsNYwxfM+Pp1lryptMRgH0OQGsdIsAYE6rwzzRs2wn7vvtnWlkPCXqI2P35Oz/Li8+o73Pt+Dlbdlk+Iy57+czeO8YwHrHx2TgEPkfh3gU697EqdEyoAM8fgH/z4x8C8wn+8Gvv4u7U/fkkrNFddWKMYf+uO677wMOLF/j4+Qtc9h0KYN8Hhk4Atid63yBOH8B48cn9HU7bCXMOQCemy7/7+3v88z+5x7O3/wdcr1fj3d5xef6H+M++8TTXmHvXR9iaydDtdAo+nM47laaA5vcUZb1sH7XWnb8b/upXgj/79X/AB/sv8NW//x188be/hSevvgqB4t/9qzt8/Q9/FvfKfW/6QUk3yrnjuoYuRMgIvs/PTJeFdZ+pasqVeLZAY/+nfKg8NOe0+6jGfVtrZXVT/krL+9lzEDIvMM3hUp1GO7U5SZFB3Fsff3DG+eEdPH3rAfu+o/e+3Ivr1VqDiuC6X5e9yc/Wfc+54nLGv/gf7vCH/9VHGGXdjzTYZN2j/H7VK/kzeWyMGffrvS+fm67LpyrmGDhtZ+ic2K+7jx9Jl/IMzjd+7w06FWNMlzmrHOHnt23DGCP5xHmdOtJ4Z8RYTeYqWusxlyYFG7qM4j1779j3HSKCbduwj4mpCFmz8OJBLlWZxM/y7+N+BOBjFojY8+sz7B4TvbvMdhnEsaZ8VGxbwnRVxelkspnYshWZx7Wu4yY/Jh9KyKuY57T17b0HD4wxHM9ObKcTpq77rsp30oZ8n+sPzDkKnUwWVR7jfOrF9yofAKabx5joXUIOrbw9fVsr7rZTyHPyVvB768veIA4QyELDWGfMZZ6Gdfe475ymO0jnetn4R/BevpZ7hWM5n8/2Gd97dU25H2mvLPsatm7X6zXWb3meaOINp2mVDxwTecXm3Zwf9+BTs3VSVphu47/k+33fH2EG4si9PDP23FT0Ig/HGNi27dGeJO/x/qnzJHCCjSv36r7vOJ1Oyx4lT/EnvzPGwPl8BqChZwb33XbCnIh9we+qDmy9G/bcx3L/sDNc1/M1EcGuM3AMteAghuzdZCb1gUqsTR13YD4k9rqOPXgUAMY0e7O1bri1N0w3FObcsfWOudtzZTthcs+NgVPf0ADMYbirtbbwGJ9f9/++70CTwqMdO3Gp02I4D/Ae5NO+mU6hjBYRjOG44yXygFddD/49nfZzTlsfwDGx2Qi09SuONfwyoGO6a0PQJWXxnIYpyec67bVts70wiqwbc2I7n8q6KaSlrTf23WW3rXcXkxMiDXMfaOJ8UfZpIxaeE+fzOXhx4YWWdJ0CdJCOAxCnj9srNl/6VtIGNXk4g+Zcb1VAZ+q9bduCH446v9oC0oExR6w55SR9Kws+KnKPOG7M1N+BHURiLM15ZjttIT9asZ90Ar1L2uWAve/P2udA37akU5E3VZZUuUnZ8uTJE1wul2Ve1OGxbod58SefVXkwPuN6oW0dY9oyTYqDcp/eO3QMbE3iuSFfx4y/Kwaq62P8uzmONvs+MI7a3nukJ+Yqc8YwGXh5uASGO2IZzg0ATuczrvQVucwgDqAcj3nQVyaAlHumZkDY4rx/3zbsbmNhrjqormX4/rBiOO4Dw1m66KK6rtRTlYaVjynHKp1aa4FFKr4Nnrim/D5izSNfHXmpfqbqHsorEXm0X1UVm7RFZ1VZys+ab2Nz+zev+pzjGEmXenGsyzoDaKYccr8JMAuf6568VJ9V59KLfAz7CmPZw1VXVBrXeb/MDqSeOu7V+l2Orer5RZYVnVn5pPLly9aUz+YzAktM82NXv06Mswlm4Zvz5nOWFvZ1YN3rCT/56wd8+ksf1YWKsX3/jz6DL3/z149kYKVNxUWc3//xv/1PH9HxZdcnDkTYgO13FRfgfL25k7IJpgu+kJhImBogH+5UBNLZLg1o9o/OCVOGpmTcQwGhMKeHETDQA3ECA6IKmdOcs50LjRijiDst3Rij4eIUBXTaeHsD5oRqKunugFMkBUV1lE6dodxopMPvP2PM7hGP/+YrH+Eerz2zd+j8QVGYIoKtAecnG67vb1Bc7BOtQXo65OFrYv+crqrQNst6wMckBTg2c7TIDCdS3B8Sr4k0tN4B2R3cdfTNgK6OHa25s08kHduqmOL3IN8AkNZtyL1DpAVYIL+Y46KREY3/xBTmbHafqc5HrcXYpQm0NdD
|
||
|
"text/plain": [
|
||
|
"<Figure size 2000x2000 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.figure(figsize=(20,20))\n",
|
||
|
"plt.imshow(image)\n",
|
||
|
"show_anns(masks)\n",
|
||
|
"plt.axis('off')\n",
|
||
|
"plt.show() "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "00b3d6b2",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Automatic mask generation options"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "183de84e",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "68364513",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"mask_generator_2 = SAM2AutomaticMaskGenerator(\n",
|
||
|
" model=sam2,\n",
|
||
|
" points_per_side=64,\n",
|
||
|
" points_per_batch=128,\n",
|
||
|
" pred_iou_thresh=0.7,\n",
|
||
|
" stability_score_thresh=0.92,\n",
|
||
|
" stability_score_offset=0.7,\n",
|
||
|
" crop_n_layers=1,\n",
|
||
|
" box_nms_thresh=0.7,\n",
|
||
|
" crop_n_points_downscale_factor=2,\n",
|
||
|
" min_mask_region_area=25.0,\n",
|
||
|
" use_m2m=True,\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "bebcdaf1",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"masks2 = mask_generator_2.generate(image)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "fb702ae3",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABiIAAAQdCAYAAADTkVDqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz92a91WXIfiP1irX3u/YbMrMysSlYVayCrVCxJFMVBFCU1JbW6ZbdlGUajYb81DLT9JPjB/jf8FzRgwH6xDT+1DbhbBqEJsqQWRFIcS6SKxVGsecqqyvn77jl7rfBDxC8i1r43yewHwoBxNpn13Xvu2XuvFSuGXwwrlqiq4npdr+t1va7X9bpe1+t6Xa/rdb2u1/W6Xtfrel2v63W9rtf1ul7X68/gav+/HsD1ul7X63pdr+t1va7X9bpe1+t6Xa/rdb2u1/W6Xtfrel2v63W9/v/3uiYirtf1ul7X63pdr+t1va7X9bpe1+t6Xa/rdb2u1/W6Xtfrel2v6/Vndl0TEdfrel2v63W9rtf1ul7X63pdr+t1va7X9bpe1+t6Xa/rdb2u1/X6M7uuiYjrdb2u1/W6Xtfrel2v63W9rtf1ul7X63pdr+t1va7X9bpe1+t6/Zld10TE9bpe1+t6Xa/rdb2u1/W6Xtfrel2v63W9rtf1ul7X63pdr+t1vf7Mrmsi4npdr+t1va7X9bpe1+t6Xa/rdb2u1/W6Xtfrel2v63W9rtf1ul5/Ztc1EXG9rtf1ul7X63pdr+t1va7X9bpe1+t6Xa/rdb2u1/W6Xtfrel2vP7Prmoi4Xtfrel2v63W9rtf1ul7X63pdr+t1va7X9bpe1+t6Xa/rdb2u15/ZdU1EXK/rdb2u1/W6Xtfrel2v63W9rtf1ul7X63pdr+t1va7X9bpe1+vP7No+6Bd/+Of+t/ib//kvQgCICFQBaQL47/aZwj+AqGCeL5A58eILj/HV3/4r+M7nvoTXv3SL+Y8e4TS+A8wJQIAmaDcbsG1A75DWAQGggM6JCUVrYt9VhY4BVUUDoFBIawAEAkEDMC4DcyhuH93g8qmfwcf+97+AfR948+13cd4HJhRogIjlYSYUIgKIQKdCFWgCzDmhqtA5Ybc0XH75r+Ljf/T38emPvYqb04beG3rvUFVMnRARjDEwxo67Z8+gc+Kr3/k2/tLtf419P+Oy75hzQnxuqmpzgtyjuaqPCwDQ8O/e/i/xv/7PP4pXP/wyfv2PNvyT5/8Iz1/4Y0wAsjXMJlBpRrrLjnk+Q8dEP21Aa/Eu4bP95wmF9AYF0ERsbPuATNJ6or37Ufxt/V/if/xzt/jV32/4p+9+F3cv3eH5+Yw5FH3b0HrDuFzs+6cT+ulkSzYVEFtQvbtgXHZAgH57A4hgP1/QW0e/OUFF0JpAmqBJg/7uLf7B9l386Csd7+Ax/ps/fgUfff5dXKbxAHzd0Iw39mfPbey9AVuHtobWOqRJUDjp7rzKtYcCKvZINX76+kuv4n/z9x/j0e0Nnp0vePvdd7HvO1QA7ZvPy56jqv4MY1URGxP8zaT5L31xw35+DU/GHT5zehu9N39EczmaAID9cjYeEsVXnt/id9/Z8Z/9l/8eUy/2rCaY6jLYTAIU8bqYY5PmvNnQRfChF57i5ZdeRGuCX/s3t8C7F3zs4z8wPuQcxORNBJiqaK35c5NPkzeRc26uG+Drroo5JwDFv/utj+Kn/9YFn/oRm+Ov/+Itnv/iR/BacznztZ9zojm/ijSMMZf1g9rcKB+td6ObjRxzatLex/L7F8Ff+siL+OEPPcXYL/i1b34fp5/4bXz606/nlMSm0ZqEbqjX1IneOsYcGGP694SLjTkG5pzovTsfCTAVN6cTfvXXfgo//1/9I7z1vYbf/j/8FP7Td38DOodNR1L/zKlozkuxns5SUMXX8Ar+X+2n8Jd+5FW8/PQxbm9OuDl1CGyNxj5cTboMNfKU6yef7Bj27tY7BMAYE6qm80QattMJ+2W377SG1luMibQSMRlNXrfnwe3AmCP5otynqn6f84qvFflJRDB1Gh0A9N4xxsDUaeNzfW3/NbQmwQ+9t5BvnZMijjffveDnf/R7+Isfv0Nv3fi5ru20+8fYMcbAfhnY9x3Pn9/hvWfPcdl3KIB9Hxg6ACikdfS+QVrz5wl663j86BbbtkHnxP/tF1/E0xdehSrw3ovfwt/4X/2B8QfcYrkNnWOmzJZ/TbbNLp1OJx+r0SE1C0Ju4z5N+6EwerfWXcfbWH/vX/0Ybi9v4eOf/0v49Od/HI9ffBECxb/9Zy/h2Xt7PKvKef3Zns9ltTWg3OqcqVtDfrR8x/ittfV51DkFSrj9sHGLSPCVOCYwfWG8lM81dBBM5883/ZL8Q75xqkOkgb8lbYG//nc6nnzkvbCPXIfWmuu3xEH7vqO7Tkr+su+01qAiuOyXRTb5XdKAzzN55Nh7PCfn0+L5m7TlWXm/LO/hunF9R+E9YpnUGYp+2jBVMcfAabuBzon9svv4kw/qOypN5pxAb9Cprjf7PZ7i97dtwyC+c9qGTnX6TR0xVtO5RhvOpYk6HHD70Fo8s/eOfd8hIti2DfuYmM4vVeaOdKzrTZ6J9Qwbt142ZnEbNpZ32DMmem+ui133t/wuv7dtCdNVTQ/s++68rov+41rXcZMfkw8l9FXMc9r69t6DB8YYLm8T2+mEGe9rTuux8HdrLfg+1x+YcxQ6uYwVHuN86lVls8pYE8EYE71LyunC2y4vUNxuJ8fiI3gr+N1lifSac4adqjSMdcZc5tlaizFv24Y5J4bT73jZ+EfwXn6WssKx3Nzc2Hdc9uqaUh7nYv+IFWzdLpdLrN/yPnGdjNQbVT9wTOQVm3dzftyDT3tvrmdd/qbC9GzKDt+bGK7g39awl3eGzE1FL/pwjGE29CCT5D0+H2HzJHCCjStldd93nE6nRUbJU/yX94wxcHNzA0DD/xuUu+2EORFywXtVB7ZuPuvcx/L8akP5DtJkd3+RfitlKvRdbzE3aNqVOm7qC74TAC5jDx4FgDF3w2atY8yJ1humOwpz7th6x9ztvbKdAu/PMXDqGxoMn4jY55XH+P4q//u+A00Kj3bsYwS+pk7hWldZ6pvZFOpo86cNvzykD3jV9eDv02k/57T1ARwTO95zX5/3Bg11QMeMWEeX1MVzTrjLZTSY9tm2mSyMouvGnNhuTmXdDDNyfGPfXXfbencxPSHSMPeBJs4XRU4bsfCcuLm5CV5ceKElXacAHaTjAMTpA/NXbL7uwyAxjOnDWXB1958BnWn3tm0Lfjja/IqxpZs/wDWnnmRsZcFH1bdrFpcYM+13YAeRGEtzntlOW+iPFrjPxty7pK8C2N/9Xfsc6NuWdCr6puqSqjepWx4/fozz+bzMizY81u0wL/7Ld1UejO+4XWhbx5i2TJPqoDyn9w4dA1uTeG/o1zHj94qB6voY/5oNk2Y+V2AcNdm7ZyfmqnPGMB14vjsHhjtiGc4NAE43N7jsl8VOEQdQj8c8PFZm8YWCpYr8t2ILTY9s2C+XwDbV3tS1jNgfVgxXsTz5vGI+0oN2qtKw8jH1WKVTay2wSMW3wROX1N9HrHnkqyMv1e8c/TfKzlFeVRWbtMVmVV3K78450bcNQ1fdW99zHCPpUi+OdVlnAM2d9ZA3jz0FFtyTl+q76lx60Y/hX2EsMlxtRaVxnfcS2yKPuZ06ymq9l2Ordn7RZcVmVj6pfPnQmvLdfEdgiZlxw8pf+75bfLDwzc3mc5YGqGJrDf/9L7wCdd1asZpNBPc/Z7ym+I+cM/10/ixtw//z//i5e3R86PrAiQgC9BihmBsgTSzI7d/hRMWDmerg43M/+/v47j/7K5AXv+jPInz14JY0oDVAmgWR4cbJjQyDoELCKKOdMNBD52oC0hXiDklzB3ga4rCApognI8g4BUa3BqgHeXoD5oQy2CuC289+E8+/ZIvM4FwoBXcmgXTOlZFQmAORhlAXRZrXYXEdiFani0EmflsFEA+2WwbFMynikxP7XFQiOMcAHQEpimMjCtMKOv3dmaCA0Jh7gF8apGkkZOYY0GHGxYLXUmg
|
||
|
"text/plain": [
|
||
|
"<Figure size 2000x2000 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.figure(figsize=(20,20))\n",
|
||
|
"plt.imshow(image)\n",
|
||
|
"show_anns(masks2)\n",
|
||
|
"plt.axis('off')\n",
|
||
|
"plt.show() "
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|