diff --git a/groundingdino/models/GroundingDINO/groundingdino.py b/groundingdino/models/GroundingDINO/groundingdino.py index a5758fd..cd97028 100644 --- a/groundingdino/models/GroundingDINO/groundingdino.py +++ b/groundingdino/models/GroundingDINO/groundingdino.py @@ -206,6 +206,21 @@ class GroundingDINO(nn.Module): nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) + def set_image_tensor(self, samples: NestedTensor): + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + self.features, self.poss = self.backbone(samples) + + def unset_image_tensor(self): + if hasattr(self, 'features'): + del self.features + if hasattr(self,'poss'): + del self.poss + + def set_image_features(self, features , poss): + self.features = features + self.poss = poss + def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) @@ -282,14 +297,14 @@ class GroundingDINO(nn.Module): } # import ipdb; ipdb.set_trace() - if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) - features, poss = self.backbone(samples) + if not hasattr(self, 'features') or not hasattr(self, 'poss'): + self.set_image_tensor(samples) srcs = [] masks = [] - for l, feat in enumerate(features): + for l, feat in enumerate(self.features): src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) @@ -298,7 +313,7 @@ class GroundingDINO(nn.Module): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: - src = self.input_proj[l](features[-1].tensors) + src = self.input_proj[l](self.features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask @@ -306,11 +321,11 @@ class GroundingDINO(nn.Module): pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) - poss.append(pos_l) + self.poss.append(pos_l) input_query_bbox = input_query_label = attn_mask = dn_meta = None hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( - srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict + srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict ) # deformable-detr-like anchor update @@ -344,7 +359,9 @@ class GroundingDINO(nn.Module): # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict) # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord} # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal} - + unset_image_tensor = kw.get('unset_image_tensor', True) + if unset_image_tensor: + self.unset_image_tensor() ## If necessary return out @torch.jit.unused @@ -392,3 +409,4 @@ def build_groundingdino(args): ) return model +