[New Feature] Support SAM 2.1 (#59)

* support sam 2.1

* refine config path and ckpt path

* update README
This commit is contained in:
Ren Tianhe
2024-10-10 14:55:50 +08:00
committed by GitHub
parent e899ad99e8
commit 82e503604f
340 changed files with 39100 additions and 608 deletions

View File

@@ -4,12 +4,14 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from functools import partial
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from iopath.common.file_io import g_pathmgr
from sam2.modeling.backbones.utils import (
PatchEmbed,
@@ -193,6 +195,7 @@ class Hiera(nn.Module):
16,
20,
),
weights_path=None,
return_interm_layers=True, # return feats from every stage
):
super().__init__()
@@ -262,6 +265,11 @@ class Hiera(nn.Module):
else [self.blocks[-1].dim_out]
)
if weights_path is not None:
with g_pathmgr.open(weights_path, "rb") as f:
chkpt = torch.load(f, map_location="cpu")
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
window_embed = self.pos_embed_window
@@ -289,3 +297,21 @@ class Hiera(nn.Module):
outputs.append(feats)
return outputs
def get_layer_id(self, layer_name):
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
num_layers = self.get_num_layers()
if layer_name.find("rel_pos") != -1:
return num_layers + 1
elif layer_name.find("pos_embed") != -1:
return 0
elif layer_name.find("patch_embed") != -1:
return 0
elif layer_name.find("blocks") != -1:
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
else:
return num_layers + 1
def get_num_layers(self) -> int:
return len(self.blocks)

View File

@@ -71,6 +71,7 @@ class FpnNeck(nn.Module):
self.position_encoding = position_encoding
self.convs = nn.ModuleList()
self.backbone_channel_list = backbone_channel_list
self.d_model = d_model
for dim in backbone_channel_list:
current = nn.Sequential()
current.add_module(