Files
Grounded-SAM-2/lib/train/dataset/coco.py
2024-11-19 22:12:54 -08:00

157 lines
5.6 KiB
Python

import os
from .base_image_dataset import BaseImageDataset
import torch
import random
from collections import OrderedDict
from lib.train.data import jpeg4py_loader
from lib.train.admin import env_settings
from pycocotools.coco import COCO
class MSCOCO(BaseImageDataset):
""" The COCO object detection dataset.
Publication:
Microsoft COCO: Common Objects in Context.
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
ECCV, 2014
https://arxiv.org/pdf/1405.0312.pdf
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
organized as follows.
- coco_root
- annotations
- instances_train2014.json
- instances_train2017.json
- images
- train2014
- train2017
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
"""
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, min_area=None,
split="train", version="2014"):
"""
args:
root - path to coco root folder
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
is used by default.
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
min_area - Objects with area less than min_area are filtered out. Default is 0.0
split - 'train' or 'val'.
version - version of coco dataset (2014 or 2017)
"""
root = env_settings().coco_dir if root is None else root
super().__init__('COCO', root, image_loader)
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
self.coco_set = COCO(self.anno_path)
self.cats = self.coco_set.cats
self.class_list = self.get_class_list() # the parent class thing would happen in the sampler
self.image_list = self._get_image_list(min_area=min_area)
if data_fraction is not None:
self.image_list = random.sample(self.image_list, int(len(self.image_list) * data_fraction))
self.im_per_class = self._build_im_per_class()
def _get_image_list(self, min_area=None):
ann_list = list(self.coco_set.anns.keys())
image_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
if min_area is not None:
image_list = [a for a in image_list if self.coco_set.anns[a]['area'] > min_area]
return image_list
def get_num_classes(self):
return len(self.class_list)
def get_name(self):
return 'coco'
def has_class_info(self):
return True
def has_segmentation_info(self):
return True
def get_class_list(self):
class_list = []
for cat_id in self.cats.keys():
class_list.append(self.cats[cat_id]['name'])
return class_list
def _build_im_per_class(self):
im_per_class = {}
for i, im in enumerate(self.image_list):
class_name = self.cats[self.coco_set.anns[im]['category_id']]['name']
if class_name not in im_per_class:
im_per_class[class_name] = [i]
else:
im_per_class[class_name].append(i)
return im_per_class
def get_images_in_class(self, class_name):
return self.im_per_class[class_name]
def get_image_info(self, im_id):
anno = self._get_anno(im_id)
bbox = torch.Tensor(anno['bbox']).view(4,)
mask = torch.Tensor(self.coco_set.annToMask(anno))
valid = (bbox[2] > 0) & (bbox[3] > 0)
visible = valid.clone().byte()
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
def _get_anno(self, im_id):
anno = self.coco_set.anns[self.image_list[im_id]]
return anno
def _get_image(self, im_id):
path = self.coco_set.loadImgs([self.coco_set.anns[self.image_list[im_id]]['image_id']])[0]['file_name']
img = self.image_loader(os.path.join(self.img_pth, path))
return img
def get_meta_info(self, im_id):
try:
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
'motion_class': None,
'major_class': cat_dict_current['supercategory'],
'root_class': None,
'motion_adverb': None})
except:
object_meta = OrderedDict({'object_class_name': None,
'motion_class': None,
'major_class': None,
'root_class': None,
'motion_adverb': None})
return object_meta
def get_class_name(self, im_id):
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
return cat_dict_current['name']
def get_image(self, image_id, anno=None):
frame = self._get_image(image_id)
if anno is None:
anno = self.get_image_info(image_id)
object_meta = self.get_meta_info(image_id)
return frame, anno, object_meta