[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
@@ -4,12 +4,14 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from sam2.modeling.backbones.utils import (
|
||||
PatchEmbed,
|
||||
@@ -193,6 +195,7 @@ class Hiera(nn.Module):
|
||||
16,
|
||||
20,
|
||||
),
|
||||
weights_path=None,
|
||||
return_interm_layers=True, # return feats from every stage
|
||||
):
|
||||
super().__init__()
|
||||
@@ -262,6 +265,11 @@ class Hiera(nn.Module):
|
||||
else [self.blocks[-1].dim_out]
|
||||
)
|
||||
|
||||
if weights_path is not None:
|
||||
with g_pathmgr.open(weights_path, "rb") as f:
|
||||
chkpt = torch.load(f, map_location="cpu")
|
||||
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
||||
|
||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
||||
h, w = hw
|
||||
window_embed = self.pos_embed_window
|
||||
@@ -289,3 +297,21 @@ class Hiera(nn.Module):
|
||||
outputs.append(feats)
|
||||
|
||||
return outputs
|
||||
|
||||
def get_layer_id(self, layer_name):
|
||||
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
||||
num_layers = self.get_num_layers()
|
||||
|
||||
if layer_name.find("rel_pos") != -1:
|
||||
return num_layers + 1
|
||||
elif layer_name.find("pos_embed") != -1:
|
||||
return 0
|
||||
elif layer_name.find("patch_embed") != -1:
|
||||
return 0
|
||||
elif layer_name.find("blocks") != -1:
|
||||
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
||||
else:
|
||||
return num_layers + 1
|
||||
|
||||
def get_num_layers(self) -> int:
|
||||
return len(self.blocks)
|
||||
|
@@ -71,6 +71,7 @@ class FpnNeck(nn.Module):
|
||||
self.position_encoding = position_encoding
|
||||
self.convs = nn.ModuleList()
|
||||
self.backbone_channel_list = backbone_channel_list
|
||||
self.d_model = d_model
|
||||
for dim in backbone_channel_list:
|
||||
current = nn.Sequential()
|
||||
current.add_module(
|
||||
|
Reference in New Issue
Block a user